## Dependencies

In [None]:
%%sh
pip install albumentations==0.4.3 catalyst==20.1.1 easydict==1.9.0 >> /dev/null
pip install efficientnet-pytorch==0.6.1 PyYAML==5.3 >> /dev/null
pip install pretrainedmodels==0.7.4 >> /dev/null

## Integration with Google Drive

In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
%%sh
mkdir input
cp -r /content/gdrive/My\ Drive/kaggle-bengali ./input/bengaliai-cv19
unzip -qq -d input/bengaliai-cv19/ input/bengaliai-cv19/train_images.zip
unzip -qq -d input/bengaliai-cv19/ input/bengaliai-cv19/test_images.zip

## Libraries

In [None]:
import albumentations as A
import catalyst as ct
import cv2
import numpy as np
import pandas as pd
import pretrainedmodels
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata
import torchvision.models as models
import yaml

from pathlib import Path
from typing import Tuple, Dict, Union, Optional, List

from catalyst.dl import SupervisedRunner
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.dl.callbacks import MixUpCallback
from easydict import EasyDict as edict
from efficientnet_pytorch import EfficientNet
from skimage.transform import AffineTransform, warp
from sklearn.metrics import recall_score
from sklearn.model_selection import KFold, train_test_split
from torch.nn.parameter import Parameter
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import (ReduceLROnPlateau, 
                                      CosineAnnealingLR,
                                      CosineAnnealingWarmRestarts)

## Settings

In [None]:
i = 0
trial = "resnet18_init"

## Config

In [None]:
conf_string = '''
dataset:
  train:
    affine: True
    morphology: False
  val:
    affine: False
    morphology: False
  test:
    affine: False
    morphology: False

data:
  train_df_path: input/bengaliai-cv19/train.csv
  train_images_path: input/bengaliai-cv19/train_images
  test_images_path: input/bengaliai-cv19/test_images
  sample_submission_path: input/bengaliai-cv19/sample_submission.csv

model:
  model_name: resnet18
  pretrained: True
  num_classes: 186
  head: custom
  in_channels: 3

train:
  batch_size: 64
  num_epochs: 10

test:
  batch_size: 32

loss:
  name: cross_entropy
  params:
    n_grapheme: 168
    n_vowel: 11
    n_consonant: 7

optimizer:
  name: Adam
  params:
    lr: 0.001

scheduler:
  name: cosine
  params:
    T_max: 10

transforms:
  Noise: False
  Contrast: False
  Rotate: True
  RandomScale: True
  Cutout:
    num_holes: 0

val:
  name: kfold
  params:
    random_state: 42
    n_splits: 5

log_dir: log/
num_workers: 2
seed: 1213
img_size: 128
checkpoints:
mixup: True
'''

In [None]:
def _get_default():
    cfg = edict()

    # dataset
    cfg.dataset = edict()
    cfg.dataset.train = edict()
    cfg.dataset.val = edict()
    cfg.dataset.test = edict()
    cfg.dataset.train.affine = False
    cfg.dataset.train.morphology = False
    cfg.dataset.val.affine = False
    cfg.dataset.val.morphology = False
    cfg.dataset.test.affine = False
    cfg.dataset.test.morphology = False

    # dataset
    cfg.data = edict()

    # model
    cfg.model = edict()
    cfg.model.model_name = "resnet18"
    cfg.model.num_classes = 186
    cfg.model.pretrained = True
    cfg.model.head = "linear"
    cfg.model.in_channels = 3

    # train
    cfg.train = edict()

    # test
    cfg.test = edict()

    # loss
    cfg.loss = edict()
    cfg.loss.params = edict()

    # optimizer
    cfg.optimizer = edict()
    cfg.optimizer.params = edict()

    # scheduler
    cfg.scheduler = edict()
    cfg.scheduler.params = edict()

    # transforms:
    cfg.transforms = edict()
    cfg.transforms.HorizontalFlip = False
    cfg.transforms.VerticalFlip = False
    cfg.transforms.Noise = False
    cfg.transforms.Contrast = False
    cfg.transforms.Rotate = False
    cfg.transforms.RandomScale = False
    cfg.transforms.Cutout = edict()
    cfg.transforms.Cutout.num_holes = 0
    cfg.transforms.mean = [0.485, 0.456, 0.406]
    cfg.transforms.std = [0.229, 0.224, 0.225]

    # val
    cfg.val = edict()
    cfg.val.params = edict()

    return cfg


