In [1]:
#sam                        cv:0.895 lb:0.893
#use loss_kd_regularization cv:      lb:
#use ema
#add 2019 traindata         cv:      lb:

In [2]:
import sys
from datetime import datetime
import time
import random
import cv2
import pandas as pd
import numpy as np
import os
from os.path import isfile
import albumentations as A
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.metrics import classification_report, accuracy_score
from albumentations.pytorch.transforms import ToTensorV2
from sklearn.model_selection import StratifiedKFold
from glob import glob
from copy import deepcopy
import collections
import warnings
import timm
from timm.scheduler import CosineLRScheduler
import math
from copy import deepcopy

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data.sampler import SequentialSampler, RandomSampler, Sampler
from torch.optim import lr_scheduler
import torch.nn.functional as F

from bi_tempered_loss_pytorch import bi_tempered_logistic_loss
from aug_mix import RandomAugMix
from sam.sam import SAM
from FMix.fmix import sample_mask


In [3]:
warnings.filterwarnings("ignore")

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

In [4]:
class TrainGlobalConfig:

#     fold_num = 0
    n_class = 5
    train_path = "./input/train_images/"
    train_2019_path = "./input_2019/train/train/"
    loss_f = "loss_kd_regularization"
    label_smoothing = 0.2
    gamma = 2.0
    num_workers = 4
    resample = False
    n_epochs = 40
    
    use_sam = True
    use_ema = True #use model ema
    use_mix_precision = False
    use_mixup = False
    use_cutmix = False
    use_fmix = True
    
    accumulate_steps = 1
    train_batch_size = 12
    test_batch_size = 32
    early_stop_patience = 5
    lr = 1e-4
    weight_decay = 1e-6
    seed = 2020
    n_fold = 5
    image_size = 512
    folder = 'effnet-b4-ema'

    # -------------------
    verbose = True
    verbose_step = 1
    # -------------------

    # --------------------
    step_scheduler = False  # do scheduler.step after optimizer.step
    validation_scheduler = True #True  # do scheduler.step after validation stage loss

    SchedulerClass = lr_scheduler.CosineAnnealingLR
    scheduler_params = dict(
        T_max=n_epochs
    )
    

In [5]:
def loss_kd_regularization(outputs, labels):
    """
    loss function for mannually-designed regularization: Tf-KD_{reg}
    """
    alpha = 0.1
    T = 20
    correct_prob = 0.99    # the probability for correct class in u(k)
    loss_CE = F.cross_entropy(outputs, labels)
    K = outputs.size(1)

    teacher_soft = torch.ones_like(outputs).cuda()
    teacher_soft = teacher_soft*(1-correct_prob)/(K-1)  # p^d(k)
    for i in range(outputs.shape[0]):
        teacher_soft[i ,labels[i]] = correct_prob
    loss_soft_regu = nn.KLDivLoss()(F.log_softmax(outputs, dim=1), F.softmax(teacher_soft/T, dim=1))*50

    KD_loss = (1. - alpha)*loss_CE + alpha*loss_soft_regu

    return KD_loss

class LabelSmoothingCrossEntropy(nn.Module):
    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x, target):
        logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()


class BinaryFocalLoss(nn.Module):

    def __init__(self, gamma=2.0, eps=1e-7, reduction='mean'):
        super(BinaryFocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
        self.reduction = reduction

    def forward(self, input, target):

        probs = torch.sigmoid(input)
        log_probs = -torch.log(probs)
        loss = torch.sum(torch.pow(1 - probs + self.eps, self.gamma).mul(log_probs).mul(target), dim=1)

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class FocalLoss(nn.Module):

    def __init__(self, gamma=0, eps=1e-7, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
        self.reduction = reduction

    def forward(self, input, target):

        y = F.one_hot(target, num_classes=input.size(-1)).to(input.device)

        logit = F.softmax(input, dim=-1)
        logit = logit.clamp(self.eps, 1. - self.eps)

        loss = -1 * y * torch.log(logit)  # cross entropy
        loss = loss * (1 - logit) ** self.gamma  # focal loss
        # loss = torch.sum(loss, dim=1)

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

def focal_loss(labels, logits, alpha, gamma):
    """Compute the focal loss between `logits` and the ground truth `labels`.
    Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
    where pt is the probability of being classified to the true class.
    pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).
    Args:
        labels: A float tensor of size [batch, num_classes].
        logits: A float tensor of size [batch, num_classes].
        alpha: A float tensor of size [batch_size]
            specifying per-example weight for balanced cross entropy.
        gamma: A float scalar modulating loss from hard and easy examples.
    Returns:
        focal_loss: A float32 scalar representing normalized total loss.
     """
    BCLoss = F.binary_cross_entropy_with_logits(input=logits, target=labels, reduction="none")

    if gamma == 0.0:
        modulator = 1.0
    else:
        modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 +
                                                                               torch.exp(-1.0 * logits)))

    loss = modulator * BCLoss

    weighted_loss = alpha * loss
    focal_loss = torch.sum(weighted_loss)

    focal_loss /= torch.sum(labels)
    return focal_loss

