<a href="https://colab.research.google.com/github/koukyo1994/kaggle-bengali-ai/blob/master/notebook/Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dependencies

In [None]:
%%sh
pip install albumentations==0.4.3 easydict==1.9.0 >> /dev/null
pip install efficientnet-pytorch==0.6.1 PyYAML==5.3 >> /dev/null
pip install pretrainedmodels==0.7.4 >> /dev/null

## Integration with Google Drive

In [2]:
from google.colab import drive
drive.mount("/content/gdrive")

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [4]:
%%sh
mkdir input
cp -r /content/gdrive/My\ Drive/kaggle-bengali ./input/bengaliai-cv19
unzip -qq -d input/bengaliai-cv19/ input/bengaliai-cv19/train_images.zip

cp: cannot open '/content/gdrive/My Drive/kaggle-bengali/Kaggle bengaliai-cv19.gsheet' for reading: Operation not supported


In [1]:
!nvidia-smi

Wed Mar  4 17:33:42 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.59       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

## Libraries

In [5]:
import os
import random

import albumentations as A
import cv2
import numpy as np
import pandas as pd
import pretrainedmodels
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata
import torchvision.models as models
import yaml

from pathlib import Path
from typing import Tuple, Dict, Union, Optional, List

from easydict import EasyDict as edict
from efficientnet_pytorch import EfficientNet
from fastprogress import progress_bar
from skimage.transform import AffineTransform, warp
from sklearn.metrics import recall_score, confusion_matrix
from sklearn.model_selection import KFold, train_test_split
from torch.nn.parameter import Parameter
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import (ReduceLROnPlateau, 
                                      CosineAnnealingLR,
                                      CosineAnnealingWarmRestarts)

alchemy not available, to install alchemy, run `pip install alchemy-catalyst`.


## Settings

In [None]:
stage = 0
trial = "seresnext_size236_90epoch_no_fold"

## Config

In [None]:
conf_string = '''
dataset:
  train:
    affine: False
    morphology: False
    binarization: False
  val:
    affine: False
    morphology: False
    binarization: False
  test:
    affine: False
    morphology: False

data:
  train_df_path: input/bengaliai-cv19/train.csv
  train_images_path: input/bengaliai-cv19/train_images
  test_images_path: input/bengaliai-cv19/test_images
  sample_submission_path: input/bengaliai-cv19/sample_submission.csv

model:
  model_name: se_resnext50_32x4d
  pretrained: imagenet
  num_classes: 186
  head: custom
  in_channels: 3

train:
  batch_size: 64
  num_epochs: 15

test:
  batch_size: 100

loss:
  name: ohem
  params:
    n_grapheme: 168
    n_vowel: 11
    n_consonant: 7
    weights:
      - 2.0
      - 1.0
      - 1.0

optimizer:
  name: Adam
  params:
    lr: 0.00005

scheduler:
  name: cosine
  params:
    T_max: 10

transforms:
  train:
    Noise: False
    Contrast: False
    Rotate: True
    RandomScale: True
    Cutout:
      num_holes: 1
      max_h_size: 50
      max_w_size: 75
      fill_value: 255
      always_apply: True
    ShiftScaleRotate: False
    RandomResizedCrop: False
    CoarseDropout: False
    GridDistortion: False
  val:
    Noise: False
    Contrast: False
    Rotate: False
    RandomScale: False
    Cutout:
      num_holes: 0
  test:
    Noise: False
    Contrast: False
    Rotate: False
    RandomScale: False
    Cutout:
      num_holes: 0

val:
  name: no_fold
  params:

callback:
  params:
    mixup_prob: 0.5
    cutmix_prob: 0.0
    no_aug_epochs: 5

log_dir: log/
num_workers: 2
seed: 1213
img_size: 236
weights:
'''

