# Configuration

In [None]:
class CFG:
    # -----------
    # Environment
    # ----------- 
    platform = 'colab'  # Options: ['kaggle', 'colab']
    device = 'cuda'  # Options: ['cuda', 'cpu']

    # -----
    # Paths
    # -----
    test_data_path = '/content/gdrive/MyDrive/kaggle/kaggle-competition-datasets/cassava-leaf-disease-classification/'
    test_img_path = '/content/test/'
    model_path = '/content/gdrive/MyDrive/kaggle/kaggle-models/kaggle-leaf-classification-models/pretrained_models/'
    log_path = '/content/gdrive/MyDrive/kaggle/kaggle-models/kaggle-leaf-classification-models/'

    # ----
    # Data
    # ----
    n_classes = 5  # Indicates the number of classes for this classification task
    img_size = 384  # Options: [384x384, 512x512]; if VIT or deit is chosen as model, need 384 x 384
    n_epochs = 5  # Indicates the number of epochs trained
    n_folds = 5  # Indicates the number of k-cross validation
    num_workers = 2
    batch_size = 1

    # -----------------
    # Pretrained Models
    # -----------------
    model_list = [0, 1, 2, 3] # Indicates the indices of pretrained models used for inference
    use_TTA = True  
    n_TTA = 8
    pretrained = False
    debug = True


    seed = 42

if CFG.use_TTA==True:
  assert CFG.batch_size == 1

# Libraries

In [None]:
if CFG.platform == 'colab':
  !pip install tqdm --upgrade
  !pip install -U albumentations
  !pip install timm

import os
import sys
import time
from datetime import datetime
from zipfile import ZipFile
import random
import warnings
warnings.filterwarnings('ignore')
from logging import Formatter, StreamHandler, getLogger

import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold

import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torch.cuda.amp import autocast, GradScaler

import timm
import cv2
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from google.colab import drive
    
if CFG.platform == 'colab':
  package_paths = ['/content/gdrive/MyDrive//kaggle/kaggle-models/efficientnet_pytorch-0.7.0',
                  '/content/gdrive/MyDrive/kaggle/kaggle-models/FMix-master']
else:
  package_paths = ['/kaggle/input/timm-pytorch-image-models/pytorch-image-models-master']

for path in package_paths: 
  sys.path.append(path)

Collecting tqdm
  Downloading tqdm-4.62.2-py2.py3-none-any.whl (76 kB)