class CB_loss(nn.Module):

    def __init__(self, samples_per_cls, n_class=5, loss_type="sigmoid", beta=0.9999, gamma=2.0):

        effective_num = 1.0 - np.power(beta, samples_per_cls)
        weights = (1.0 - beta) / np.array(effective_num)
        self.weights = weights / np.sum(weights) * n_class


        self.n_class = n_class
        self.loss_type = loss_type
        self.gamma =gamma

    def forward(self, input, target):

        weights = self.weights.copy()

        weights = torch.tensor(weights).float()
        weights = weights.unsqueeze(0)
        weights = weights.repeat(target.shape[0], 1) * target
        weights = weights.sum(1)
        weights = weights.unsqueeze(1)
        weights = weights.repeat(1, self.n_class)

        if self.loss_type == "focal":
            cb_loss = focal_loss(target, input, weights, self.gamma)
        elif self.loss_type == "sigmoid":
            cb_loss = F.binary_cross_entropy_with_logits(input=input, target=target, weights=weights)
        elif self.loss_type == "softmax":
            pred = input.softmax(dim=1)
            cb_loss = F.binary_cross_entropy(input=pred, target=target, weight=weights)
        return cb_loss

def loss_function(config):
    if config.loss_f == "cross_entropy":
        return nn.CrossEntropyLoss()
    if config.loss_f == "focal":
        return FocalLoss(gamma=config.gamma)
    if config.loss_f == "binary_cross_entropy":
        return nn.BCEWithLogitsLoss()
    if config.loss_f == "binary_focal_loss":
        return BinaryFocalLoss(gamma=config.gamma)
    if config.loss_f == "cb_loss":
        return CB_loss(config.sample_per_class)
    if config.loss_f == "labelsmoothingcrossentropy":
        return LabelSmoothingCrossEntropy(config.label_smoothing)
    if config.loss_f == "bi-tempered":
        return bi_tempered_logistic_loss
    if config.loss_f == "loss_kd_regularization":
        return loss_kd_regularization