In [None]:
def _get_default():
    cfg = edict()

    # dataset
    cfg.dataset = edict()
    cfg.dataset.train = edict()
    cfg.dataset.val = edict()
    cfg.dataset.test = edict()
    cfg.dataset.train.affine = False
    cfg.dataset.train.morphology = False
    cfg.dataset.train.binarization = False
    cfg.dataset.val.affine = False
    cfg.dataset.val.morphology = False
    cfg.dataset.val.binarization = False
    cfg.dataset.test.affine = False
    cfg.dataset.test.morphology = False
    cfg.dataset.test.binarization = False

    # dataset
    cfg.data = edict()

    # model
    cfg.model = edict()
    cfg.model.model_name = "resnet18"
    cfg.model.num_classes = 186
    cfg.model.pretrained = True
    cfg.model.head = "linear"
    cfg.model.in_channels = 3
    # cfg.model.outputs = ["grapheme", "vowel", "consonant"]

    # train
    cfg.train = edict()

    # test
    cfg.test = edict()

    # loss
    cfg.loss = edict()
    cfg.loss.params = edict()

    # optimizer
    cfg.optimizer = edict()
    cfg.optimizer.params = edict()

    # scheduler
    cfg.scheduler = edict()
    cfg.scheduler.params = edict()

    # transforms:
    cfg.transforms = edict()
    cfg.transforms.train = edict()
    cfg.transforms.train.HorizontalFlip = False
    cfg.transforms.train.VerticalFlip = False
    cfg.transforms.train.Noise = False
    cfg.transforms.train.Contrast = False
    cfg.transforms.train.Rotate = False
    cfg.transforms.train.RandomScale = False
    cfg.transforms.train.Cutout = edict()
    cfg.transforms.train.Cutout.num_holes = 0
    cfg.transforms.train.ShiftScaleRotate = False
    cfg.transforms.train.RandomResizedCrop = False
    cfg.transforms.train.CoarseDropout = False
    cfg.transforms.train.GridDistortion = False
    cfg.transforms.val = edict()
    cfg.transforms.val.HorizontalFlip = False
    cfg.transforms.val.VerticalFlip = False
    cfg.transforms.val.Noise = False
    cfg.transforms.val.Contrast = False
    cfg.transforms.val.Rotate = False
    cfg.transforms.val.RandomScale = False
    cfg.transforms.val.Cutout = edict()
    cfg.transforms.val.Cutout.num_holes = 0
    cfg.transforms.val.ShiftScaleRotate = False
    cfg.transforms.val.RandomResizedCrop = False
    cfg.transforms.val.CoarseDropout = False
    cfg.transforms.val.GridDistortion = False
    cfg.transforms.test = edict()
    cfg.transforms.test.HorizontalFlip = False
    cfg.transforms.test.VerticalFlip = False
    cfg.transforms.test.Noise = False
    cfg.transforms.test.Contrast = False
    cfg.transforms.test.Rotate = False
    cfg.transforms.test.RandomScale = False
    cfg.transforms.test.Cutout = edict()
    cfg.transforms.test.Cutout.num_holes = 0
    cfg.transforms.test.ShiftScaleRotate = False
    cfg.transforms.test.RandomResizedCrop = False
    cfg.transforms.test.CoarseDropout = False
    cfg.transforms.test.GridDistortion = False
    cfg.transforms.mean = [0.485, 0.456, 0.406]
    cfg.transforms.std = [0.229, 0.224, 0.225]

    # val
    cfg.val = edict()
    cfg.val.params = edict()

    cfg.callbacks = []

    return cfg


def _merge_config(src: edict, dst: edict):
    if not isinstance(src, edict):
        return
    for k, v in src.items():
        if isinstance(v, edict):
            _merge_config(src[k], dst[k])
        else:
            dst[k] = v

In [None]:
cfg = edict(yaml.load(conf_string, Loader=yaml.SafeLoader))
config = _get_default()
_merge_config(cfg, config)

## Environmental settings

