<a href="https://colab.research.google.com/github/koukyo1994/kaggle-bengali-ai/blob/master/notebook/Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dependencies

In [0]:
%%sh
pip install albumentations==0.4.3 catalyst==20.1.1 easydict==1.9.0 >> /dev/null
pip install efficientnet-pytorch==0.6.1 PyYAML==5.3 >> /dev/null
pip install pretrainedmodels==0.7.4 >> /dev/null

## Integration with Google Drive

In [2]:
from google.colab import drive
drive.mount("/content/gdrive")

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [4]:
%%sh
mkdir input
cp -r /content/gdrive/My\ Drive/kaggle-bengali ./input/bengaliai-cv19
unzip -qq -d input/bengaliai-cv19/ input/bengaliai-cv19/train_images.zip

cp: cannot open '/content/gdrive/My Drive/kaggle-bengali/Kaggle bengaliai-cv19.gsheet' for reading: Operation not supported


In [1]:
!nvidia-smi

Fri Feb 28 05:46:44 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.48.02    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

## Libraries

In [1]:
import albumentations as A
import catalyst as ct
import cv2
import numpy as np
import pandas as pd
import pretrainedmodels
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata
import torchvision.models as models
import yaml

from pathlib import Path
from typing import Tuple, Dict, Union, Optional, List

from catalyst.dl import SupervisedRunner
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.dl.callbacks import CriterionCallback, OptimizerCallback
from catalyst.utils import get_device
from easydict import EasyDict as edict
from efficientnet_pytorch import EfficientNet
from fastprogress import progress_bar
from skimage.transform import AffineTransform, warp
from sklearn.metrics import recall_score, confusion_matrix
from sklearn.model_selection import KFold, train_test_split
from torch.nn.parameter import Parameter
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import (ReduceLROnPlateau, 
                                      CosineAnnealingLR,
                                      CosineAnnealingWarmRestarts)

alchemy not available, to install alchemy, run `pip install alchemy-catalyst`.


## Settings

In [0]:
i = 0
stage = 1
trial = "seresnext_size236_90epoch_cutout"

## Config

In [0]:
conf_string = '''
dataset:
  train:
    affine: False
    morphology: False
    binarization: False
  val:
    affine: False
    morphology: False
    binarization: False
  test:
    affine: False
    morphology: False

data:
  train_df_path: input/bengaliai-cv19/train.csv
  train_images_path: input/bengaliai-cv19/train_images
  test_images_path: input/bengaliai-cv19/test_images
  sample_submission_path: input/bengaliai-cv19/sample_submission.csv

model:
  model_name: se_resnext50_32x4d
  pretrained: imagenet
  num_classes: 186
  head: custom
  in_channels: 3

train:
  batch_size: 64
  num_epochs: 15

test:
  batch_size: 100

loss:
  name: ohem
  params:
    n_grapheme: 168
    n_vowel: 11
    n_consonant: 7
    weights:
      - 2.0
      - 1.0
      - 1.0

optimizer:
  name: Adam
  params:
    lr: 0.00005

scheduler:
  name: cosine
  params:
    T_max: 10

transforms:
  train:
    Noise: False
    Contrast: False
    Rotate: True
    RandomScale: True
    Cutout:
      num_holes: 1
      max_h_size: 50
      max_w_size: 75
      fill_value: 255
      always_apply: True
    ShiftScaleRotate: False
    RandomResizedCrop: False
    CoarseDropout: False
    GridDistortion: False
  val:
    Noise: False
    Contrast: False
    Rotate: False
    RandomScale: False
    Cutout:
      num_holes: 0
  test:
    Noise: False
    Contrast: False
    Rotate: False
    RandomScale: False
    Cutout:
      num_holes: 0

val:
  name: kfold
  params:
    random_state: 42
    n_splits: 5

callbacks:
  - AverageRecall:
      index: 0
      offset: 0
      n_classes: 168
      prefix: grapheme_recall
      loss_type: cross_entroy
  - AverageRecall:
      index: 1
      offset: 168
      n_classes: 11
      prefix: vowel_recall
      loss_type: cross_entropy
  - AverageRecall:
      index: 2
      offset: 179
      n_classes: 7
      prefix: consonant_recall
      loss_type: cross_entropy
  - TotalAverageRecall:
      loss_type: cross_entropy
  - SaveWeightsCallback:
      to: /content/gdrive/My Drive/kaggle-bengali/checkpoints/fold{}/
      name: {}
      is_larger_better: True
      main_metric: tar
      save_optimizer_state: True
  - SaveWeightsCallback:
      to: /content/gdrive/My Drive/kaggle-bengali/checkpoints/fold{}/
      name: {}
      is_larger_better: True
      main_metric: grapheme_recall
      save_optimizer_state: False
      suffix: _grapheme
  - SaveWeightsCallback:
      to: /content/gdrive/My Drive/kaggle-bengali/checkpoints/fold{}/
      name: {}
      is_larger_better: True
      main_metric: vowel_recall
      save_optimizer_state: False
      suffix: _vowel
  - SaveWeightsCallback:
      to: /content/gdrive/My Drive/kaggle-bengali/checkpoints/fold{}/
      name: {}
      is_larger_better: True
      main_metric: consonant_recall
      save_optimizer_state: False
      suffix: _consonant
  - MixupOrCutmixCallback:
      mixup_prob: 0.5
      cutmix_prob: 0.0
      no_aug_epochs: 0

log_dir: log/
num_workers: 2
seed: 1213
img_size: 236
weights: 
'''.format(i, trial, i, trial, i, trial, i, trial)