def _merge_config(src: edict, dst: edict):
    if not isinstance(src, edict):
        return
    for k, v in src.items():
        if isinstance(v, edict):
            _merge_config(src[k], dst[k])
        else:
            dst[k] = v

In [None]:
cfg = edict(yaml.load(conf_string, Loader=yaml.SafeLoader))
config = _get_default()
_merge_config(cfg, config)

## Environmental settings

In [None]:
ct.utils.set_global_seed(config.seed)
ct.utils.prepare_cudnn(deterministic=True)

In [None]:
output_base_dir = Path("output")
output_base_dir.mkdir(exist_ok=True, parents=True)

train_images_path = Path(config.data.train_images_path)

## Data and utilities preparation

### validation utils

In [None]:
def no_fold(df: pd.DataFrame,
            config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    params = config.val.params
    idx = np.arange(len(df))
    trn_idx, val_idx = train_test_split(idx, **params)
    return [(trn_idx, val_idx)]


def kfold(df: pd.DataFrame,
          config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    params = config.val.params
    kf = KFold(shuffle=True, **params)
    splits = list(kf.split(df))
    return splits


def get_validation(df: pd.DataFrame,
                   config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    name: str = config.val.name

    func = globals().get(name)
    if func is None:
        raise NotImplementedError

    return func(df, config)

### transforms

In [None]:
def get_transforms(config: edict):
    list_transforms = []
    if config.transforms.HorizontalFlip:
        list_transforms.append(A.HorizontalFrip())
    if config.transforms.VerticalFlip:
        list_transforms.append(A.VerticalFlip())
    if config.transforms.Rotate:
        list_transforms.append(A.Rotate(limit=15))
    if config.transforms.RandomScale:
        list_transforms.append(A.RandomScale())
    if config.transforms.Noise:
        list_transforms.append(
            A.OneOf(
                [A.GaussNoise(), A.IAAAdditiveGaussianNoise()], p=0.5))
    if config.transforms.Contrast:
        list_transforms.append(
            A.OneOf(
                [A.RandomContrast(0.5),
                 A.RandomGamma(),
                 A.RandomBrightness()],
                p=0.5))
    if config.transforms.Cutout.num_holes > 0:
        list_transforms.append(A.Cutout(**config.Cutout))

    list_transforms.append(
        A.Normalize(
            mean=config.transforms.mean, std=config.transforms.std, p=1))

    return A.Compose(list_transforms, p=1.0)

## Dataset and DataLoader

In [None]:
class BaseDataset(torchdata.Dataset):
    def __init__(self, image_dir: Path, df: pd.DataFrame, transforms,
                 size: Tuple[int, int]):
        self.df = df
        self.image_dir = image_dir
        self.transforms = transforms
        self.size = size

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_id = self.df.loc[idx, "image_id"]
        image_path = self.image_dir / f"{image_id}.png"

        image = cv2.imread(str(image_path))
        # image = crop_resize(image[:, :, 0], self.size)
        # image = np.moveaxis(np.stack([image, image, image]), 0, -1)
        if self.transforms is not None:
            image = self.transforms(image=image)["image"]
        image = cv2.resize(image, self.size)
        if image.shape[2] == 3:
            image = np.moveaxis(image, -1, 0)
        grapheme = self.df.loc[idx, "grapheme_root"]
        vowel = self.df.loc[idx, "vowel_diacritic"]
        consonant = self.df.loc[idx, "consonant_diacritic"]
        label = np.zeros(3, dtype=int)
        label[0] = grapheme
        label[1] = vowel
        label[2] = consonant
        return {"images": image, "targets": label}
    
    
def get_base_loader(df: pd.DataFrame,
                    image_dir: Path,
                    phase: str = "train",
                    size: Tuple[int, int] = (128, 128),
                    batch_size=256,
                    num_workers=2,
                    transforms=None):
    assert phase in ["train", "valid"]
    if phase == "train":
        is_shuffle = True
        drop_last = True
    else:
        is_shuffle = False
        drop_last = False

    dataset = BaseDataset(  # type: ignore
        image_dir, df, transforms, size)
    return torchdata.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=drop_last)

## Model and Loss

### Model

In [None]:
def gem(x: torch.Tensor, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p),
                        (x.size(-2), x.size(-1))).pow(1. / p)


def mish(input):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    See additional documentation for mish class.
    '''
    return input * torch.tanh(F.softplus(input))


class Mish(nn.Module):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
    Examples:
        >>> m = Mish()
        >>> input = torch.randn(2)
        >>> output = m(input)
    '''

    def __init__(self):
        '''
        Init method.
        '''
        super().__init__()

    def forward(self, input):
        '''
        Forward pass of the function.
        '''
        return mish(input)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps).squeeze(-1).squeeze(-1)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(
            self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'


class Resnet(nn.Module):
    def __init__(self,
                 model_name: str,
                 num_classes: int,
                 pretrained=False,
                 head="linear",
                 in_channels=3):
        super().__init__()
        self.num_classes = num_classes
        self.base = getattr(models, model_name)(pretrained=pretrained)
        self.head = head
        assert in_channels in [1, 3]
        assert head in ["linear", "custom"]
        if in_channels == 1:
            if pretrained:
                weight = self.base.conv1.weight
                self.base.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
                self.base.conv1.weight = nn.Parameter(
                    data=torch.mean(weight, dim=1, keepdim=True),
                    requires_grad=True)
            else:
                self.base.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if head == "linear":
            n_in_features = self.base.fc.in_features
            self.base.fc = nn.Linear(n_in_features, self.num_classes)
        elif head == "custom":
            n_in_features = self.base.fc.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 168))
            self.vowel_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 11))
            self.consonant_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 7))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.head == "linear":
            return self.base(x)
        elif self.head == "custom":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        else:
            raise NotImplementedError