In [10]:
def seed_torch(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    
seed_torch(config["seed"] + stage)

In [None]:
output_base_dir = Path("output")
output_base_dir.mkdir(exist_ok=True, parents=True)

train_images_path = Path(config.data.train_images_path)

## Data and utilities preparation

### transforms

In [None]:
def get_transforms(config: edict, phase: str = "train"):
    assert phase in ["train", "valid", "test"]
    if phase == "train":
        cfg = config.transforms.train
    elif phase == "valid":
        cfg = config.transforms.val
    elif phase == "test":
        cfg = config.transforms.test
    list_transforms = []
    if cfg.HorizontalFlip:
        list_transforms.append(A.HorizontalFrip())
    if cfg.VerticalFlip:
        list_transforms.append(A.VerticalFlip())
    if cfg.Rotate:
        list_transforms.append(A.Rotate(limit=15))
    if cfg.RandomScale:
        list_transforms.append(A.RandomScale())
    if cfg.Noise:
        list_transforms.append(
            A.OneOf(
                [A.GaussNoise(), A.IAAAdditiveGaussianNoise()], p=0.5))
    if cfg.Contrast:
        list_transforms.append(
            A.OneOf(
                [A.RandomContrast(0.5),
                 A.RandomGamma(),
                 A.RandomBrightness()],
                p=0.5))
    if cfg.Cutout.num_holes > 0:
        list_transforms.append(A.Cutout(**cfg.Cutout))
    if cfg.ShiftScaleRotate:
        list_transforms.append(
            A.ShiftScaleRotate(
                shift_limit=0.0625, scale_limit=0, rotate_limit=7, p=0.5))
    if cfg.RandomResizedCrop:
        list_transforms.append(
            A.RandomResizedCrop(128, 128, scale=(0.8, 1), p=0.5))
    if cfg.CoarseDropout:
        list_transforms.append(
            A.CoarseDropout(max_holes=8, max_height=8, max_width=8, p=0.2))
    if cfg.GridDistortion:
        list_transforms.append(A.GridDistortion(p=0.2))

    list_transforms.append(
        A.Normalize(
            mean=config.transforms.mean,
            std=config.transforms.std,
            p=1,
            always_apply=True))

    return A.Compose(list_transforms, p=1.0)

### Data Loading

In [None]:
df = pd.read_csv(config.data.train_df_path)

transforms = get_transforms(config, "train")

cls_levels = {
    "grapheme": df.grapheme_root.nunique(),
    "vowel": df.vowel_diacritic.nunique(),
    "consonant": df.consonant_diacritic.nunique()
}

## Dataset and DataLoader

In [None]:
class BaseDataset(torchdata.Dataset):
    def __init__(self,
                 image_dir: Path,
                 df: pd.DataFrame,
                 transforms,
                 size: Tuple[int, int],
                 binarization=False):
        self.df = df
        self.image_dir = image_dir
        self.transforms = transforms
        self.size = size
        self.binarization = binarization

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_id = self.df.loc[idx, "image_id"]
        image_path = self.image_dir / f"{image_id}.png"

        image = cv2.imread(str(image_path))
        if self.binarization:
            image = binarization_and_opening(image)
        longer_side = image.shape[1]
        if image.ndim == 2:
            new_image = np.ones(
                (longer_side, longer_side), dtype=np.uint8) * 255
        else:
            new_image = np.ones(
                (longer_side, longer_side, 3), dtype=np.uint8) * 255
        offset = np.random.randint(0, longer_side - image.shape[0])
        new_image[offset:offset + image.shape[0], :] = image

        if self.transforms is not None:
            image = self.transforms(image=new_image)["image"]
        image = cv2.resize(image, self.size)
        if image.shape[2] == 3:
            image = np.moveaxis(image, -1, 0)
        grapheme = self.df.loc[idx, "grapheme_root"]
        vowel = self.df.loc[idx, "vowel_diacritic"]
        consonant = self.df.loc[idx, "consonant_diacritic"]
        label = np.zeros(3, dtype=int)
        label[0] = grapheme
        label[1] = vowel
        label[2] = consonant
        return {"images": image, "targets": label}
    
    
def get_base_loader(df: pd.DataFrame,
                    image_dir: Path,
                    phase: str = "train",
                    size: Tuple[int, int] = (128, 128),
                    batch_size=256,
                    num_workers=2,
                    transforms=None,
                    binarization=False):
    assert phase in ["train", "valid"]
    if phase == "train":
        is_shuffle = True
        drop_last = True
    else:
        is_shuffle = False
        drop_last = False

    dataset = BaseDataset(  # type: ignore
        image_dir, df, transforms, size, binarization)
    return torchdata.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=drop_last)

## Model and Loss

### Model

In [None]:
def gem(x: torch.Tensor, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p),
                        (x.size(-2), x.size(-1))).pow(1. / p)