In [0]:
def _get_default():
    cfg = edict()

    # dataset
    cfg.dataset = edict()
    cfg.dataset.train = edict()
    cfg.dataset.val = edict()
    cfg.dataset.test = edict()
    cfg.dataset.train.affine = False
    cfg.dataset.train.morphology = False
    cfg.dataset.train.binarization = False
    cfg.dataset.val.affine = False
    cfg.dataset.val.morphology = False
    cfg.dataset.val.binarization = False
    cfg.dataset.test.affine = False
    cfg.dataset.test.morphology = False
    cfg.dataset.test.binarization = False

    # dataset
    cfg.data = edict()

    # model
    cfg.model = edict()
    cfg.model.model_name = "resnet18"
    cfg.model.num_classes = 186
    cfg.model.pretrained = True
    cfg.model.head = "linear"
    cfg.model.in_channels = 3
    # cfg.model.outputs = ["grapheme", "vowel", "consonant"]

    # train
    cfg.train = edict()

    # test
    cfg.test = edict()

    # loss
    cfg.loss = edict()
    cfg.loss.params = edict()

    # optimizer
    cfg.optimizer = edict()
    cfg.optimizer.params = edict()

    # scheduler
    cfg.scheduler = edict()
    cfg.scheduler.params = edict()

    # transforms:
    cfg.transforms = edict()
    cfg.transforms.train = edict()
    cfg.transforms.train.HorizontalFlip = False
    cfg.transforms.train.VerticalFlip = False
    cfg.transforms.train.Noise = False
    cfg.transforms.train.Contrast = False
    cfg.transforms.train.Rotate = False
    cfg.transforms.train.RandomScale = False
    cfg.transforms.train.Cutout = edict()
    cfg.transforms.train.Cutout.num_holes = 0
    cfg.transforms.train.ShiftScaleRotate = False
    cfg.transforms.train.RandomResizedCrop = False
    cfg.transforms.train.CoarseDropout = False
    cfg.transforms.train.GridDistortion = False
    cfg.transforms.val = edict()
    cfg.transforms.val.HorizontalFlip = False
    cfg.transforms.val.VerticalFlip = False
    cfg.transforms.val.Noise = False
    cfg.transforms.val.Contrast = False
    cfg.transforms.val.Rotate = False
    cfg.transforms.val.RandomScale = False
    cfg.transforms.val.Cutout = edict()
    cfg.transforms.val.Cutout.num_holes = 0
    cfg.transforms.val.ShiftScaleRotate = False
    cfg.transforms.val.RandomResizedCrop = False
    cfg.transforms.val.CoarseDropout = False
    cfg.transforms.val.GridDistortion = False
    cfg.transforms.test = edict()
    cfg.transforms.test.HorizontalFlip = False
    cfg.transforms.test.VerticalFlip = False
    cfg.transforms.test.Noise = False
    cfg.transforms.test.Contrast = False
    cfg.transforms.test.Rotate = False
    cfg.transforms.test.RandomScale = False
    cfg.transforms.test.Cutout = edict()
    cfg.transforms.test.Cutout.num_holes = 0
    cfg.transforms.test.ShiftScaleRotate = False
    cfg.transforms.test.RandomResizedCrop = False
    cfg.transforms.test.CoarseDropout = False
    cfg.transforms.test.GridDistortion = False
    cfg.transforms.mean = [0.485, 0.456, 0.406]
    cfg.transforms.std = [0.229, 0.224, 0.225]

    # val
    cfg.val = edict()
    cfg.val.params = edict()

    cfg.callbacks = []

    return cfg


def _merge_config(src: edict, dst: edict):
    if not isinstance(src, edict):
        return
    for k, v in src.items():
        if isinstance(v, edict):
            _merge_config(src[k], dst[k])
        else:
            dst[k] = v

In [0]:
cfg = edict(yaml.load(conf_string, Loader=yaml.SafeLoader))
config = _get_default()
_merge_config(cfg, config)

## Environmental settings

In [6]:
ct.utils.set_global_seed(config.seed + stage)
ct.utils.prepare_cudnn(deterministic=True)

In [0]:
output_base_dir = Path("output")
output_base_dir.mkdir(exist_ok=True, parents=True)

train_images_path = Path(config.data.train_images_path)

## Data and utilities preparation

### validation utils

