# Model soup

In [19]:
# 기본코드
import inspect
import os
import sys
import time
import numpy as np
import pandas as pd

import torch
from dataloader import CustomDataLoader, do_transform, collate_fn
from modules.losses import create_criterion
from utils.utils import label_accuracy_score, add_hist, set_seed
from easydict import EasyDict
from modules.model import create_model


args = EasyDict({
    'val_json_path': '/opt/ml/input/data/seed21/val_1.json',
    'data_dir': '/opt/ml/input/data',
    'batch_size': 2,
    'num_workers': 4,
    'model_name': 'mvtb4_unet',
    'use_losses': 'combo',
    'save_submission': False
})

In [21]:
def validation(args, epoch, model, data_loader, criterion, device):
    print(f'\n Start validation #{epoch}')
    model.eval()
    category_names = ['Background', 'General trash', 'Paper', 'Paper pack', 'Metal', 'Glass', 'Plastic', 'Styrofoam', 'Plastic bag', 'Battery', 'Clothing']

    with torch.no_grad():
        n_class = 11
        total_loss = 0
        cnt = 0
        hist = np.zeros((n_class, n_class))
        submission = pd.read_csv('/opt/ml/input/code/submission/sample_submission.csv', index_col=None)
        for step, (images, masks, image_infos) in enumerate(data_loader):
            images = torch.stack(images)       
            masks = torch.stack(masks).long()
            images, masks = images.to(device), masks.to(device)            
            
            # device 할당
            model = model.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            total_loss += loss
            cnt += 1
            outputs = torch.argmax(outputs, dim=1).detach().cpu().numpy()
            masks = masks.detach().cpu().numpy()

            if args.save_submission:
                for i in range(len(image_infos)):
                    submission = submission.append({"image_id" : image_infos[i]['file_name'], "PredictionString" : ' '.join(str(e) for e in outputs[i].flatten())}, 
                                            ignore_index=True)

            hist = add_hist(hist, masks, outputs, n_class=n_class)
        
        acc, acc_cls, mIoU, fwavacc, IoU = label_accuracy_score(hist)
        IoU_by_class = [{classes : round(IoU,4)} for IoU, classes in zip(IoU , category_names)]

        avrg_loss = total_loss / cnt
        print(f'Validation #{epoch}  Average Loss: {round(avrg_loss.item(), 4)}, Accuracy : {round(acc, 4)}, mIoU: {round(mIoU, 4)}')
        print(f'IoU by class : {IoU_by_class}')
        
    return avrg_loss, mIoU, submission

0. 모델 & checkpoint 가져오기

In [3]:
model = create_model(args.model_name)

################ soup할 checkpoint path 적기 ################
checkpoint_paths = [
    '/opt/ml/input/code/saved/mit_unet_3plus_final_2/best_mIoU_9.pt',
    '/opt/ml/input/code/saved/mit_unet_3plus_final_2/best_mIoU_11.pt',
    '/opt/ml/input/code/saved/mit_unet_3plus_final_2/best_mIoU_12.pt',
]
################ soup할 checkpoint path 적기 ################
device = 'cuda' if torch.cuda.is_available() else 'cpu'

1. uniform soup

In [None]:
def uniform_soup(args, model, checkpoint_paths, device, by_name=False):
    val_dataset = CustomDataLoader(data_dir=args.val_json_path,
                                   mode='val',
                                   transform=do_transform(mode='val'),
                                   data_path=args.data_dir)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             collate_fn=collate_fn)
    
    criterion = create_criterion(args.use_losses)
    model = model.to(device)
    model_dict = model.state_dict()
    soups = {key:[] for key in model_dict}
    checkpoint = {}
    for i, checkpoint_path in enumerate(checkpoint_paths):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        weight_dict = checkpoint.state_dict()
        for k, v in weight_dict.items():
            soups[k].append(v)
    if 0 < len(soups):
        soups = {k:(torch.sum(torch.stack(v), axis = 0) / len(v)).type(v[0].dtype) for k, v in soups.items() if len(v) != 0}
        model_dict.update(soups)
        model.load_state_dict(model_dict)
    
    model.load_state_dict(model_dict)
    
    avrg_loss, val_mIoU, val_csv = validation(args, 0, model, val_loader, criterion, device)
    
    return model

In [None]:
################ save dir path 적기 ################
save_dir_path = './soup/'
name = 'uniform_model_soup' # soup 이름 적기
################ save dir path 적기 ################

print("\n[Uniform Soup]")
uniform_model = uniform_soup(args, model, checkpoint_paths, device)

torch.save(uniform_model, save_dir_path + f'{name}.pt')

2. Greedy Soup (uniform weight update)