def mish(input):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    See additional documentation for mish class.
    '''
    return input * torch.tanh(F.softplus(input))


class Mish(nn.Module):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
    Examples:
        >>> m = Mish()
        >>> input = torch.randn(2)
        >>> output = m(input)
    '''

    def __init__(self):
        '''
        Init method.
        '''
        super().__init__()

    def forward(self, input):
        '''
        Forward pass of the function.
        '''
        return mish(input)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps).squeeze(-1).squeeze(-1)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(
            self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'


class SpatialAttention2d(nn.Module):
    def __init__(self, channel):
        super(SpatialAttention2d, self).__init__()
        self.squeeze = nn.Conv2d(channel, 1, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        z = self.squeeze(x)
        z = self.sigmoid(z)
        return x * z


class GAB(nn.Module):
    def __init__(self, input_dim, reduction=4):
        super(GAB, self).__init__()
        self.global_avgpool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(
            input_dim, input_dim // reduction, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(
            input_dim // reduction, input_dim, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        z = self.global_avgpool(x)
        z = self.relu(self.conv1(z))
        z = self.sigmoid(self.conv2(z))
        return x * z


class SCse(nn.Module):
    def __init__(self, dim):
        super(SCse, self).__init__()
        self.satt = SpatialAttention2d(dim)
        self.catt = GAB(dim)

    def forward(self, x):
        return self.satt(x) + self.catt(x)
    
    
class SEResNext(nn.Module):
    def __init__(self,
                 model_name: str,
                 num_classes: int,
                 pretrained=None,
                 head="linear",
                 in_channels=3):
        super().__init__()
        self.num_classes = num_classes
        self.base = getattr(pretrainedmodels.models,
                            model_name)(pretrained=pretrained)
        self.head = head
        assert in_channels in [1, 3]
        assert head in ["linear", "custom", "scse"]
        if in_channels == 1:
            if pretrained == "imagenet":
                weight = self.base.layer0.conv1.weight
                self.base.layer0.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
                self.base.layer0.conv1.weight = nn.Parameter(
                    data=torch.mean(weight, dim=1, keepdim=True),
                    requires_grad=True)
            else:
                self.base.layer0.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if head == "linear":
            n_in_features = self.base.last_linear.in_features
            self.base.last_linear = nn.Linear(n_in_features, self.num_classes)
        elif head == "custom":
            n_in_features = self.base.last_linear.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(
                    512, 168))
            self.vowel_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(
                    512, 11))
            self.consonant_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 7))
        elif head == "scse":
            n_in_features = self.base.last_linear.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 168))
            self.vowel_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 11))
            self.consonant_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 7))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.head == "linear":
            return self.base(x)
        elif self.head == "custom":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        elif self.head == "scse":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        else:
            raise NotImplementedError


class Resnet(nn.Module):
    def __init__(self,
                 model_name: str,
                 num_classes: int,
                 pretrained=False,
                 head="linear",
                 in_channels=3):
        super().__init__()
        self.num_classes = num_classes
        self.base = getattr(models, model_name)(pretrained=pretrained)
        self.head = head
        assert in_channels in [1, 3]
        assert head in ["linear", "custom", "scse"]
        if in_channels == 1:
            if pretrained:
                weight = self.base.conv1.weight
                self.base.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
                self.base.conv1.weight = nn.Parameter(
                    data=torch.mean(weight, dim=1, keepdim=True),
                    requires_grad=True)
            else:
                self.base.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if head == "linear":
            n_in_features = self.base.fc.in_features
            self.base.fc = nn.Linear(n_in_features, self.num_classes)
        elif head == "custom":
            n_in_features = self.base.fc.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 168))
            self.vowel_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 11))
            self.consonant_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 7))
        elif head == "scse":
            n_in_features = self.base.fc.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 168))
            self.vowel_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 11))
            self.consonant_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 7))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.head == "linear":
            return self.base(x)
        elif self.head == "custom":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        elif self.head == "scse":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        else:
            raise NotImplementedError


def get_model(config: edict):
    params = config.model
    if "resnet" in params.model_name:
        return Resnet(**params)
    elif "se_resnext" in params.model_name:
        return SEResNext(**params)
    else:
        raise NotImplementedError

### Loss