In [0]:
def no_fold(df: pd.DataFrame,
            config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    params = config.val.params
    idx = np.arange(len(df))
    trn_idx, val_idx = train_test_split(idx, **params)
    return [(trn_idx, val_idx)]


def kfold(df: pd.DataFrame,
          config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    params = config.val.params
    kf = KFold(shuffle=True, **params)
    splits = list(kf.split(df))
    return splits


def get_validation(df: pd.DataFrame,
                   config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    name: str = config.val.name

    func = globals().get(name)
    if func is None:
        raise NotImplementedError

    return func(df, config)

### transforms

In [0]:
def get_transforms(config: edict, phase: str = "train"):
    assert phase in ["train", "valid", "test"]
    if phase == "train":
        cfg = config.transforms.train
    elif phase == "valid":
        cfg = config.transforms.val
    elif phase == "test":
        cfg = config.transforms.test
    list_transforms = []
    if cfg.HorizontalFlip:
        list_transforms.append(A.HorizontalFrip())
    if cfg.VerticalFlip:
        list_transforms.append(A.VerticalFlip())
    if cfg.Rotate:
        list_transforms.append(A.Rotate(limit=15))
    if cfg.RandomScale:
        list_transforms.append(A.RandomScale())
    if cfg.Noise:
        list_transforms.append(
            A.OneOf(
                [A.GaussNoise(), A.IAAAdditiveGaussianNoise()], p=0.5))
    if cfg.Contrast:
        list_transforms.append(
            A.OneOf(
                [A.RandomContrast(0.5),
                 A.RandomGamma(),
                 A.RandomBrightness()],
                p=0.5))
    if cfg.Cutout.num_holes > 0:
        list_transforms.append(A.Cutout(**cfg.Cutout))
    if cfg.ShiftScaleRotate:
        list_transforms.append(
            A.ShiftScaleRotate(
                shift_limit=0.0625, scale_limit=0, rotate_limit=7, p=0.5))
    if cfg.RandomResizedCrop:
        list_transforms.append(
            A.RandomResizedCrop(128, 128, scale=(0.8, 1), p=0.5))
    if cfg.CoarseDropout:
        list_transforms.append(
            A.CoarseDropout(max_holes=8, max_height=8, max_width=8, p=0.2))
    if cfg.GridDistortion:
        list_transforms.append(A.GridDistortion(p=0.2))

    list_transforms.append(
        A.Normalize(
            mean=config.transforms.mean,
            std=config.transforms.std,
            p=1,
            always_apply=True))

    return A.Compose(list_transforms, p=1.0)

### Data Loading

In [0]:
df = pd.read_csv(config.data.train_df_path)
splits = get_validation(df, config)

transforms_dict = {
    phase: get_transforms(config, phase)
    for phase in ["train", "valid"]
}

cls_levels = {
    "grapheme": df.grapheme_root.nunique(),
    "vowel": df.vowel_diacritic.nunique(),
    "consonant": df.consonant_diacritic.nunique()
}

## Dataset and DataLoader

In [0]:
class BaseDataset(torchdata.Dataset):
    def __init__(self,
                 image_dir: Path,
                 df: pd.DataFrame,
                 transforms,
                 size: Tuple[int, int],
                 binarization=False):
        self.df = df
        self.image_dir = image_dir
        self.transforms = transforms
        self.size = size
        self.binarization = binarization

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_id = self.df.loc[idx, "image_id"]
        image_path = self.image_dir / f"{image_id}.png"

        image = cv2.imread(str(image_path))
        if self.binarization:
            image = binarization_and_opening(image)
        longer_side = image.shape[1]
        if image.ndim == 2:
            new_image = np.ones(
                (longer_side, longer_side), dtype=np.uint8) * 255
        else:
            new_image = np.ones(
                (longer_side, longer_side, 3), dtype=np.uint8) * 255
        offset = np.random.randint(0, longer_side - image.shape[0])
        new_image[offset:offset + image.shape[0], :] = image

        if self.transforms is not None:
            image = self.transforms(image=new_image)["image"]
        image = cv2.resize(image, self.size)
        if image.shape[2] == 3:
            image = np.moveaxis(image, -1, 0)
        grapheme = self.df.loc[idx, "grapheme_root"]
        vowel = self.df.loc[idx, "vowel_diacritic"]
        consonant = self.df.loc[idx, "consonant_diacritic"]
        label = np.zeros(3, dtype=int)
        label[0] = grapheme
        label[1] = vowel
        label[2] = consonant
        return {"images": image, "targets": label}
    
    
def get_base_loader(df: pd.DataFrame,
                    image_dir: Path,
                    phase: str = "train",
                    size: Tuple[int, int] = (128, 128),
                    batch_size=256,
                    num_workers=2,
                    transforms=None,
                    binarization=False):
    assert phase in ["train", "valid"]
    if phase == "train":
        is_shuffle = True
        drop_last = True
    else:
        is_shuffle = False
        drop_last = False

    dataset = BaseDataset(  # type: ignore
        image_dir, df, transforms, size, binarization)
    return torchdata.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=drop_last)

## Model and Loss

### Model

In [0]:
def gem(x: torch.Tensor, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p),
                        (x.size(-2), x.size(-1))).pow(1. / p)


def mish(input):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    See additional documentation for mish class.
    '''
    return input * torch.tanh(F.softplus(input))


class Mish(nn.Module):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
    Examples:
        >>> m = Mish()
        >>> input = torch.randn(2)
        >>> output = m(input)
    '''

    def __init__(self):
        '''
        Init method.
        '''
        super().__init__()

    def forward(self, input):
        '''
        Forward pass of the function.
        '''
        return mish(input)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps).squeeze(-1).squeeze(-1)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(
            self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'


class SpatialAttention2d(nn.Module):
    def __init__(self, channel):
        super(SpatialAttention2d, self).__init__()
        self.squeeze = nn.Conv2d(channel, 1, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        z = self.squeeze(x)
        z = self.sigmoid(z)
        return x * z


class GAB(nn.Module):
    def __init__(self, input_dim, reduction=4):
        super(GAB, self).__init__()
        self.global_avgpool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(
            input_dim, input_dim // reduction, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(
            input_dim // reduction, input_dim, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        z = self.global_avgpool(x)
        z = self.relu(self.conv1(z))
        z = self.sigmoid(self.conv2(z))
        return x * z


class SCse(nn.Module):
    def __init__(self, dim):
        super(SCse, self).__init__()
        self.satt = SpatialAttention2d(dim)
        self.catt = GAB(dim)

    def forward(self, x):
        return self.satt(x) + self.catt(x)
    
    
class SEResNext(nn.Module):
    def __init__(self,
                 model_name: str,
                 num_classes: int,
                 pretrained=None,
                 head="linear",
                 in_channels=3):
        super().__init__()
        self.num_classes = num_classes
        self.base = getattr(pretrainedmodels.models,
                            model_name)(pretrained=pretrained)
        self.head = head
        assert in_channels in [1, 3]
        assert head in ["linear", "custom", "scse"]
        if in_channels == 1:
            if pretrained == "imagenet":
                weight = self.base.layer0.conv1.weight
                self.base.layer0.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
                self.base.layer0.conv1.weight = nn.Parameter(
                    data=torch.mean(weight, dim=1, keepdim=True),
                    requires_grad=True)
            else:
                self.base.layer0.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if head == "linear":
            n_in_features = self.base.last_linear.in_features
            self.base.last_linear = nn.Linear(n_in_features, self.num_classes)
        elif head == "custom":
            n_in_features = self.base.last_linear.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(
                    512, 168))
            self.vowel_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(
                    512, 11))
            self.consonant_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 7))
        elif head == "scse":
            n_in_features = self.base.last_linear.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 168))
            self.vowel_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 11))
            self.consonant_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 7))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.head == "linear":
            return self.base(x)
        elif self.head == "custom":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        elif self.head == "scse":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        else:
            raise NotImplementedError


