In [None]:
package_path = '../input/pytorch-image-models/pytorch-image-models-master'

import sys
sys.path.append(package_path)

DATA_DIR = '../input/cassava-leaf-disease-classification'

from datetime import datetime
import os
import random
import time
import warnings

import pandas as pd
import numpy as np
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import sklearn
from sklearn.metrics import log_loss
from sklearn.model_selection import StratifiedKFold

import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
import timm

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
from albumentations.pytorch import ToTensorV2

CFG = {
    'fold_num': 5,
    'seed': 719,
    'model_arch': 'vit_base_patch16_384',
    'img_size': 384,
    'epochs': 10,
    'train_bs': 32,
    'valid_bs': 16,
    'lr': 1e-4,
    'num_workers': 4,
    'accum_iter': 1, # support to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0',
    'tta': 4,
    'weights': [1] * 10
}

model_path_vit = []
model_path_eff = []

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

class CassavaDataset(Dataset):
    def __init__(self, df, data_root, transforms=None, output_label=True):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        self.output_label = output_label
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        
        if self.output_label:
            target = self.df.iloc[index]['label']
          
        img  = get_img(f"{self.data_root}/{self.df.loc[index]['image_id']}")
        
        if self.transforms:
            img = self.transforms(image=img)['image']
            
        # do label smoothing
        if self.output_label:
            return img, target
        else:
            return img
        
def get_inference_transforms_vit():
    return Compose([
            RandomResizedCrop(384, 384),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.0)

def get_inference_transforms_eff():
    return Compose([
            RandomResizedCrop(512, 512),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.0)

class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        if 'vit' not in model_arch:
            n_features = self.model.classifier.in_features
            self.model.classifier = nn.Linear(n_features, n_class)
        else:
            n_features = self.model.head.in_features
            self.model.head = nn.Linear(n_features, n_class)
        
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    
def inference_one_epoch(model, data_loader, device):
    model.eval()

    image_preds_all = []
    
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (imgs) in pbar:
        imgs = imgs.to(device).float()
        
        image_preds = model(imgs)
        image_preds_all += [torch.softmax(image_preds, 1).detach().cpu().numpy()]
        
    image_preds_all = np.concatenate(image_preds_all, axis=0)
    return image_preds_all


print(f'Inference fold started')

test = pd.DataFrame()
test['image_id'] = list(os.listdir(f'{DATA_DIR}/test_images/'))
test_ds = CassavaDataset(
    test, f'{DATA_DIR}/test_images/',
    transforms=get_inference_transforms_vit(), output_label=False)

tst_loader = torch.utils.data.DataLoader(
    test_ds, 
    batch_size=CFG['valid_bs'],
    num_workers=CFG['num_workers'],
    shuffle=False,
    pin_memory=False,
)

device = torch.device(CFG['device'])


tst_preds = []

for i in range(len(model_path_vit)):
    if 'large' in model_path_vit[i]:
        model = CassvaImgClassifier('vit_large_patch16_384', 5).to(device)
    else:
        model = CassvaImgClassifier('vit_base_patch16_384', 5).to(device)

    model.load_state_dict(
        torch.load(model_path_vit[i]))

    with torch.no_grad():
        sum_weights = sum(CFG['weights'])
        for _ in range(CFG['tta']):
            tst_image_preds = inference_one_epoch(model, tst_loader, device)
            tst_preds += [CFG['weights'][i] / sum_weights / CFG['tta']*tst_image_preds]        

test_ds = CassavaDataset(
    test, f'{DATA_DIR}/test_images/',
    transforms=get_inference_transforms_eff(), output_label=False)

tst_loader = torch.utils.data.DataLoader(
    test_ds, 
    batch_size=CFG['valid_bs'] * 2,
    num_workers=CFG['num_workers'],
    shuffle=False,
    pin_memory=False,
)
            
for i in range(len(model_path_eff)):            
    if 'b5' in model_path_eff[i]:
        model = CassvaImgClassifier('tf_efficientnet_b5_ns', 5).to(device)
    else:
        model = CassvaImgClassifier('tf_efficientnet_b4_ns', 5).to(device)
        
    model.load_state_dict(
        torch.load(model_path_eff[i]))

    with torch.no_grad():
        sum_weights = sum(CFG['weights'])
        for _ in range(CFG['tta']):
            tst_image_preds = inference_one_epoch(model, tst_loader, device)
            tst_preds += [CFG['weights'][i] / sum_weights / CFG['tta']*tst_image_preds]   
            
tst_preds = np.mean(tst_preds, axis=0) 

del model
torch.cuda.empty_cache()

test['label'] = np.argmax(tst_preds, axis=1)
test.to_csv('submission.csv', index=False)