In [None]:
class BengaliCrossEntropyLoss(nn.Module):
    def __init__(self, n_grapheme: int, n_vowel: int, n_consonant: int, weights=[1.0, 1.0, 1.0]):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.cross_entropy = nn.CrossEntropyLoss()
        self.weights = weights

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.weights[0] * self.cross_entropy(grapheme_pred, grapheme_true) + \
            self.weights[1] * self.cross_entropy(vowel_pred, vowel_true) + \
            self.weights[2] * self.cross_entropy(consonant_pred, consonant_true)


class BengaliBCELoss(nn.Module):
    def __init__(self, n_grapheme: int, n_vowel: int, n_consonant: int):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, head:tail]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, head:tail]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, head:tail]

        return self.bce(grapheme_pred, grapheme_true) + \
            self.bce(vowel_pred, vowel_true) + \
            self.bce(consonant_pred, consonant_true)
    
    
class BengaliFocalLoss(nn.Module):
    def __init__(self,
                 n_grapheme: int,
                 n_vowel: int,
                 n_consonant: int,
                 weights=(1.0, 1.0, 1.0)):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.grapheme_focal = FocalLoss(logits=True, n_class=n_grapheme)
        self.vowel_focal = FocalLoss(logits=True, n_class=n_vowel)
        self.consonant_focal = FocalLoss(logits=True, n_class=n_consonant)
        self.weights = weights

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.weights[0] * self.grapheme_focal(
            grapheme_pred, grapheme_true) + \
            self.weights[1] * self.vowel_focal(vowel_pred, vowel_true) + \
            self.weights[2] * self.consonant_focal(
                consonant_pred, consonant_true)
            

class FocalLoss(nn.Module):
    def __init__(self,
                 alpha=1,
                 gamma=2,
                 logits=False,
                 n_class=168,
                 reduction='elementwise_mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduction = reduction
        self.eye = torch.eye(n_class).to(get_device())

    def forward(self, inputs, targets):
        one_hot = self.eye[targets]
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(
                inputs, one_hot, reduction='none')
        else:
            BCE_loss = F.binary_cross_entropy(
                inputs, one_hot, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt)**self.gamma * BCE_loss

        if self.reduction is None:
            return F_loss
        else:
            return torch.mean(F_loss)
    
    
class BengaliMultiMarginLoss(nn.Module):
    def __init__(self,
                 n_grapheme: int,
                 n_vowel: int,
                 n_consonant: int,
                 weights=(1.0, 1.0, 1.0)):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.margin = nn.MultiMarginLoss()
        self.weights = weights

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.weights[0] * self.margin(
            grapheme_pred, grapheme_true) + \
            self.weights[1] * self.margin(vowel_pred, vowel_true) + \
            self.weights[2] * self.margin(
                consonant_pred, consonant_true)
    
    
class OHEMLoss(nn.Module):
    def __init__(self, rate=0.7):
        super().__init__()
        self.rate = rate

    def forward(self, pred, target):
        batch_size = pred.size(0)
        ohem_cls_loss = F.cross_entropy(
            pred, target, reduction="none", ignore_index=-1)

        sorted_ohem_loss, idx = torch.sort(ohem_cls_loss, descending=True)
        keep_num = min(sorted_ohem_loss.size(0), int(batch_size * self.rate))
        if keep_num < sorted_ohem_loss.size(0):
            keep_idx_cuda = idx[:keep_num]
            ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
        cls_loss = ohem_cls_loss.sum() / keep_num
        return cls_loss


class BengaliOHEMLoss(nn.Module):
    def __init__(self,
                 n_grapheme: int,
                 n_vowel: int,
                 n_consonant: int,
                 weights=(1.0, 1.0, 1.0),
                 rate=0.7):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.ohem = OHEMLoss(rate=rate)
        self.weights = weights

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.weights[0] * self.ohem(
            grapheme_pred, grapheme_true) + \
            self.weights[1] * self.ohem(vowel_pred, vowel_true) + \
            self.weights[2] * self.ohem(
                consonant_pred, consonant_true)


def get_loss(config: edict):
    name = config.loss.name
    params = config.loss.params
    if name == "bce":
        criterion = BengaliBCELoss(**params)
    elif name == "cross_entropy":
        criterion = BengaliCrossEntropyLoss(**params)  # type: ignore
    elif name == "margin":
        criterion = BengaliMultiMarginLoss()  # type: ignore
    elif name == "focal":
        criterion = BengaliFocalLoss(**params)  # type: ignore
    elif name == "ohem":
        criterion = BengaliOHEMLoss(**params)  # type: ignore
    else:
        raise NotImplementedError
    return criterion

## Optimizer and Scheduler

### Optimizer

In [None]:
Optimizer = Union[Adam, SGD]


def get_optimizer(model, config: edict) -> Optimizer:
    name = config.optimizer.name
    params = config.optimizer.params
    if name == "Adam":
        optimizer = Adam(model.parameters(), **params)
    elif name == "SGD":
        optimizer = Adam(model.parameters(), **params)
    else:
        raise NotImplementedError
    return optimizer

### Scheduler

In [None]:
Scheduler = Optional[
    Union[ReduceLROnPlateau, CosineAnnealingLR, CosineAnnealingWarmRestarts]]


def get_scheduler(optimizer, config: edict) -> Scheduler:
    params = config.scheduler.params
    name = config.scheduler.name
    scheduler: Scheduler = None
    if name == "plateau":
        scheduler = ReduceLROnPlateau(optimizer, **params)
    elif name == "cosine":
        scheduler = CosineAnnealingLR(optimizer, **params)
    elif name == "cosine_warmup":
        scheduler = CosineAnnealingWarmRestarts(optimizer, **params)

    return scheduler

## Callbacks

In [None]:
class BatchCallback:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.
                                   is_available() else "cpu")

    def on_loader_start(self, state: dict):
        return state

    def on_batch_start(self, state: dict):
        state["batch"]["images"] = state["batch"]["images"].to(self.device)
        state["batch"]["targets"] = state["batch"]["tragets"].to(self.device)
        return state

    def on_batch_end(self, state: dict):
        return state


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