class Resnet(nn.Module):
    def __init__(self,
                 model_name: str,
                 num_classes: int,
                 pretrained=False,
                 head="linear",
                 in_channels=3):
        super().__init__()
        self.num_classes = num_classes
        self.base = getattr(models, model_name)(pretrained=pretrained)
        self.head = head
        assert in_channels in [1, 3]
        assert head in ["linear", "custom", "scse"]
        if in_channels == 1:
            if pretrained:
                weight = self.base.conv1.weight
                self.base.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
                self.base.conv1.weight = nn.Parameter(
                    data=torch.mean(weight, dim=1, keepdim=True),
                    requires_grad=True)
            else:
                self.base.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if head == "linear":
            n_in_features = self.base.fc.in_features
            self.base.fc = nn.Linear(n_in_features, self.num_classes)
        elif head == "custom":
            n_in_features = self.base.fc.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 168))
            self.vowel_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 11))
            self.consonant_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 7))
        elif head == "scse":
            n_in_features = self.base.fc.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 168))
            self.vowel_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 11))
            self.consonant_head = nn.Sequential(
                SCse(n_in_features), Mish(), nn.BatchNorm2d(512), GeM(),
                nn.Dropout(0.3), nn.Linear(512, 7))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.head == "linear":
            return self.base(x)
        elif self.head == "custom":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        elif self.head == "scse":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        else:
            raise NotImplementedError


def get_model(config: edict):
    params = config.model
    if "resnet" in params.model_name:
        return Resnet(**params)
    elif "se_resnext" in params.model_name:
        return SEResNext(**params)
    else:
        raise NotImplementedError

### Loss

In [0]:
class BengaliCrossEntropyLoss(nn.Module):
    def __init__(self, n_grapheme: int, n_vowel: int, n_consonant: int, weights=[1.0, 1.0, 1.0]):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.cross_entropy = nn.CrossEntropyLoss()
        self.weights = weights

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.weights[0] * self.cross_entropy(grapheme_pred, grapheme_true) + \
            self.weights[1] * self.cross_entropy(vowel_pred, vowel_true) + \
            self.weights[2] * self.cross_entropy(consonant_pred, consonant_true)


class BengaliBCELoss(nn.Module):
    def __init__(self, n_grapheme: int, n_vowel: int, n_consonant: int):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, head:tail]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, head:tail]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, head:tail]

        return self.bce(grapheme_pred, grapheme_true) + \
            self.bce(vowel_pred, vowel_true) + \
            self.bce(consonant_pred, consonant_true)
    
    
class BengaliFocalLoss(nn.Module):
    def __init__(self,
                 n_grapheme: int,
                 n_vowel: int,
                 n_consonant: int,
                 weights=(1.0, 1.0, 1.0)):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.grapheme_focal = FocalLoss(logits=True, n_class=n_grapheme)
        self.vowel_focal = FocalLoss(logits=True, n_class=n_vowel)
        self.consonant_focal = FocalLoss(logits=True, n_class=n_consonant)
        self.weights = weights

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.weights[0] * self.grapheme_focal(
            grapheme_pred, grapheme_true) + \
            self.weights[1] * self.vowel_focal(vowel_pred, vowel_true) + \
            self.weights[2] * self.consonant_focal(
                consonant_pred, consonant_true)
            