[?25l[K     |████▎                           | 10 kB 41.9 MB/s eta 0:00:01[K     |████████▋                       | 20 kB 2.1 MB/s eta 0:00:01[K     |█████████████                   | 30 kB 3.0 MB/s eta 0:00:01[K     |█████████████████▏              | 40 kB 3.6 MB/s eta 0:00:01[K     |█████████████████████▌          | 51 kB 3.8 MB/s eta 0:00:01[K     |█████████████████████████▉      | 61 kB 4.4 MB/s eta 0:00:01[K     |██████████████████████████████  | 71 kB 4.2 MB/s eta 0:00:01[K     |████████████████████████████████| 76 kB 2.7 MB/s 
[?25hInstalling collected packages: tqdm
  Attempting uninstall: tqdm
    Found existing installation: tqdm 4.62.0
    Uninstalling tqdm-4.62.0:
      Successfully uninstalled tqdm-4.62.0
Successfully installed tqdm-4.62.2
Collecting albumentations
  Downloading albumentations-1.0.3-py3-none-any.whl (98 kB)
[K     |████████████████████████████████| 98 kB 4.5 MB/s 
Col

# Data Import

In [None]:
if CFG.platform == 'colab':
  drive.mount('/content/gdrive')

  test_zip_path = CFG.test_data_path + "test.zip"
  test_img_path = CFG.test_img_path

  if not os.path.isdir(test_img_path):
    with ZipFile(test_zip_path, 'r') as zip_f: 
      zip_f.extractall(path='/content') 
else:
  test_img_path = '../input/cassava-leaf-disease-classification/test_images/'

Mounted at /content/gdrive


# Utils

In [None]:
def seed_everything(seed=CFG.seed):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = True

def read_img_from_path(path):
  im_bgr = cv2.imread(path)
  im_rgb = im_bgr[:, :, ::-1].copy()
  return im_rgb

# Dataset

In [None]:
from albumentations.pytorch import ToTensorV2
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

def get_heavy_transforms():
  train_transforms = Compose(
      [
        RandomResizedCrop(CFG.img_size, CFG.img_size),
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        ShiftScaleRotate(p=0.1),
        RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        ToTensorV2(p=1.0),
      ])
  
  return train_transforms


def get_test_transforms():
  test_transforms = Compose(
      [
       ToTensorV2(p=1.0)
      ]
  )
  return test_transforms


def get_light_transforms():
  light_transforms = Compose(
      [
       CenterCrop(CFG.img_size, CFG.img_size),
       ToTensorV2(p=1.0)
      ]
  )
  return light_transforms

In [None]:
class CassavaTestDataset(Dataset):
  """ Leaves Test Dataset """
  def __init__(self, img_id, transform=None):
    self.img_id = img_id
    self.transform = transform
  
  def __len__(self):
    return len(self.img_id)

  def __getitem__(self, idx): 
    img_path = CFG.test_img_path + str(self.img_id[idx])
    img = read_img_from_path(img_path)

    if self.transform:
      img = self.transform(image=img)['image']
      
    return img, self.img_id[idx]


def prepare_test_dataloader():
  test_df = pd.read_csv(CFG.test_data_path+"sample_submission.csv")
  test_id = test_df['image_id'].to_numpy()

  test_dataset = CassavaTestDataset(test_id, transform=get_test_transforms())
  test_dataloader = DataLoader(test_dataset, batch_size=CFG.batch_size, 
                               shuffle=False, num_workers=CFG.num_workers)
  return test_dataloader

# Model

In [None]:
class CassavaNet(nn.Module):
  def __init__(self, model, model_name):
    super().__init__()
    self.model = model
    self.model_name = model_name

  def forward(self, x):
    return self.model(x)
  
  def freeze(self):
    for param in self.model.parameters():
        param.requires_grad = False
        
    if 'efficientnet' in self.model_name:
        for param in self.model.classifier.parameters():
            param.requires_grad = True
    elif self.model_name == 'vit_large_patch16_384' or 'deit_base_patch16_224':
        for param in self.model.head.parameters():
            param.requires_grad = True
    elif 'resnext' in self.model_name:
        for param in self.model.fc.parameters():
            param.requires_grad = True
  
  def unfreeze(self):
    for param in self.model.parameters():
      param.requires_grad = True


def get_resnet_model():
  model = torchvision.models.resnet50(pretrained=CFG.pretrained)
  model.fc = nn.Linear(2048, CFG.n_classes)
  return model

def get_resnext_model():
  model = timm.create_model('resnext50_32x4d', pretrained=CFG.pretrained)
  n_features = model.fc.in_features
  model.fc = nn.Linear(n_features, CFG.n_classes)
  return model

def get_efficientnet_model():
  model = timm.create_model('tf_efficientnet_b4_ns', pretrained=CFG.pretrained)
  n_features = model.classifier.in_features
  model.classifier = nn.Linear(n_features, CFG.n_classes)
  return model

def get_deit_model():
  model = torch.hub.load('facebookresearch/deit:main', 
                                      'deit_base_patch16_384', pretrained=CFG.pretrained)
  n_features = model.head.in_features
  model.head = nn.Linear(n_features, CFG.n_classes)
  return model

def get_vit_model():
  model = timm.create_model('vit_large_patch16_384', pretrained=CFG.pretrained)
  n_features = model.head.in_features
  model.head = nn.Linear(n_features, CFG.n_classes)
  return model

In [None]:
def get_model(model_name):
  model = None

  if 'efficientnet' in model_name:
    model = get_efficientnet_model()
  elif 'deit' in model_name:
    model = get_deit_model()
  elif 'vit' in model_name:
    model = get_vit_model()
  elif 'resnext' in model_name:
    model = get_resnext_model()
  elif  'resnet' in model_name:
    model = get_resnet_model()
  else:
    raise ValueError("Invalid model choice")
  
  return CassavaNet(model, model_name)

In [None]:
def load_pretrained_models():
  models = []
  count = 0

  for model_fpath in os.listdir(CFG.model_path):
      full_path = CFG.model_path+model_fpath

      if not os.path.isdir(full_path) and count in CFG.model_list:
          print("Model Loaded:", model_fpath)
          model_name = model_fpath.split('_f')[0]
          print(model_name)
          model = get_model(model_name)
          info = torch.load(full_path, map_location=torch.device(CFG.device))
          model.load_state_dict(info)
          models.append(model)
          
      count+=1
  
  return models

# Inference

In [None]:
def infer():
  test_dataloader = prepare_test_dataloader()
  test_img_ids, test_pred_labels = [], []

  # Construct for the purpose of testing
  with torch.no_grad():
    for img, img_filename in test_dataloader:
      if CFG.use_TTA == False: # No TTA
        voting = np.zeros((len(models), CFG.batch_size, CFG.n_classes))
        imgs = np.zeros((CFG.batch_size, 3, CFG.img_size, CFG.img_size))
      else: # With TTA
        heavy_transforms = get_heavy_transforms()
        voting = np.zeros((len(models), CFG.n_TTA, CFG.n_classes))
        imgs = np.zeros((CFG.n_TTA, 3, CFG.img_size, CFG.img_size))

        for aug_no in range(CFG.n_TTA):
            img_np = torch.squeeze(img).numpy()
            img_np = img_np.reshape((img_np.shape[1], img_np.shape[2], -1))
            trans_img = heavy_transforms(image=img_np)['image']
            imgs[aug_no, :, :, :] = trans_img.numpy()

        imgs = torch.from_numpy(imgs).to(torch.float32).to(CFG.device)

      # Ensemble models
      for model_idx in range(len(models)):
          model = models[model_idx]
          model = model.to(CFG.device)
          model.eval()            

          logits = model(imgs)
          voting[model_idx, :, :] = F.softmax(logits).cpu().numpy()

      if CFG.use_TTA:
        voting = np.sum(voting, axis=1)/CFG.n_TTA
      voting = np.sum(voting, axis=0)/len(models)

      pred_label = np.argmax(voting)
      # The file name is formatted as img_id.jpeg
      img_id = img_filename[0][:-4] 

      test_img_ids.append(img_id)
      test_pred_labels.append(pred_label)
    
  return test_img_ids, test_pred_labels


# Main

In [None]:
if __name__ == '__main__':
  # Convert the submission dataframe to csv
  seed_everything()
  
  if CFG.debug == True:
    CFG.model_list = [0]

  models = load_pretrained_models()
  test_img_ids, test_pred_labels = infer()

  output_path = '/content/submission.csv' if CFG.platform == 'colab' else '../output/kaggle/working/'
  column_header = ['image_id', 'label']
  submission = pd.DataFrame(zip(test_img_ids, test_pred_labels), columns=column_header)
  submission.to_csv(path_or_buf = output_path, index = False)

Model Loaded: resnext50_32x4d_f1_b0.894.pth
resnext50_32x4d