class CriterionCallback(BatchCallback):
    def __init__(self, criterion):
        self.criterion = criterion
        super().__init__()

    def on_loader_start(self, state: dict):
        self.avg_loss = 0.0
        self.n_steps = len(state["loader"])
        return state

    def _calc_loss(self, state: dict):
        batch = state["batch"]
        target = batch["targets"]
        pred = state["pred"]
        loss = self.criterion(pred, target)
        return loss

    def on_batch_end(self, state: dict):
        optimizer = state["optimizer"]
        scheduler = state["scheduler"]
        optimizer.zero_grad()

        loss = self._calc_loss(state)
        loss.backward()
        optimizer.step()
        scheduler.step()
        self.avg_loss += loss.item() / self.n_steps
        state["avg_loss"] = self.avg_loss
        return state


class MixupOrCutmixCallback(CriterionCallback):
    def __init__(self,
                 criterion,
                 alpha=1.0,
                 mixup_prob=0.5,
                 cutmix_prob=0.5,
                 no_aug_epochs=0):
        super().__init__(criterion)

        assert alpha >= 0, "alpha must be>=0"
        assert 1 >= mixup_prob >= 0, "mixup_prob must be within 1 and 0"
        assert 1 >= cutmix_prob >= 0, "cutmix_prob must be within 1 and 0"
        assert 1 >= mixup_prob + cutmix_prob, \
            "sum of mixup_prob and cutmix_prob must be lower than 1"
        self.alpha = alpha
        self.lam = 1
        self.is_needed = True
        self.mixup_prob = mixup_prob
        self.cutmix_prob = cutmix_prob
        self.no_action_prob = 1 - (mixup_prob + cutmix_prob)
        self.no_aug_epochs = no_aug_epochs

    def on_loader_start(self, state: dict):
        state = super().on_loader_start(state)
        if state["num_epochs"] - state["epoch"] <= self.no_aug_epochs:
            self.is_needed = False
        return state

    def on_batch_start(self, state: dict):
        batch = state["batch"]
        if not self.is_needed:
            return state

        dice = np.random.choice([0, 1, 2],
                                p=(self.mixup_prob, self.cutmix_prob,
                                   self.no_action_prob))
        self.dice = dice
        if dice == 0:
            if self.alpha > 0:
                self.lam = np.random.beta(self.alpha, self.alpha)
            else:
                self.lam = 1
            self.index = torch.randperm(batch["images"].shape[0])
            self.index.to(self.device)

            batch["images"] = self.lam * batch["images"] + \
                (1 - self.lam) * batch["images"][self.index]
        elif dice == 1:
            self.index = torch.randperm(batch["images"].shape[0])
            self.index.to(self.device)

            if self.alpha > 0:
                lam = np.random.beta(self.alpha, self.alpha)
            else:
                lam = 1
                bbx1, bby1, bbx2, bby2 = rand_bbox(batch["images"].size(), lam)
                batch["images"][:, :, bbx1:bbx2, bby1:bby2] = \
                    batch["images"][self.index, :, bbx1:bbx2, bby1:bby2]
                self.lam = 1 - (
                    (bbx2 - bbx1) * (bby2 - bby1) /
                    (batch["images"].size()[-1] * batch["images"].size()[-2]))
        else:
            pass
        state["batch"] = batch
        return state

    def _calc_loss(self, state: dict):
        if not self.is_needed:
            return super()._calc_loss(state)

        if self.dice == 0 or self.dice == 1:
            pred = state["pred"]
            y_a = state["batch"]["targets"]
            y_b = state["batch"]["targets"][self.index]
            loss = self.lam * self.criterion(pred, y_a) + \
                (1 - self.lam) * self.criterion(pred, y_b)
            return loss
        else:
            return super()._calc_loss(state)