class FocalLoss(nn.Module):
    def __init__(self,
                 alpha=1,
                 gamma=2,
                 logits=False,
                 n_class=168,
                 reduction='elementwise_mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduction = reduction
        self.eye = torch.eye(n_class).to(get_device())

    def forward(self, inputs, targets):
        one_hot = self.eye[targets]
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(
                inputs, one_hot, reduction='none')
        else:
            BCE_loss = F.binary_cross_entropy(
                inputs, one_hot, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt)**self.gamma * BCE_loss

        if self.reduction is None:
            return F_loss
        else:
            return torch.mean(F_loss)
    
    
class BengaliMultiMarginLoss(nn.Module):
    def __init__(self,
                 n_grapheme: int,
                 n_vowel: int,
                 n_consonant: int,
                 weights=(1.0, 1.0, 1.0)):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.margin = nn.MultiMarginLoss()
        self.weights = weights

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.weights[0] * self.margin(
            grapheme_pred, grapheme_true) + \
            self.weights[1] * self.margin(vowel_pred, vowel_true) + \
            self.weights[2] * self.margin(
                consonant_pred, consonant_true)
    
    
class OHEMLoss(nn.Module):
    def __init__(self, rate=0.7):
        super().__init__()
        self.rate = rate

    def forward(self, pred, target):
        batch_size = pred.size(0)
        ohem_cls_loss = F.cross_entropy(
            pred, target, reduction="none", ignore_index=-1)

        sorted_ohem_loss, idx = torch.sort(ohem_cls_loss, descending=True)
        keep_num = min(sorted_ohem_loss.size(0), int(batch_size * self.rate))
        if keep_num < sorted_ohem_loss.size(0):
            keep_idx_cuda = idx[:keep_num]
            ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
        cls_loss = ohem_cls_loss.sum() / keep_num
        return cls_loss


class BengaliOHEMLoss(nn.Module):
    def __init__(self,
                 n_grapheme: int,
                 n_vowel: int,
                 n_consonant: int,
                 weights=(1.0, 1.0, 1.0),
                 rate=0.7):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.ohem = OHEMLoss(rate=rate)
        self.weights = weights

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.weights[0] * self.ohem(
            grapheme_pred, grapheme_true) + \
            self.weights[1] * self.ohem(vowel_pred, vowel_true) + \
            self.weights[2] * self.ohem(
                consonant_pred, consonant_true)


def get_loss(config: edict):
    name = config.loss.name
    params = config.loss.params
    if name == "bce":
        criterion = BengaliBCELoss(**params)
    elif name == "cross_entropy":
        criterion = BengaliCrossEntropyLoss(**params)  # type: ignore
    elif name == "margin":
        criterion = BengaliMultiMarginLoss()  # type: ignore
    elif name == "focal":
        criterion = BengaliFocalLoss(**params)  # type: ignore
    elif name == "ohem":
        criterion = BengaliOHEMLoss(**params)  # type: ignore
    else:
        raise NotImplementedError
    return criterion

## Optimizer and Scheduler

### Optimizer

In [0]:
Optimizer = Union[Adam, SGD]


def get_optimizer(model, config: edict) -> Optimizer:
    name = config.optimizer.name
    params = config.optimizer.params
    if name == "Adam":
        optimizer = Adam(model.parameters(), **params)
    elif name == "SGD":
        optimizer = Adam(model.parameters(), **params)
    else:
        raise NotImplementedError
    return optimizer

### Scheduler

In [0]:
Scheduler = Optional[
    Union[ReduceLROnPlateau, CosineAnnealingLR, CosineAnnealingWarmRestarts]]


def get_scheduler(optimizer, config: edict) -> Scheduler:
    params = config.scheduler.params
    name = config.scheduler.name
    scheduler: Scheduler = None
    if name == "plateau":
        scheduler = ReduceLROnPlateau(optimizer, **params)
    elif name == "cosine":
        scheduler = CosineAnnealingLR(optimizer, **params)
    elif name == "cosine_warmup":
        scheduler = CosineAnnealingWarmRestarts(optimizer, **params)

    return scheduler

## Callbacks

In [0]:
class AverageRecall(Callback):
    def __init__(self,
                 index: int,
                 offset: int,
                 n_classes: int,
                 prefix: str,
                 loss_type: str = "bce",
                 output_key: str = "logits",
                 target_key: str = "targets"):
        self.index = index
        self.offset = offset
        self.n_classes = n_classes
        self.prefix = prefix
        self.loss_type = loss_type
        self.output_key = output_key
        self.target_key = target_key
        self.recall = 0.0
        super().__init__(CallbackOrder.Metric)

    def on_loader_start(self, state: RunnerState):
        self.prediction: List[int] = []
        self.target: List[int] = []

    def on_batch_end(self, state: RunnerState):
        targ = state.input[self.target_key].detach()
        out = state.output[self.output_key].detach()
        head = self.offset
        tail = self.offset + self.n_classes
        if self.loss_type == "bce":
            pred_np = torch.argmax(
                torch.sigmoid(out[:, head:tail]), dim=1).cpu().numpy()
            target_np = torch.argmax(targ[:, head:tail], dim=1).cpu().numpy()
        else:
            pred_np = torch.argmax(out[:, head:tail], dim=1).cpu().numpy()
            target_np = targ[:, self.index].cpu().numpy()
        self.prediction.extend(pred_np)
        self.target.extend(target_np)
        score = recall_score(
            target_np, pred_np, average="macro", zero_division=0)
        state.metrics.add_batch_value(name="batch_" + self.prefix, value=score)

    def on_loader_end(self, state: RunnerState):
        metric_name = self.prefix
        y_true = np.asarray(self.target)
        y_pred = np.asarray(self.prediction)

        metric = recall_score(y_true, y_pred, average="macro")
        state.metrics.epoch_values[state.loader_name][metric_name] = float(
            metric)
        self.recall = metric


class TotalAverageRecall(Callback):
    def __init__(self,
                 n_grapheme=168,
                 n_vowel=11,
                 n_consonant=7,
                 loss_type: str = "bce",
                 prefix: str = "tar",
                 output_key: str = "logits",
                 target_key: str = "targets"):
        self.prefix = prefix
        self.grapheme_callback = AverageRecall(
            index=0,
            offset=0,
            n_classes=n_grapheme,
            prefix="grapheme_recall",
            loss_type=loss_type,
            output_key=output_key,
            target_key=target_key)
        self.vowel_callback = AverageRecall(
            index=1,
            offset=n_grapheme,
            n_classes=n_vowel,
            prefix="vowel_recall",
            loss_type=loss_type,
            output_key=output_key,
            target_key=target_key)
        self.consonant_callback = AverageRecall(
            index=2,
            offset=n_grapheme + n_vowel,
            n_classes=n_consonant,
            prefix="consonant_recall",
            loss_type=loss_type,
            output_key=output_key,
            target_key=target_key)
        super().__init__(CallbackOrder.Metric)

    def on_loader_start(self, state):
        self.grapheme_callback.on_loader_start(state)
        self.vowel_callback.on_loader_start(state)
        self.consonant_callback.on_loader_start(state)

    def on_batch_end(self, state: RunnerState):
        self.grapheme_callback.on_batch_end(state)
        self.vowel_callback.on_batch_end(state)
        self.consonant_callback.on_batch_end(state)

    def on_loader_end(self, state: RunnerState):
        self.grapheme_callback.on_loader_end(state)
        self.vowel_callback.on_loader_end(state)
        self.consonant_callback.on_loader_end(state)

        grapheme_recall = self.grapheme_callback.recall
        vowel_recall = self.vowel_callback.recall
        consonant_recall = self.consonant_callback.recall
        final_score = np.average(
            [grapheme_recall, vowel_recall, consonant_recall],
            weights=[2, 1, 1])
        state.metrics.epoch_values[state.loader_name][self.
                                                      prefix] = final_score



class SaveWeightsCallback(Callback):
    def __init__(self,
                 to: Optional[Union[Path, str]] = None,
                 name: str = "",
                 is_larger_better=True,
                 main_metric="tar",
                 save_optimizer_state=True,
                 suffix=""):
        self.to = to
        if isinstance(self.to, str):
            self.to = Path(self.to)
        self.name = name
        self.best = -np.inf if is_larger_better else np.inf
        self.is_larger_better = is_larger_better
        self.main_metric = main_metric
        self.save_optimizer_state = save_optimizer_state
        self.suffix = suffix
        super().__init__(CallbackOrder.External)

    def on_epoch_end(self, state: RunnerState):
        val_metric = state.metrics.epoch_values["valid"][self.main_metric]
        to_save = False
        if self.is_larger_better and self.best < val_metric:
            to_save = True
            self.best = val_metric
        elif not self.is_larger_better and self.best > val_metric:
            to_save = True
            self.best = val_metric
        if to_save:
            weights = state.model.state_dict()
            epoch = state.epoch
            state_dict = {
                "model_state_dict": weights,
                "epoch": epoch
            }

            if self.save_optimizer_state:
                optimizer_state = state.optimizer.state_dict()
                scheduler_state = state.scheduler.state_dict()
                state_dict["optimizer_state"] = optimizer_state
                state_dict["scheduler_state"] = scheduler_state

            logdir = state.logdir / "checkpoints"
            logdir.mkdir(exist_ok=True, parents=True)

            if self.name == "":
                torch.save(state_dict, logdir / "temp.pth")
            else:
                torch.save(state_dict, logdir / f"{self.name + self.suffix}.pth")

            if self.to is not None:
                if self.name == "":
                    torch.save(state_dict, self.to / "temp.pth")
                else:
                    torch.save(state_dict, self.to / f"{self.name + self.suffix}.pth")

                    
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


class MixupOrCutmixCallback(CriterionCallback):
    def __init__(self,
                 fields: List[str] = [
                     "images",
                 ],
                 alpha=1.0,
                 on_train_only=True,
                 mixup_prob=0.5,
                 cutmix_prob=0.5,
                 no_aug_epochs=0,
                 **kwargs):
        assert len(fields) > 0, \
            "At least one field is required"
        assert alpha >= 0, "alpha must be>=0"
        assert 1 >= mixup_prob >= 0, "mixup_prob must be within 1 and 0"
        assert 1 >= cutmix_prob >= 0, "cutmix_prob must be within 1 and 0"
        assert 1 >= mixup_prob + cutmix_prob, \
            "sum of mixup_prob and cutmix_prob must be lower than 1"

        super().__init__(**kwargs)

        self.on_train_only = on_train_only
        self.fields = fields
        self.alpha = alpha
        self.lam = 1
        self.index = None
        self.is_needed = True
        self.mixup_prob = mixup_prob
        self.cutmix_prob = cutmix_prob
        self.no_action_prob = 1 - (mixup_prob + cutmix_prob)
        self.no_aug_epochs = no_aug_epochs

    def on_loader_start(self, state: RunnerState):
        self.is_needed = not self.on_train_only or \
            state.loader_name.startswith("train")

        if state.num_epochs - state.epoch < self.no_aug_epochs:
            self.is_needed = False

    def on_batch_start(self, state: RunnerState):
        if not self.is_needed:
            return

        dice = np.random.choice(
            [0, 1, 2],
            p=[self.mixup_prob, self.cutmix_prob, self.no_action_prob])
        self.dice = dice
        if dice == 0:
            if self.alpha > 0:
                self.lam = np.random.beta(self.alpha, self.alpha)
            else:
                self.lam = 1
            self.index = torch.randperm(state.input[self.fields[0]].shape[0])
            self.index.to(state.device)

            for f in self.fields:
                state.input[f] = self.lam * state.input[f] + \
                    (1 - self.lam) * state.input[f][self.index]
        elif dice == 1:
            self.index = torch.randperm(state.input[self.fields[0]].shape[0])
            self.index.to(state.device)

            if self.alpha > 0:
                lam = np.random.beta(self.alpha, self.alpha)
            else:
                lam = 1
                bbx1, bby1, bbx2, bby2 = rand_bbox(
                    state.input[self.fields[0]].size(), lam)
                for f in self.fields:
                    state.input[f][:, :, bbx1:bbx2, bby1:bby2] = \
                        state.input[f][self.index, :, bbx1:bbx2, bby1:bby2]
                self.lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                                (state.input[self.fields[0]].size()[-1] *
                                 state.input[self.fields[0]].size()[-2]))
        else:
            pass

    def _compute_loss(self, state: RunnerState, criterion):
        if not self.is_needed:
            return super()._compute_loss(state, criterion)

        if self.dice == 0 or self.dice == 1:
            pred = state.output[self.output_key]
            y_a = state.input[self.input_key]
            y_b = state.input[self.input_key][self.index]
            loss = self.lam * criterion(pred, y_a) + \
                (1 - self.lam) * criterion(pred, y_b)
            return loss
        else:
            return super()._compute_loss(state, criterion)

