In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.backends.cudnn.benchmark = True

import os
import cv2
import shutil
import pandas as pd
import numpy as np
import pretrainedmodels
import matplotlib.pyplot as plt

from apex import amp
from time import time
from warnings import filterwarnings
from tqdm.notebook import tqdm
from efficientnet_pytorch import EfficientNet
from torch.utils.data import Dataset, DataLoader
import argparse

import gc
from PIL import Image
from albumentations import *
from albumentations.pytorch import ToTensorV2
from torchvision.transforms import ToPILImage
from sklearn.model_selection import StratifiedKFold

torch.backends.cudnn.benchmark = True

In [None]:
# Sample merging script for SWA
train_root = 'data/train_384'
test_root = 'data/test_images'
model_dir = 'experiments/final_b4_0_stage1'
model_names = ['4.pt', '3.pt', '2.pt', '1.pt', '0.pt']
batch_size = 128
image_size = 384
device = 'cuda:0'
fold = 1

In [None]:
test_transform = Compose([
    Resize(image_size, image_size, interpolation=cv2.INTER_LANCZOS4),
    Normalize(),
    ToTensorV2(),
])

In [None]:
class TrashDataset_test(Dataset):
    def __init__(self, df, root='data', transform=None):
        self.transform = transform
        self.df = df
        self.root = root

    def __len__(self):
        return len(self.df)

    def __getitem__(self,idx):
        path = os.path.join(self.root, str(self.df['id'].values[idx])+'.png')
        # try:
        image = np.array(Image.open(path))
        if self.transform is not None:
            image = self.transform(image=image)['image']
        return {'img':image}

class TrashDataset(Dataset):
    def __init__(self, df, root, transforms=ToTensorV2()):
        self.df, self.root = df, root
        self.len = len(df)
        self.transforms = transforms
    
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        name, label = str(self.df.image_id.values[idx]), self.df.label.values[idx]
        img = cv2.imread(os.path.join(self.root, name+'.png'))
        original = img.copy()
        img = self.transforms(image=img)['image']
        return {'img':img, 'label':label, 'original':original}

def metric(preds, labels):
    preds = preds.argmax(1)
    return (preds == labels).float().mean().item()

In [None]:
valset = TrashDataset(pd.read_csv('data/train.csv'), train_root, test_transform)
indices = np.array(range(len(valset)))
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2020)
for i, (train_indices, val_indices) in enumerate(skf.split(indices, valset.df.label.values)):
    if i == fold:
        break
valset = torch.utils.data.Subset(valset, val_indices)
testset = TrashDataset_test(pd.read_csv('data/test.csv'), test_root, test_transform)
print(len(valset), 'validation', len(testset), 'test')

In [None]:
import json
train_path = 'data/train.json'
id_map = {}
with open(train_path, 'r') as fp:
    train = dict(json.load(fp))
for instance in train['categories']:
    id_map[instance['id']] = instance['name']

In [None]:
# Load a base model
model = torch.jit.load(os.path.join(model_dir, model_names[0])).to(device)
model.eval()
valloader = DataLoader(valset, batch_size=batch_size, num_workers=8)
testloader = DataLoader(testset, batch_size=batch_size, num_workers=8)

In [None]:
def evaluate(model, state_dict, demonstrate=False, half=False):
    model.load_state_dict(state_dict)
    if half:
        model = model.half()
    else:
        model = model.float()
    metric_values, preds_, pred_values_, labels_ = [], [], [], []
    with torch.no_grad():
        i0 = tqdm(valloader)
        for batch in i0:
            img, labels, original = batch['img'].to(device), batch['label'].to(device), batch['original']
            if half:
                img = img.half()
            p0 = F.softmax(model(img), dim=1)
            p1 = F.softmax(model(torch.flip(img, (-1,))), dim=1)
            p2 = F.softmax(model(torch.flip(img, (-2,))), dim=1)
            p3 = F.softmax(model(torch.flip(img, (-1, -2))), dim=1)
            preds = ((p0 + p1 + p2 + p3) / 4).float()

            pred_values, preds = preds.max(1)
            preds_.append(preds.cpu())
            pred_values_.append(pred_values.cpu())
            labels_.append(labels.cpu())
            wrong_mask = preds != labels
            wrong_ids = torch.tensor(range(len(img)))[wrong_mask]
            if len(wrong_ids) > 0 and demonstrate:
                for wrong_id in wrong_ids:
                    label = id_map[labels[wrong_id].item()+1]
                    pred = id_map[preds[wrong_id].item()+1]
                    confidence = pred_values[wrong_id].item()
                    print('Pred:', pred, ' label:', label, ' conf:', round(confidence, 4))
                    wrong_img = original[wrong_id]
                    plt.imshow(wrong_img)
                    plt.show()
            metric_value = 1 - wrong_mask.float().mean().item()
            metric_values.append(metric_value)
            i0.set_postfix({'Metric': np.mean(metric_values)})
    if demonstrate:
        print('Metric:', np.mean(metric_values))
        return np.mean(metric_values), torch.cat(preds_), torch.cat(pred_values_), torch.cat(labels_)
    else:
        return np.mean(metric_values)

In [None]:
from collections import OrderedDict 

# Average model weights
def avg_state_dict(weights):
    average_dict = OrderedDict()
    for k in weights[0].keys():
        average_dict[k] = sum([weight[k] for weight in weights]) / len(weights)
    return average_dict 

In [None]:
weights = [torch.jit.load(os.path.join(model_dir, model_name)).state_dict() for model_name in model_names]
current_weights, metrics = [], []
best_weight, patience = None, 0
for w in weights:
    current_weights.append(w)
    average_weights = avg_state_dict(current_weights)
    metrics.append(evaluate(model, average_weights, half=True))
    if metrics[-1] == max(metrics):
        print('Better combo found with CV:', metrics[-1])
        best_weight = average_weights
        patience = 0
    else:
        patience += 1
        if patience == 10:
            print('Early stopping triggered ;)')
            break

In [None]:
model.load_state_dict(best_weight)
model.eval()
torch.jit.save(model, os.path.join(model_dir, 'merged_best.pt'))

In [None]:
print(model_dir)