# Model soup

In [5]:
# 기본코드
import inspect
import os
import sys
import time
import numpy as np
import pandas as pd

import torch
from dataloader import CustomDataLoader, do_transform, collate_fn
from modules.losses import create_criterion
from utils.utils import label_accuracy_score, add_hist, set_seed
from easydict import EasyDict
from modules.model import create_model


args = EasyDict({
    'val_json_path': '../data/val.json',
    'data_dir': '../data',
    'batch_size': 2,
    'num_workers': 4,
    'model_name': 'mit_unet_3plus',
    'use_losses': 'combo',
    'save_submission': True
})

In [6]:
def validation(args, epoch, model, data_loader, criterion, device):
    print(f'\n Start validation #{epoch}')
    model.eval()
    category_names = ['Background', 'General trash', 'Paper', 'Paper pack', 'Metal', 'Glass', 'Plastic', 'Styrofoam', 'Plastic bag', 'Battery', 'Clothing']

    with torch.no_grad():
        n_class = 11
        total_loss = 0
        cnt = 0
        hist = np.zeros((n_class, n_class))
        submission = pd.read_csv('/opt/ml/input/code/submission/sample_submission.csv', index_col=None)
        for step, (images, masks, image_infos) in enumerate(data_loader):
            images = torch.stack(images)       
            masks = torch.stack(masks).long()
            images, masks = images.to(device), masks.to(device)            
            
            # device 할당
            model = model.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            total_loss += loss
            cnt += 1
            outputs = torch.argmax(outputs, dim=1).detach().cpu().numpy()
            masks = masks.detach().cpu().numpy()

            if args.save_submission:
                for i in range(len(image_infos)):
                    submission = submission.append({"image_id" : image_infos[i]['file_name'], "PredictionString" : ' '.join(str(e) for e in outputs[i].flatten())}, 
                                            ignore_index=True)

            hist = add_hist(hist, masks, outputs, n_class=n_class)
        
        acc, acc_cls, mIoU, fwavacc, IoU = label_accuracy_score(hist)
        IoU_by_class = [{classes : round(IoU,4)} for IoU, classes in zip(IoU , category_names)]

        avrg_loss = total_loss / cnt
        print(f'Validation #{epoch}  Average Loss: {round(avrg_loss.item(), 4)}, Accuracy : {round(acc, 4)}, mIoU: {round(mIoU, 4)}')
        print(f'IoU by class : {IoU_by_class}')
        
    return avrg_loss, mIoU, submission

In [7]:
def uniform_soup(cfg, model, checkpoint_paths ,device, by_name = False):
    val_dataset = CustomDataLoader(data_dir=args.val_json_path,
                                   mode='val',
                                   transform=do_transform(mode='val'),
                                   data_path=args.data_dir)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             collate_fn=collate_fn)
    
    criterion = create_criterion(args.use_losses)
    model = model.to(device)
    model_dict = model.state_dict()
    soups = {key:[] for key in model_dict}
    checkpoint = {}
    for i, checkpoint_path in enumerate(checkpoint_paths):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        weight_dict = checkpoint.state_dict()
        for k, v in weight_dict.items():
            soups[k].append(v)
    if 0 < len(soups):
        soups = {k:(torch.sum(torch.stack(v), axis = 0) / len(v)).type(v[0].dtype) for k, v in soups.items() if len(v) != 0}
        model_dict.update(soups)
        model.load_state_dict(model_dict)
    
    model.load_state_dict(model_dict)
    
    avrg_loss, val_mIoU, val_csv = validation(args, 0, model, val_loader, criterion, device)
    
    return model

0. 모델 & checkpoint 가져오기

In [8]:
model = create_model(args.model_name)

################ soup할 checkpoint path 적기 ################
checkpoint_paths = [
    '/opt/ml/input/code/saved/mit_unet_3plus_final_2/best_mIoU_12.pt',
    '/opt/ml/input/code/saved/mit_unet_3plus_final_2/best_mIoU_11.pt',
    '/opt/ml/input/code/saved/mit_unet_3plus_final_2/best_mIoU_9.pt',
]
################ soup할 checkpoint path 적기 ################
device = 'cuda' if torch.cuda.is_available() else 'cpu'

1. uniform soup

In [11]:
################ save dir path 적기 ################
save_dir_path = './soup/'
name = 'uniform_soup' # soup 이름 적기
################ save dir path 적기 ################

print("\n[Uniform Soup]")
uniform_model = uniform_soup(args, model, checkpoint_paths, device)

torch.save(uniform_model, save_dir_path + f'uniform_model_soup_{name}.pt')


[Uniform Soup]
loading annotations into memory...
Done (t=1.10s)
creating index...
index created!

 Start validation #0
Validation #0  Average Loss: 0.4245, Accuracy : 0.9447, mIoU: 0.6807
IoU by class : [{'Background': 0.968}, {'General trash': 0.4625}, {'Paper': 0.8205}, {'Paper pack': 0.5377}, {'Metal': 0.5816}, {'Glass': 0.6223}, {'Plastic': 0.5142}, {'Styrofoam': 0.7712}, {'Plastic bag': 0.8651}, {'Battery': 0.8434}, {'Clothing': 0.5016}]


2. Greedy Soup (uniform weight update)