In [0]:
def get_callbacks(config: edict):
    callbacks = []
    for callback in config.callbacks:
        name = list(callback.keys())[0]
        params = callback[name]
        if globals().get(name) is not None:
            if params is not None:
                callbacks.append(globals().get(name)(**params))  # type: ignore
            else:
                callbacks.append(globals().get(name)())  # type: ignore
    return callbacks

## Training

In [0]:
%%sh
mkdir -p /root/.cache/torch/checkpoints
cp /content/gdrive/My\ Drive/kaggle-bengali/se_resnext50_32x4d-a260b3a4.pth /root/.cache/torch/checkpoints/se_resnext50_32x4d-a260b3a4.pth

In [19]:
trn_idx, val_idx = splits[i]

print(f"Fold: {i}")

output_dir = output_base_dir / f"fold{i}"
output_dir.mkdir(exist_ok=True, parents=True)

trn_df = df.loc[trn_idx, :].reset_index(drop=True)
val_df = df.loc[val_idx, :].reset_index(drop=True)
data_loaders = {
    phase: get_base_loader(
        df,
        train_images_path,
        phase=phase,
        size=(config.img_size, config.img_size),
        batch_size=config.train.batch_size,
        num_workers=config.num_workers,
        transforms=transforms_dict[phase])
    for phase, df in zip(["train", "valid"], [trn_df, val_df])
}