## Training

In [None]:
def train_one_epoch(model, criterion, optimizer, scheduler, loader,
                    current_epoch, batch_callbacks, num_epochs: int):
    model.train()
    state = {
        "optimizer": optimizer,
        "scheduler": scheduler,
        "loader": loader,
        "epoch": current_epoch,
        "num_epochs": num_epochs
    }
    state = run_callbacks(batch_callbacks, state, "on_loader_start")

    for loader_output in progress_bar(loader, leave=False):
        state["batch"] = loader_output
        state = run_callbacks(batch_callbacks, state, "on_batch_start")
        state["pred"] = model(state["batch"]["images"])
        state = run_callbacks(batch_callbacks, state, "on_batch_end")
    return state


def train(model, criterion, optimizer, scheduler, loader, callbacks,
          num_epochs: int):
    for epoch in range(num_epochs):
        print(f"Epoch: [{epoch + 1}/{num_epochs}]:", end=" ")
        state = train_one_epoch(model, criterion, optimizer, scheduler, loader,
                                epoch + 1, callbacks, num_epochs)
        print(f"avg_loss: {state['avg_loss']:.5f}", end=" ")
        print(f"base_lr: {scheduler.get_lr()[0]:.5f}")
        state_dict = {}
        state_dict["model_state_dict"] = model.state_dict()
        state_dict["optimizer_state_dict"] = optimizer.state_dict()
        state_dict["scheduler_state_dict"] = scheduler.state_dict()
        state_dict["epoch"] = epoch
        torch.save(state_dict, f"/content/gdrive/My Drive/kaggle-bengali/{trial}_latest.pth")
    return model

In [None]:
%%sh
mkdir -p /root/.cache/torch/checkpoints
cp /content/gdrive/My\ Drive/kaggle-bengali/se_resnext50_32x4d-a260b3a4.pth /root/.cache/torch/checkpoints/se_resnext50_32x4d-a260b3a4.pth

In [None]:
data_loader = get_base_loader(
    df,
    train_images_path,
    phase="train",
    size=(config.img_size, config.img_size),
    batch_size=config.train.batch_size,
    num_workers=config.num_workers,
    transforms=transforms)
model = get_model(config)
criterion = get_loss(config)
optimizer = get_optimizer(model, config)
scheduler = get_scheduler(optimizer, config)
callbacks = [
    MixupOrCutmixCallback(criterion,
                          **config.callback.params)
]

if config.weights is not None:
    state_dict = torch.load(config.weights)
    model.load_state_dict(state_dict["model_state_dict"])
    optimizer.load_state_dict(state_dict["optimizer_state_dict"])
    scheduler.load_state_dict(state_dict["scheduler_state_dict"])

model = train(
    model, 
    criterion, 
    optimizer, 
    scheduler, 
    data_loader,
    callbacks,
    config.train.num_epochs)