In [6]:
def get_train_transforms(image_size):
    return A.Compose(
        [
            A.RandomRotate90(p=0.5),
            A.Transpose(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
#             A.ShiftScaleRotate(
#                 shift_limit=(-0.2, 0.2), 
#                 scale_limit=(-0.2, 0.2), 
#                 rotate_limit=(-20, 20), 
#                 interpolation=cv2.INTER_LINEAR, 
#                 border_mode=cv2.BORDER_REFLECT_101,
#                 p=0.8),
            A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            A.CoarseDropout(
                p=0.5,
                max_holes=100,
                max_height=50,
                max_width=50,
                min_holes=30,
                min_height=20,
                min_width=20,
            ),
            A.Resize(height=image_size, width=image_size, p=1.0),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ],
        p=1.0,
    )

In [7]:
def get_valid_transforms(image_size):
    return A.Compose(
        [   
#             A.RandomRotate90(p=0.5),
#             A.Transpose(p=0.5),
#             A.HorizontalFlip(p=0.5),
#             A.VerticalFlip(p=0.5),
            A.Resize(height=image_size, width=image_size, p=1.0),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ],
        p=1.0)


In [8]:
def expand_path(img_name, cfg):
    
    img_name = str(img_name)
    if isfile(cfg.train_path + img_name):
        return cfg.train_path + img_name
    if isfile(cfg.train_2019_path + "cbb/" + img_name):
        return cfg.train_2019_path + "cbb/" + img_name
    if isfile(cfg.train_2019_path + "cbsd/" + img_name):
        return cfg.train_2019_path + "cbsd/" + img_name
    if isfile(cfg.train_2019_path + "cgm/" + img_name):
        return cfg.train_2019_path + "cgm/" + img_name
    if isfile(cfg.train_2019_path + "cmd/" + img_name):
        return cfg.train_2019_path + "cmd/" + img_name
    if isfile(cfg.train_2019_path + "healthy/" + img_name):
        return cfg.train_2019_path + "healthy/" + img_name

    return img_name

In [9]:
class Cassava_dataset(Dataset):
    def __init__(self, cfg, dataframe, transforms=None, mode="train"):
        self.config = cfg
        self.image_size = cfg.image_size
        self.df = dataframe
        self.transforms = transforms
        self.mode = mode
        self.mixup = True
    def __len__(self) -> int:
        return len(self.df)
    def __getitem__(self, index: int):
        image, label = self.get_image_and_label(index)
        if self.transforms:
            image = self.transforms(image=image)["image"]
        return image, label
    def get_image_and_label(self, index):
        data = self.df.iloc[index]
        image_id = data["image_id"]
        image_path = expand_path(image_id, self.config)
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#.astype(np.float32)
        label = data["label"]
        return image, label

In [10]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [11]:
def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

In [12]:
def cutmix_data(x, y, alpha=1.0, use_cuda=True):

    mixed_x = x.clone()
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    def rand_bbox(size, lam):
        W = size[2]
        H = size[3]
        cut_rat = np.sqrt(1. - lam)
        cut_w = np.int(W * cut_rat)
        cut_h = np.int(H * cut_rat)

        # uniform
        cx = np.random.randint(W)
        cy = np.random.randint(H)

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        return bbx1, bby1, bbx2, bby2

    batch_size = x.size()[0]
    if use_cuda:
        rand_index = torch.randperm(batch_size).cuda()
    else:
        rand_index = torch.randperm(batch_size)

    bbx1, bby1, bbx2, bby2 = rand_bbox(mixed_x.size(), lam)
    mixed_x[:, :, bbx1:bbx2, bby1:bby2] = mixed_x[rand_index, :, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (mixed_x.size()[-1] * mixed_x.size()[-2]))
    y_a, y_b = y, y[rand_index]

    return mixed_x, y_a, y_b, lam

In [13]:
def fmix_data(x, y, alpha=1.0, decay_power=3, size=(512, 512), max_soft=0.0, reformulate=False, use_cuda=True):
    
    lam, mask = sample_mask(alpha, decay_power, size, max_soft, reformulate)
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)
    mask = torch.from_numpy(mask).float().to(x.device)
    
    mixed_x = mask * x + (1 - mask) * x[index, :]
    y_a, y_b = y, y[index]
    
    return mixed_x, y_a, y_b, lam

In [14]:
def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [15]:
def is_parallel(model):
    # is model is parallel with DP or DDP
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)