def get_model(config: edict):
    params = config.model
    if "resnet" in params.model_name:
        return Resnet(**params)
    else:
        raise NotImplementedError

### Loss

In [None]:
class BengaliCrossEntropyLoss(nn.Module):
    def __init__(self, n_grapheme: int, n_vowel: int, n_consonant: int):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.cross_entropy(grapheme_pred, grapheme_true) + \
            self.cross_entropy(vowel_pred, vowel_true) + \
            self.cross_entropy(consonant_pred, consonant_true)


class BengaliBCELoss(nn.Module):
    def __init__(self, n_grapheme: int, n_vowel: int, n_consonant: int):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, head:tail]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, head:tail]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, head:tail]

        return self.bce(grapheme_pred, grapheme_true) + \
            self.bce(vowel_pred, vowel_true) + \
            self.bce(consonant_pred, consonant_true)


def get_loss(config: edict):
    name = config.loss.name
    params = config.loss.params
    if name == "bce":
        criterion = BengaliBCELoss(**params)
    elif name == "cross_entropy":
        criterion = BengaliCrossEntropyLoss(**params)  # type: ignore
    else:
        raise NotImplementedError
    return criterion

## Optimizer and Scheduler

### Optimizer

In [None]:
Optimizer = Union[Adam, SGD]


def get_optimizer(model, config: edict) -> Optimizer:
    name = config.optimizer.name
    params = config.optimizer.params
    if name == "Adam":
        optimizer = Adam(model.parameters(), **params)
    elif name == "SGD":
        optimizer = Adam(model.parameters(), **params)
    else:
        raise NotImplementedError
    return optimizer

### Scheduler

In [None]:
Scheduler = Optional[
    Union[ReduceLROnPlateau, CosineAnnealingLR, CosineAnnealingWarmRestarts]]


def get_scheduler(optimizer, config: edict) -> Scheduler:
    params = config.scheduler.params
    name = config.scheduler.name
    scheduler: Scheduler = None
    if name == "plateau":
        scheduler = ReduceLROnPlateau(optimizer, **params)
    elif name == "cosine":
        scheduler = CosineAnnealingLR(optimizer, **params)
    elif name == "cosine_warmup":
        scheduler = CosineAnnealingWarmRestarts(optimizer, **params)

    return scheduler

## Callbacks