model = get_model(config).to(get_device())
criterion = get_loss(config).to(get_device())
optimizer = get_optimizer(model, config)
scheduler = get_scheduler(optimizer, config)
callbacks = get_callbacks(config)

Fold: 0


In [20]:
runner = SupervisedRunner(
    device=ct.utils.get_device(),
    input_key="images",
    input_target_key="targets",
    output_key="logits")
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=data_loaders,
    logdir=output_dir,
    scheduler=scheduler,
    num_epochs=config.train.num_epochs,
    callbacks=callbacks,
    main_metric="tar",
    minimize_metric=False,
    monitoring_params=None,
    verbose=False,
    resume=config["weights"])

[2020-02-28 06:40:39,263] 
1/15 * Epoch 1 (train): _base/lr=5.000e-05 | _base/momentum=0.9000 | _timers/_fps=1661.9984 | _timers/batch_time=0.0399 | _timers/data_time=0.0016 | _timers/model_time=0.0364 | batch_consonant_recall=0.6375 | batch_grapheme_recall=0.3731 | batch_vowel_recall=0.6302 | consonant_recall=0.5599 | grapheme_recall=0.3561 | loss=8.3096 | tar=0.4731 | vowel_recall=0.6205
1/15 * Epoch 1 (valid): _base/lr=5.000e-05 | _base/momentum=0.9000 | _timers/_fps=1715.1413 | _timers/batch_time=0.0454 | _timers/data_time=0.0030 | _timers/model_time=0.0423 | batch_consonant_recall=0.9281 | batch_grapheme_recall=0.8205 | batch_vowel_recall=0.9426 | consonant_recall=0.9250 | grapheme_recall=0.8438 | loss=3.0275 | tar=0.8884 | vowel_recall=0.9411
[2020-02-28 07:12:44,920] 
2/15 * Epoch 2 (train): _base/lr=4.758e-05 | _base/momentum=0.9000 | _timers/_fps=1703.0763 | _timers/batch_time=0.0382 | _timers/data_time=0.0013 | _timers/model_time=0.0349 | batch_consonant_recall=0.7163 | batch