class ModelEMA:
    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training schemes to perform well.
    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """

    def __init__(self, model, decay=0.9999, updates=0):
        # Create EMA
        self.ema = deepcopy(model.module if is_parallel(model) else model).eval()  # FP32 EMA
        # if next(model.parameters()).device.type != 'cpu':
        #     self.ema.half()  # FP16 EMA
        self.updates = updates  # number of EMA updates
        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))  # decay exponential ramp (to help early epochs)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        # Update EMA parameters
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)

            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1. - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)

In [16]:
class Fitter:

    def __init__(self, model, device, config):
        
        self.config = config
        self.epoch = 0
        self.early_stop_count = 0
        self.fold_num = config.fold_num
        self.base_dir = f'./{config.folder}'
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)
        self.log_path = f'{self.base_dir}/log.txt'
        self.best_summary_loss = 10 ** 5
        self.best_score = 0.0
        self.model = model
        self.device = device
        
        self.use_sam = config.use_sam
        self.use_mix_precision = config.use_mix_precision
        self.use_mixup = config.use_mixup
        self.use_cutmix = config.use_cutmix
        self.use_fmix = config.use_fmix
        #ema
        self.use_ema = config.use_ema
        self.model_ema = None
        self.model_ema_decay = 0.999
        
        self.accumulate_steps = config.accumulate_steps
        self.valset_size = config.valset_size
        self.test_batch_size = config.test_batch_size
        self.loss_function_train = loss_function(config)
        self.loss_function_test = nn.CrossEntropyLoss()
        
        self.early_stop_patience = config.early_stop_patience
        
        self.target_names = ["cbb", "cbsd", "cgm", "cmd", "healthy"]
        self.model.to(self.device)
        
        if config.use_sam:
            base_optimizer = torch.optim.Adam
            self.optimizer = SAM(self.model.parameters(), base_optimizer, lr=config.lr, weight_decay=config.weight_decay)
        else:
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        if self.use_mix_precision:
            self.scaler = GradScaler()
        if self.use_ema:
            self.model_ema = ModelEMA(self.model, decay=self.model_ema_decay)
            
        self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params)
        self.log(f'Fitter prepared. Device is {self.device}. fold num is {self.fold_num}')

    def fit(self, train_loader, validation_loader):
        
        for e in range(self.config.n_epochs):
            if self.config.verbose:
                lr = self.optimizer.param_groups[0]['lr']
                timestamp = datetime.utcnow().isoformat()
                self.log(f'\n{timestamp}\nLR: {lr}')
            t = time.time()
            summary_loss = self.train_one_epoch(train_loader)
            self.log(
                f'[RESULT]: Train. Epoch: {self.epoch}, summary_loss: {summary_loss.avg:.5f}, time: {(time.time() - t):.5f}')
            self.save(f'{self.base_dir}/last-checkpoint.bin')
            t = time.time()
            
            if self.use_ema:
                summary_loss, score, report = self.validation(self.model_ema.ema.module if hasattr(self.model_ema.ema, 'module') else self.model_ema.ema, 
                                               validation_loader)
            else:
                summary_loss, score, report = self.validation(self.model, validation_loader)
            
            self.log(
                f'[RESULT]: Val. Epoch: {self.epoch}, summary_loss: {summary_loss.avg:.5f}, time: {(time.time() - t):.5f}')
            self.log(f'[RESULT]: Val. Score: {score:.5f}')
            self.log(f'[RESULT]: Val. Report: \n{report}')
            
            if summary_loss.avg < self.best_summary_loss:
                self.best_summary_loss = summary_loss.avg
                self.model.eval()
                self.save(f'{self.base_dir}/{self.fold_num}-best-checkpoint-{str(self.epoch).zfill(3)}epoch.bin')
                for path in sorted(glob(f'{self.base_dir}/{self.fold_num}-best-checkpoint-*epoch.bin'))[:-3]:
                    os.remove(path)
                    
            if score > self.best_score:
                self.best_score = score
                self.model.eval()
                self.save(f'{self.base_dir}/{self.fold_num}-best-score-checkpoint-{str(self.epoch).zfill(3)}epoch.bin')
                for path in sorted(glob(f'{self.base_dir}/{self.fold_num}-best-score-checkpoint-*epoch.bin'))[:-3]:
                    os.remove(path)
                self.early_stop_count = 0
            else:
                self.early_stop_count += 1
            
            if self.early_stop_count == self.early_stop_patience:
                break
            #break
            if self.config.validation_scheduler:
                self.scheduler.step()
            self.epoch += 1

    def validation(self, model, val_loader):
        
        model.eval()
        summary_loss = AverageMeter()
        t = time.time()
        predicts = np.zeros(self.valset_size)
        truths = np.zeros(self.valset_size)
        for step, (images, labels) in enumerate(val_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Val Step {step}/{len(val_loader)}, ' + \
                        f'summary_loss: {summary_loss.avg:.5f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            start = step * self.test_batch_size
            end = min(start + self.test_batch_size, self.valset_size)
            with torch.no_grad():
                images = images.to(self.device).float()
                targets = labels.long().to(self.device)
                outputs = self.model(images)
                loss = self.loss_function_test(outputs, targets)
                summary_loss.update(loss.detach().item())
                probs = torch.softmax(outputs, dim=1)
                predicts[start:end] = np.argmax(probs.detach().cpu().numpy(), axis=1).flatten()
                truths[start:end] = labels.detach().cpu().numpy().flatten()
            #break
            
        score = accuracy_score(truths, predicts)
        report = classification_report(truths, predicts, target_names=self.target_names)
        
        return summary_loss, score, report

    def train_one_epoch(self, train_loader):
        
        self.model.train()
        summary_loss = AverageMeter()
        t = time.time()
        self.optimizer.zero_grad()  # very important
        for step, (images, labels) in enumerate(train_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    print(
                        f'Train Step {step}/{len(train_loader)}, ' + \
                        f'summary_loss: {summary_loss.avg:.5f}, ' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            images = images.float().to(self.device)
            targets = labels.long().to(self.device)
            original_images = True

            if self.use_mixup and random.random() < 0.5:
                images, targets_a, targets_b, lam = mixup_data(images, targets, alpha=1.0, use_cuda=True)
                original_images = False
            
            if self.use_cutmix and original_images and random.random() < 0.5:
                images, targets_a, targets_b, lam = cutmix_data(images, targets)
                original_images = False
                
            if self.use_fmix and original_images and random.random() < 0.5:
                images, targets_a, targets_b, lam = fmix_data(images, targets)
                original_images = False
                
            if original_images:
                targets_a = targets
                targets_b = targets
                lam = 1.0

            if self.use_mix_precision:
                assert self.use_sam == False
                with autocast():
                    outputs = self.model(images)
                    loss = mixup_criterion(self.loss_function_train, outputs, targets_a, targets_b, lam)
                    loss /= self.accumulate_steps
                self.scaler.scale(loss).backward()
                if (step + 1) % self.accumulate_steps == 0:  # Wait for several backward steps
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()
            elif self.use_sam:
                loss = mixup_criterion(self.loss_function_train, self.model(images), targets_a, targets_b, lam)
                loss.backward()
                self.optimizer.first_step(zero_grad=True)
                    
                loss = mixup_criterion(self.loss_function_train, self.model(images), targets_a, targets_b, lam)
                loss.backward()
                self.optimizer.second_step(zero_grad=True)
            else:
                assert self.use_sam == False
                outputs = self.model(images)
                loss = mixup_criterion(self.loss_function_train, outputs, targets_a, targets_b, lam)
                loss /= self.accumulate_steps
                loss.backward()
                if (step + 1) % self.accumulate_steps == 0:  # Wait for several backward steps
                    self.optimizer.step()
                    self.optimizer.zero_grad()
            
            if self.model_ema is not None:
                self.model_ema.update(self.model)
            
            summary_loss.update(loss.detach().item() * self.accumulate_steps)
            if self.config.step_scheduler and (step + 1) % self.accumulate_steps == 0:
                self.scheduler.step()
            #break
                    
        return summary_loss

    def save(self, path):
        self.model.eval()
        
        if self.use_ema:
            torch.save({
                'model_state_dict': self.model.state_dict(),
                'state_dict_ema': self.model_ema.ema.module.state_dict() if hasattr(self.model_ema.ema, 'module') else self.model_ema.ema.state_dict(),
            }, path)
        else:
            torch.save({
                'model_state_dict': self.model.state_dict(),
            }, path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])

    def log(self, message):
        if self.config.verbose:
            print(message)
        with open(self.log_path, 'a+') as logger:
            logger.write(f'{message}\n')

In [17]:
class ImbalancedDatasetSampler(Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices (list, optional): a list of indices
        num_samples (int, optional): number of samples to draw
        callback_get_label func: a callback-like function which takes two arguments - dataset and index
    """

    def __init__(self, dataset, indices=None, num_samples=None):

        # if indices is not provided,
        # all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) \
            if indices is None else indices

        # if num_samples is not provided,
        # draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices) \
            if num_samples is None else num_samples

        # distribution of classes in the dataset
        label_to_count = {}
        for idx in self.indices:
            label = self._get_label(dataset, idx)
            if label in label_to_count:
                label_to_count[label] += 1
            else:
                label_to_count[label] = 1
        # weight for each sample
        weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
                   for idx in self.indices]
        self.weights = torch.DoubleTensor(weights)

    def _get_label(self, dataset, idx):
        return dataset.df.iloc[idx]["label"]

    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(
            self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples

In [18]:
class UpsampleSampler(Sampler):

    def __init__(self, dataset):
        
        self.indices = list(range(len(dataset)))
        self.classes_num = len(np.unique(dataset.df["label"].values))
        label_to_count = [0] * self.classes_num
        indexes_per_class = []
        self.indexes_per_class = [[],[],[],[],[]]
        print(self.indexes_per_class)
        for idx in self.indices:
            label = self._get_label(dataset, idx)
            label_to_count[label] += 1
            self.indexes_per_class[label].append(idx)
            
        self.min_mun = min(label_to_count) 
        self.length = 0
        for k in range(self.classes_num):
            if len(self.indexes_per_class[k]) >= self.min_mun*3:
                self.length += len(self.indexes_per_class[k])
            else:
                self.length += self.min_mun*2
        print(self.length)

    def __iter__(self):
        
        all_indexs = []
        
        for k in range(self.classes_num):
            if len(self.indexes_per_class[k]) >= self.min_mun*2:
                all_indexs.extend(self.indexes_per_class[k])
            else:
                gap = self.min_mun*2 - len(self.indexes_per_class[k])
                random_choice = np.random.choice(self.indexes_per_class[k], int(gap), replace=True)
                all_indexs.extend(list(random_choice) + list(self.indexes_per_class[k]))
        
        l = np.array(all_indexs)
        l = l.reshape(-1)
        random.shuffle(l)
        return iter(l)

    def __len__(self):
        return int(self.length)
    
    def _get_label(self, dataset, idx):
        return dataset.df.iloc[idx]["label"]

In [19]:
def run_training(net, train_dataset, validation_dataset, config):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.train_batch_size,
        sampler=RandomSampler(train_dataset) if config.resample == False else UpsampleSampler(train_dataset),
        pin_memory=False,
        drop_last=True,
        num_workers=config.num_workers,
    )
    val_loader = DataLoader(
        validation_dataset,
        batch_size=config.test_batch_size,
        num_workers=config.num_workers,
        shuffle=False,
        sampler=SequentialSampler(validation_dataset),
        pin_memory=False,
    )

    fitter = Fitter(model=net, device=device, config=config)
    fitter.fit(train_loader, val_loader)