In [22]:
def greedy_soup(args, model_ori, checkpoint_paths, device):    
    val_dataset = CustomDataLoader(data_dir=args.val_json_path,
                                   mode='val',
                                   transform=do_transform(mode='val'),
                                   data_path=args.data_dir)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             collate_fn=collate_fn)
    criterion = create_criterion(args.use_losses)
    """
    result = []
    checkpoint = {}
    for i, checkpoint_path in enumerate(checkpoint_paths):
        model = model_ori.to(device)
        checkpoint = torch.load(checkpoint_path, map_location=device)
        state_dict = checkpoint.state_dict()
        model.load_state_dict(state_dict)

        avrg_loss, val_mIoU, val_csv = validation(args, 0, model, val_loader, criterion, device)
        
        eval_kwargs = {}
        eval_kwargs.update(metric=['mIoU'])
        
        result.append((val_mIoU, checkpoint_path))
        print(f"리스트에 {i}번째 mIoU {val_mIoU}저장")
    
    result.sort(key = lambda x : x[0], reverse = True)
    print(f"리스트 정렬")
    print(result)
    """
    ckpt_used = []
    model = model_ori.to(device)
    model_dict = model.state_dict()
    pre_metric_value = 0
    pre_weight_dict = {}
    for i, (mIoU, checkpoint_path) in enumerate(checkpoint_paths):
        model = model_ori.to(device)
        soups = {key:[] for key in model_dict}
        now_model_dict = model_dict
        if i == 0:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            pre_metric_value = mIoU
            pre_weight_dict = checkpoint.state_dict()
            print("soup 모델에 가장 높은 mIou를 가진 checkpoint가 추가되었습니다")
            print(f"추가된 checkpoint_path: {checkpoint_path}")
            print(f"현재 최고 mIoU: {pre_metric_value}")
            ckpt_used.append(os.path.basename(checkpoint_path))
        else:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            weight_dict = checkpoint.state_dict()
            
            for k, v in pre_weight_dict.items():
                soups[k].append(v)
            for k, v in weight_dict.items():
                soups[k].append(v)    
            if 0 < len(soups):
                soups = {k:(torch.sum(torch.stack(v), axis = 0) / len(v)).type(v[0].dtype) for k, v in soups.items() if len(v) != 0}
                now_model_dict.update(soups)
                
            
            model.load_state_dict(now_model_dict)    
            
            eval_kwargs = {}
            eval_kwargs.update(metric=['mIoU'])
            avrg_loss, val_mIoU, val_csv = validation(args, 0, model, val_loader, criterion, device)
            
            if val_mIoU >= pre_metric_value:
                pre_metric_value = val_mIoU
                pre_weight_dict = now_model_dict
                print("soup 모델에 새로운 checkpoint가 추가되었습니다")
                print(f"추가된 checkpoint_path: {checkpoint_path}")
                print(f"현재 최고 mIoU: {pre_metric_value}")
                ckpt_used.append(os.path.basename(checkpoint_path))
            else:
                print("이번 체크 포인트는 soup 모델에 추가되지 않았습니다")
                print(f"이번 checkpoint_path: {checkpoint_path}")
                print(f"현재 최고 mIoU: {pre_metric_value}, 이번 mIou {val_mIoU}")
            
    model = model_ori.to(device)
    model.load_state_dict(pre_weight_dict)
    
    return model, ckpt_used

In [23]:
model = create_model(args.model_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

################ save dir path 적기 ################
save_dir_path = '/opt/ml/trained_models/Final_fold_1/'
name = 'greedy_model_soup' # soup 이름 적기
################ save dir path 적기 ################

# sort models (use only topK-saved dir)
checkpoint_paths = [(float(f[-9:-3]), os.path.join('/opt/ml/trained_models/Final_fold_1', f)) for f in os.listdir('/opt/ml/trained_models/Final_fold_1') if f.endswith('.pt')]
checkpoint_paths.sort(key=lambda x: x[0], reverse=True)
print(checkpoint_paths)

greedy_model, ckpt_used = greedy_soup(args, model, checkpoint_paths, device)

print(ckpt_used)
torch.save(greedy_model, save_dir_path + f'{name}.pt')

[(0.6985, '/opt/ml/trained_models/Final_fold_1/ep111_0.6985.pt'), (0.6984, '/opt/ml/trained_models/Final_fold_1/ep115_0.6984.pt'), (0.6977, '/opt/ml/trained_models/Final_fold_1/ep113_0.6977.pt'), (0.6952, '/opt/ml/trained_models/Final_fold_1/ep118_0.6952.pt'), (0.6949, '/opt/ml/trained_models/Final_fold_1/ep119_0.6949.pt'), (0.6948, '/opt/ml/trained_models/Final_fold_1/ep116_0.6948.pt'), (0.6938, '/opt/ml/trained_models/Final_fold_1/ep117_0.6938.pt'), (0.6936, '/opt/ml/trained_models/Final_fold_1/ep114_0.6936.pt'), (0.6929, '/opt/ml/trained_models/Final_fold_1/ep120_0.6929.pt'), (0.6872, '/opt/ml/trained_models/Final_fold_1/ep108_0.6872.pt'), (0.687, '/opt/ml/trained_models/Final_fold_1/ep112_0.6870.pt'), (0.6829, '/opt/ml/trained_models/Final_fold_1/ep79_0.6829.pt'), (0.6819, '/opt/ml/trained_models/Final_fold_1/ep107_0.6819.pt'), (0.6815, '/opt/ml/trained_models/Final_fold_1/ep105_0.6815.pt'), (0.6808, '/opt/ml/trained_models/Final_fold_1/ep106_0.6808.pt')]
loading annotations into m