In [None]:
class MacroAverageRecall(Callback):
    def __init__(self,
                 n_grapheme=168,
                 n_vowel=11,
                 n_consonant=7,
                 loss_type: str = "bce",
                 prefix: str = "mar",
                 output_key: str = "logits",
                 target_key: str = "targets"):
        self.prefix = prefix
        self.output_key = output_key
        self.target_key = target_key
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.loss_type = loss_type
        super().__init__(CallbackOrder.Metric)

    def on_batch_end(self, state: RunnerState):
        targ = state.input[self.target_key].detach()
        out = state.output[self.output_key]
        head = 0
        tail = self.n_grapheme
        grapheme = torch.sigmoid(out[:, head:tail])
        grapheme_np = torch.argmax(grapheme, dim=1).detach().cpu().numpy()
        if self.loss_type == "bce":
            grapheme_target = torch.argmax(
                targ[:, head:tail], dim=1).cpu().numpy()
        else:
            grapheme_target = targ[:, 0].cpu().numpy()

        head = tail
        tail = head + self.n_vowel
        vowel = torch.sigmoid(out[:, head:tail])
        vowel_np = torch.argmax(vowel, dim=1).detach().cpu().numpy()
        if self.loss_type == "bce":
            vowel_target = torch.argmax(
                targ[:, head:tail], dim=1).cpu().numpy()
        else:
            vowel_target = targ[:, 1].cpu().numpy()

        head = tail
        tail = head + self.n_consonant
        consonant = torch.sigmoid(out[:, head:tail])
        consonant_np = torch.argmax(consonant, dim=1).detach().cpu().numpy()
        if self.loss_type == "bce":
            consonant_target = torch.argmax(
                targ[:, head:tail], dim=1).cpu().numpy()
        else:
            consonant_target = targ[:, 2].cpu().numpy()

        scores = []
        scores.append(
            recall_score(
                grapheme_target, grapheme_np, average="macro",
                zero_division=0))
        scores.append(
            recall_score(
                vowel_target, vowel_np, average="macro", zero_division=0))
        scores.append(
            recall_score(
                consonant_target,
                consonant_np,
                average="macro",
                zero_division=0))
        final_score = np.average(scores, weights=[2, 1, 1])
        state.metrics.add_batch_value(name=self.prefix, value=final_score)


class SaveWeightsCallback(Callback):
    def __init__(self, to: Optional[Path] = None, name: str=""):
        self.to = to
        self.name = name
        super().__init__(CallbackOrder.External)

    def on_epoch_end(self, state: RunnerState):
        weights = state.model.state_dict()
        logdir = state.logdir / "checkpoints"
        logdir.mkdir(exist_ok=True, parents=True)
        if self.name == "":
            torch.save(weights, logdir / "temp.pth")
        else:
            torch.save(weights, logdir / f"{self.name}.pth")

        if self.to is not None:
            if self.name == "":
                torch.save(weights, self.to / "temp.pth")
            else:
                torch.save(weights, self.to / f"{self.name}.pth")

## Training

In [None]:
trn_idx, val_idx = splits[i]

print(f"Fold: {i}")

output_dir = output_base_dir / f"fold{i}"
output_dir.mkdir(exist_ok=True, parents=True)

trn_df = df.loc[trn_idx, :].reset_index(drop=True)
val_df = df.loc[val_idx, :].reset_index(drop=True)
data_loaders = {
    phase: get_base_loader(
        df,
        train_images_path,
        phase=phase,
        size=(config.img_size, config.img_size),
        batch_size=config.train.batch_size,
        num_workers=config.num_workers,
        transforms=transforms)
    for phase, df in zip(["train", "valid"], [trn_df, val_df])
}
model = get_model(config)
criterion = get_loss(config)
optimizer = get_optimizer(model, config)
scheduler = get_scheduler(optimizer, config)
callbacks = [
    MacroAverageRecall(
        n_grapheme=cls_levels["grapheme"],
        n_vowel=cls_levels["vowel"],
        n_consonant=cls_levels["consonant"],
        loss_type=config.loss.name),
    SaveWeightsCallback(
        to=Path(config.checkpoints
                ) if config.checkpoints is not None else None,
        name=trial)
]

if config.mixup:
    callbacks.append(MixupCallback(fields=[
        "images",
    ]))

runner = SupervisedRunner(
    device=ct.utils.get_device(),
    input_key="images",
    input_target_key="targets",
    output_key="logits")
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=data_loaders,
    logdir=output_dir,
    scheduler=scheduler,
    num_epochs=config.train.num_epochs,
    callbacks=callbacks,
    main_metric="mar",
    minimize_metric=False,
    monitoring_params=None,
    verbose=False)