In [20]:
class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(in_features, n_class)
        
    def forward(self, x):
        x = self.model(x)
        
        return x

In [21]:
if __name__ == "__main__":
    seed_everything(TrainGlobalConfig.seed)
    train_csv = pd.read_csv("./input/train.csv")
#     train_2019_csv = pd.read_csv("./input_2019/train.csv")
    kfolds = StratifiedKFold(n_splits=TrainGlobalConfig.n_fold,
                             random_state=TrainGlobalConfig.seed,
                             shuffle=True).split(np.arange(train_csv.shape[0]), train_csv.label.values)
    for fold, (train_idx, val_idx) in enumerate(kfolds):
        if fold == 0: continue
        TrainGlobalConfig.fold_num = fold
        
        net = CassvaImgClassifier(model_arch="tf_efficientnet_b4_ns",
                                  n_class=TrainGlobalConfig.n_class, 
                                  pretrained=True)
        train_df = train_csv.loc[train_idx, :].reset_index(drop=True)
        valid_df = train_csv.loc[val_idx, :].reset_index(drop=True)
        TrainGlobalConfig.valset_size = len(valid_df)

        train_dataset = Cassava_dataset(TrainGlobalConfig,
                                        train_df,
                                        transforms=get_train_transforms(TrainGlobalConfig.image_size),
                                        mode="train")

        val_dataset = Cassava_dataset(TrainGlobalConfig,
                                      valid_df,
                                      transforms=get_valid_transforms(TrainGlobalConfig.image_size),
                                      mode="val")

        run_training(net=net, train_dataset=train_dataset, validation_dataset=val_dataset, config=TrainGlobalConfig)

Fitter prepared. Device is cuda. fold num is 0

2021-01-08T01:18:35.118175
LR: 0.0001
[RESULT]: Train. Epoch: 0, summary_loss: 1.10148, time: 1700.73007
[RESULT]: Val. Epoch: 0, summary_loss: 0.84572, time: 52.79914
[RESULT]: Val. Score: 0.87360
[RESULT]: Val. Report: 
              precision    recall  f1-score   support

         cbb       0.51      0.65      0.57       218
        cbsd       0.90      0.68      0.77       438
         cgm       0.85      0.73      0.79       477
         cmd       0.95      0.98      0.96      2631
     healthy       0.69      0.75      0.72       516

    accuracy                           0.87      4280
   macro avg       0.78      0.76      0.76      4280
weighted avg       0.88      0.87      0.87      4280


2021-01-08T01:47:52.089972
LR: 9.98458666866564e-05
[RESULT]: Train. Epoch: 1, summary_loss: 1.03389, time: 1700.28576
[RESULT]: Val. Epoch: 1, summary_loss: 0.78321, time: 52.31196
[RESULT]: Val. Score: 0.88481
[RESULT]: Val. Report: 
    