In [21]:
!ls output/fold0/checkpoints

best_full.pth
best.pth
last_full.pth
last.pth
_metrics.json
seresnext_size236_90epoch_cutout_consonant.pth
seresnext_size236_90epoch_cutout_grapheme.pth
seresnext_size236_90epoch_cutout.pth
seresnext_size236_90epoch_cutout_vowel.pth
train.13_full.pth
train.13.pth


In [22]:
state_dict = torch.load("output/fold0/checkpoints/last_full.pth")
state_dict.keys()

dict_keys(['epoch_metrics', 'valid_metrics', 'stage', 'epoch', 'checkpoint_data', 'model_state_dict', 'criterion_state_dict', 'optimizer_state_dict', 'scheduler_state_dict'])

In [0]:
!cp output/fold0/checkpoints/last_full.pth /content/gdrive/My\ Drive/kaggle-bengali/checkpoints/fold0/seresnext_size236_90epoch_cutout_latest.pth

## Check performance

In [0]:
def load_model(config: edict, bin_path: Union[str, Path]):
    # config.model.pretrained = None
    model = get_model(config)
    state_dict = torch.load(bin_path, map_location=get_device())
    if "model_state_dict" in state_dict.keys():
        model.load_state_dict(state_dict["model_state_dict"])
    else:
        model.load_state_dict(state_dict)
    return model

In [0]:
def macro_average_recall(prediction: np.ndarray, df: pd.DataFrame):
    grapheme = recall_score(
        df["grapheme_root"].values, prediction[:, 0], average="macro")
    vowel = recall_score(
        df["vowel_diacritic"].values, prediction[:, 1], average="macro")
    consonant = recall_score(
        df["consonant_diacritic"].values, prediction[:, 2], average="macro")
    return np.average([grapheme, vowel, consonant], weights=[2, 1, 1])

In [0]:
def inference_loop(model: nn.Module,
                   loader: torchdata.DataLoader,
                   cls_levels: dict,
                   loss_fn: Optional[nn.Module] = None,
                   requires_soft=False):
    n_grapheme = cls_levels["grapheme"]
    n_vowel = cls_levels["vowel"]
    n_consonant = cls_levels["consonant"]

    dataset_length = len(loader.dataset)
    prediction = np.zeros((dataset_length, 3), dtype=np.uint8)
    if requires_soft:
        soft_prediction = np.zeros(
            (dataset_length, n_grapheme + n_vowel + n_consonant),
            dtype=np.float32)

    batch_size = loader.batch_size
    device = get_device()

    avg_loss = 0.
    model.eval()

    targets: Optional[torch.Tensor] = None

    for i, batch in enumerate(progress_bar(loader, leave=False)):
        with torch.no_grad():
            if isinstance(batch, dict):
                images = batch["images"].to(device)
                targets = batch["targets"].to(device)
            else:
                images = batch.to(device)
                targets = None
            pred = model(images).detach()
            if loss_fn is not None and targets is not None:
                avg_loss += loss_fn(
                    pred, batch["targets"].to(device)).item() / len(loader)
            head = 0
            tail = n_grapheme
            pred_grapheme = torch.argmax(
                pred[:, head:tail], dim=1).cpu().numpy()

            head = tail
            tail = head + n_vowel
            pred_vowel = torch.argmax(pred[:, head:tail], dim=1).cpu().numpy()

            head = tail
            tail = head + n_consonant
            pred_consonant = torch.argmax(
                pred[:, head:tail], dim=1).cpu().numpy()

            prediction[i * batch_size:(i + 1) * batch_size, 0] = pred_grapheme
            prediction[i * batch_size:(i + 1) * batch_size, 1] = pred_vowel
            prediction[i * batch_size:(i + 1) * batch_size, 2] = pred_consonant

            if requires_soft:
                head = 0
                tail = n_grapheme
                soft_prediction[i * batch_size:(i + 1) *
                                batch_size, head:tail] = F.softmax(
                                    pred[:, head:tail], dim=1).cpu().numpy()

                head = tail
                tail = head + n_vowel
                soft_prediction[i * batch_size:(i + 1) *
                                batch_size, head:tail] = F.softmax(
                                    pred[:, head:tail], dim=1).cpu().numpy()

                head = tail
                tail = head + n_consonant
                soft_prediction[i * batch_size:(i + 1) *
                                batch_size, head:tail] = F.softmax(
                                    pred[:, head:tail], dim=1).cpu().numpy()

    return_dict = {"prediction": prediction, "loss": avg_loss}
    if requires_soft:
        return_dict["soft_prediction"] = soft_prediction

    return return_dict

In [0]:
checkpoint_path = f"output/fold{i}/checkpoints/best.pth"
model = load_model(config, checkpoint_path)
model.to(get_device())
loader = data_loaders["valid"]

prediction = inference_loop(
    model,
    loader,
    cls_levels,
    criterion,
    requires_soft=False)
score = macro_average_recall(prediction["prediction"], val_df)
print(f"Score: {score:.5f}")

Score: 0.95977


In [0]:
!cp output/fold0/checkpoints/best.pth /content/gdrive/My\ Drive/kaggle-bengali/checkpoints/fold0/seresnext_50epoch_size128_mixup_focal.pth

cp: cannot stat 'output/fold0/checkpoints/best.pth': No such file or directory
