<a href="https://colab.research.google.com/github/koukyo1994/kaggle-bengali-ai/blob/master/notebook/Resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dependencies

In [None]:
%%sh
pip install albumentations==0.4.3 catalyst==20.1.1 easydict==1.9.0 >> /dev/null
pip install efficientnet-pytorch==0.6.1 PyYAML==5.3 >> /dev/null
pip install pretrainedmodels==0.7.4 >> /dev/null

## Integration with Google Drive

In [2]:
from google.colab import drive
drive.mount("/content/gdrive")

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [None]:
%%sh
mkdir input
cp -r /content/gdrive/My\ Drive/kaggle-bengali ./input/bengaliai-cv19
unzip -qq -d input/bengaliai-cv19/ input/bengaliai-cv19/train_images.zip
unzip -qq -d input/bengaliai-cv19/ input/bengaliai-cv19/test_images.zip

## Libraries

In [4]:
import albumentations as A
import catalyst as ct
import cv2
import numpy as np
import pandas as pd
import pretrainedmodels
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata
import torchvision.models as models
import yaml

from pathlib import Path
from typing import Tuple, Dict, Union, Optional, List

from catalyst.dl import SupervisedRunner
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.dl.callbacks import MixupCallback
from catalyst.utils import get_device
from easydict import EasyDict as edict
from efficientnet_pytorch import EfficientNet
from fastprogress import progress_bar
from skimage.transform import AffineTransform, warp
from sklearn.metrics import recall_score, confusion_matrix
from sklearn.model_selection import KFold, train_test_split
from torch.nn.parameter import Parameter
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import (ReduceLROnPlateau, 
                                      CosineAnnealingLR,
                                      CosineAnnealingWarmRestarts)

alchemy not available, to install alchemy, run `pip install alchemy-catalyst`.


## Settings

In [None]:
i = 0
trial = "resnet34_fourth"

## Config

In [None]:
conf_string = '''
dataset:
  train:
    affine: True
    morphology: False
  val:
    affine: False
    morphology: False
  test:
    affine: False
    morphology: False

data:
  train_df_path: input/bengaliai-cv19/train.csv
  train_images_path: input/bengaliai-cv19/train_images
  test_images_path: input/bengaliai-cv19/test_images
  sample_submission_path: input/bengaliai-cv19/sample_submission.csv

model:
  model_name: resnet34
  pretrained: True
  num_classes: 186
  head: custom
  in_channels: 3

train:
  batch_size: 128
  num_epochs: 10

test:
  batch_size: 128

loss:
  name: cross_entropy
  params:
    n_grapheme: 168
    n_vowel: 11
    n_consonant: 7

optimizer:
  name: Adam
  params:
    lr: 0.0001

scheduler:
  name: cosine
  params:
    T_max: 10

transforms:
  train:
    Noise: False
    Contrast: False
    Rotate: True
    RandomScale: True
    Cutout:
      num_holes: 0
  val:
    Noise: False
    Contrast: False
    Rotate: False
    RandomScale: False
    Cutout:
      num_holes: 0
  test:
    Noise: False
    Contrast: False
    Rotate: False
    RandomScale: False
    Cutout:
      num_holes: 0

val:
  name: kfold
  params:
    random_state: 42
    n_splits: 5

callbacks:
  - AverageRecall:
      index: 0
      offset: 0
      n_classes: 168
      prefix: grapheme_recall
      loss_type: cross_entroy
  - AverageRecall:
      index: 1
      offset: 168
      n_classes: 11
      prefix: vowel_recall
      loss_type: cross_entropy
  - AverageRecall:
      index: 2
      offset: 179
      n_classes: 7
      prefix: consonant_recall
      loss_type: cross_entropy
  - TotalAverageRecall:
      loss_type: cross_entropy
  - SaveWeightsCallback:
      to: /content/gdrive/My Drive/kaggle-bengali/checkpoints

log_dir: log/
num_workers: 2
seed: 1213
img_size: 128
mixup: False
'''

In [None]:
def _get_default():
    cfg = edict()

    # dataset
    cfg.dataset = edict()
    cfg.dataset.train = edict()
    cfg.dataset.val = edict()
    cfg.dataset.test = edict()
    cfg.dataset.train.affine = False
    cfg.dataset.train.morphology = False
    cfg.dataset.val.affine = False
    cfg.dataset.val.morphology = False
    cfg.dataset.test.affine = False
    cfg.dataset.test.morphology = False

    # dataset
    cfg.data = edict()

    # model
    cfg.model = edict()
    cfg.model.model_name = "resnet18"
    cfg.model.num_classes = 186
    cfg.model.pretrained = True
    cfg.model.head = "linear"
    cfg.model.in_channels = 3

    # train
    cfg.train = edict()

    # test
    cfg.test = edict()

    # loss
    cfg.loss = edict()
    cfg.loss.params = edict()

    # optimizer
    cfg.optimizer = edict()
    cfg.optimizer.params = edict()

    # scheduler
    cfg.scheduler = edict()
    cfg.scheduler.params = edict()

    # transforms:
    cfg.transforms = edict()
    cfg.transforms.train = edict()
    cfg.transforms.train.HorizontalFlip = False
    cfg.transforms.train.VerticalFlip = False
    cfg.transforms.train.Noise = False
    cfg.transforms.train.Contrast = False
    cfg.transforms.train.Rotate = False
    cfg.transforms.train.RandomScale = False
    cfg.transforms.train.Cutout = edict()
    cfg.transforms.train.Cutout.num_holes = 0
    cfg.transforms.val = edict()
    cfg.transforms.val.HorizontalFlip = False
    cfg.transforms.val.VerticalFlip = False
    cfg.transforms.val.Noise = False
    cfg.transforms.val.Contrast = False
    cfg.transforms.val.Rotate = False
    cfg.transforms.val.RandomScale = False
    cfg.transforms.val.Cutout = edict()
    cfg.transforms.val.Cutout.num_holes = 0
    cfg.transforms.test = edict()
    cfg.transforms.test.HorizontalFlip = False
    cfg.transforms.test.VerticalFlip = False
    cfg.transforms.test.Noise = False
    cfg.transforms.test.Contrast = False
    cfg.transforms.test.Rotate = False
    cfg.transforms.test.RandomScale = False
    cfg.transforms.test.Cutout = edict()
    cfg.transforms.test.Cutout.num_holes = 0
    cfg.transforms.mean = [0.485, 0.456, 0.406]
    cfg.transforms.std = [0.229, 0.224, 0.225]

    # val
    cfg.val = edict()
    cfg.val.params = edict()

    cfg.callbacks = []

    return cfg


def _merge_config(src: edict, dst: edict):
    if not isinstance(src, edict):
        return
    for k, v in src.items():
        if isinstance(v, edict):
            _merge_config(src[k], dst[k])
        else:
            dst[k] = v

In [None]:
cfg = edict(yaml.load(conf_string, Loader=yaml.SafeLoader))
config = _get_default()
_merge_config(cfg, config)

## Environmental settings

In [9]:
ct.utils.set_global_seed(config.seed)
ct.utils.prepare_cudnn(deterministic=True)

In [None]:
output_base_dir = Path("output")
output_base_dir.mkdir(exist_ok=True, parents=True)

train_images_path = Path(config.data.train_images_path)

## Data and utilities preparation

### validation utils

In [None]:
def no_fold(df: pd.DataFrame,
            config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    params = config.val.params
    idx = np.arange(len(df))
    trn_idx, val_idx = train_test_split(idx, **params)
    return [(trn_idx, val_idx)]


def kfold(df: pd.DataFrame,
          config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    params = config.val.params
    kf = KFold(shuffle=True, **params)
    splits = list(kf.split(df))
    return splits


def get_validation(df: pd.DataFrame,
                   config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    name: str = config.val.name

    func = globals().get(name)
    if func is None:
        raise NotImplementedError

    return func(df, config)

### transforms

In [None]:
def get_transforms(config: edict, phase: str = "train"):
    assert phase in ["train", "valid", "test"]
    if phase == "train":
        cfg = config.transforms.train
    elif phase == "valid":
        cfg = config.transforms.val
    elif phase == "test":
        cfg = config.transforms.test
    list_transforms = []
    if cfg.HorizontalFlip:
        list_transforms.append(A.HorizontalFrip())
    if cfg.VerticalFlip:
        list_transforms.append(A.VerticalFlip())
    if cfg.Rotate:
        list_transforms.append(A.Rotate(limit=15))
    if cfg.RandomScale:
        list_transforms.append(A.RandomScale())
    if cfg.Noise:
        list_transforms.append(
            A.OneOf(
                [A.GaussNoise(), A.IAAAdditiveGaussianNoise()], p=0.5))
    if cfg.Contrast:
        list_transforms.append(
            A.OneOf(
                [A.RandomContrast(0.5),
                 A.RandomGamma(),
                 A.RandomBrightness()],
                p=0.5))
    if cfg.Cutout.num_holes > 0:
        list_transforms.append(A.Cutout(**config.Cutout))

    list_transforms.append(
        A.Normalize(
            mean=config.transforms.mean, std=config.transforms.std, p=1))

    return A.Compose(list_transforms, p=1.0)

### Data Loading

In [None]:
df = pd.read_csv(config.data.train_df_path)
splits = get_validation(df, config)

transforms_dict = {
    phase: get_transforms(config, phase)
    for phase in ["train", "valid"]
}

cls_levels = {
    "grapheme": df.grapheme_root.nunique(),
    "vowel": df.vowel_diacritic.nunique(),
    "consonant": df.consonant_diacritic.nunique()
}

## Dataset and DataLoader

In [None]:
class BaseDataset(torchdata.Dataset):
    def __init__(self, image_dir: Path, df: pd.DataFrame, transforms,
                 size: Tuple[int, int]):
        self.df = df
        self.image_dir = image_dir
        self.transforms = transforms
        self.size = size

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_id = self.df.loc[idx, "image_id"]
        image_path = self.image_dir / f"{image_id}.png"

        image = cv2.imread(str(image_path))
        longer_side = image.shape[1]
        if image.ndim == 2:
            new_image = np.ones(
                (longer_side, longer_side), dtype=np.uint8) * 255
        else:
            new_image = np.ones(
                (longer_side, longer_side, 3), dtype=np.uint8) * 255
        offset = np.random.randint(0, longer_side - image.shape[0])
        new_image[offset:offset + image.shape[0], :] = image

        if self.transforms is not None:
            image = self.transforms(image=new_image)["image"]
        image = cv2.resize(image, self.size)
        if image.shape[2] == 3:
            image = np.moveaxis(image, -1, 0)
        grapheme = self.df.loc[idx, "grapheme_root"]
        vowel = self.df.loc[idx, "vowel_diacritic"]
        consonant = self.df.loc[idx, "consonant_diacritic"]
        label = np.zeros(3, dtype=int)
        label[0] = grapheme
        label[1] = vowel
        label[2] = consonant
        return {"images": image, "targets": label}
    
    
def get_base_loader(df: pd.DataFrame,
                    image_dir: Path,
                    phase: str = "train",
                    size: Tuple[int, int] = (128, 128),
                    batch_size=256,
                    num_workers=2,
                    transforms=None):
    assert phase in ["train", "valid"]
    if phase == "train":
        is_shuffle = True
        drop_last = True
    else:
        is_shuffle = False
        drop_last = False

    dataset = BaseDataset(  # type: ignore
        image_dir, df, transforms, size)
    return torchdata.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=drop_last)

## Model and Loss

### Model

In [None]:
def gem(x: torch.Tensor, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p),
                        (x.size(-2), x.size(-1))).pow(1. / p)


def mish(input):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    See additional documentation for mish class.
    '''
    return input * torch.tanh(F.softplus(input))


class Mish(nn.Module):
    '''
    Applies the mish function element-wise:
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
    Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
    Examples:
        >>> m = Mish()
        >>> input = torch.randn(2)
        >>> output = m(input)
    '''

    def __init__(self):
        '''
        Init method.
        '''
        super().__init__()

    def forward(self, input):
        '''
        Forward pass of the function.
        '''
        return mish(input)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps).squeeze(-1).squeeze(-1)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(
            self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'


class Resnet(nn.Module):
    def __init__(self,
                 model_name: str,
                 num_classes: int,
                 pretrained=False,
                 head="linear",
                 in_channels=3):
        super().__init__()
        self.num_classes = num_classes
        self.base = getattr(models, model_name)(pretrained=pretrained)
        self.head = head
        assert in_channels in [1, 3]
        assert head in ["linear", "custom"]
        if in_channels == 1:
            if pretrained:
                weight = self.base.conv1.weight
                self.base.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
                self.base.conv1.weight = nn.Parameter(
                    data=torch.mean(weight, dim=1, keepdim=True),
                    requires_grad=True)
            else:
                self.base.conv1 = nn.Conv2d(
                    1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if head == "linear":
            n_in_features = self.base.fc.in_features
            self.base.fc = nn.Linear(n_in_features, self.num_classes)
        elif head == "custom":
            n_in_features = self.base.fc.in_features
            arch = list(self.base.children())
            for _ in range(2):
                arch.pop()
            self.base = nn.Sequential(*arch)
            self.grapheme_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 168))
            self.vowel_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 11))
            self.consonant_head = nn.Sequential(
                Mish(), nn.Conv2d(n_in_features, 512, kernel_size=3),
                nn.BatchNorm2d(512), GeM(), nn.Linear(512, 7))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.head == "linear":
            return self.base(x)
        elif self.head == "custom":
            x = self.base(x)
            grapheme = self.grapheme_head(x)
            vowel = self.vowel_head(x)
            consonant = self.consonant_head(x)
            return torch.cat([grapheme, vowel, consonant], dim=1)
        else:
            raise NotImplementedError


def get_model(config: edict):
    params = config.model
    if "resnet" in params.model_name:
        return Resnet(**params)
    else:
        raise NotImplementedError

### Loss

In [None]:
class BengaliCrossEntropyLoss(nn.Module):
    def __init__(self, n_grapheme: int, n_vowel: int, n_consonant: int):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.cross_entropy(grapheme_pred, grapheme_true) + \
            self.cross_entropy(vowel_pred, vowel_true) + \
            self.cross_entropy(consonant_pred, consonant_true)


class BengaliBCELoss(nn.Module):
    def __init__(self, n_grapheme: int, n_vowel: int, n_consonant: int):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, head:tail]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, head:tail]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, head:tail]

        return self.bce(grapheme_pred, grapheme_true) + \
            self.bce(vowel_pred, vowel_true) + \
            self.bce(consonant_pred, consonant_true)


def get_loss(config: edict):
    name = config.loss.name
    params = config.loss.params
    if name == "bce":
        criterion = BengaliBCELoss(**params)
    elif name == "cross_entropy":
        criterion = BengaliCrossEntropyLoss(**params)  # type: ignore
    else:
        raise NotImplementedError
    return criterion

## Optimizer and Scheduler

### Optimizer

In [None]:
Optimizer = Union[Adam, SGD]


def get_optimizer(model, config: edict) -> Optimizer:
    name = config.optimizer.name
    params = config.optimizer.params
    if name == "Adam":
        optimizer = Adam(model.parameters(), **params)
    elif name == "SGD":
        optimizer = Adam(model.parameters(), **params)
    else:
        raise NotImplementedError
    return optimizer

### Scheduler

In [None]:
Scheduler = Optional[
    Union[ReduceLROnPlateau, CosineAnnealingLR, CosineAnnealingWarmRestarts]]


def get_scheduler(optimizer, config: edict) -> Scheduler:
    params = config.scheduler.params
    name = config.scheduler.name
    scheduler: Scheduler = None
    if name == "plateau":
        scheduler = ReduceLROnPlateau(optimizer, **params)
    elif name == "cosine":
        scheduler = CosineAnnealingLR(optimizer, **params)
    elif name == "cosine_warmup":
        scheduler = CosineAnnealingWarmRestarts(optimizer, **params)

    return scheduler

## Callbacks

In [None]:
class AverageRecall(Callback):
    def __init__(self,
                 index: int,
                 offset: int,
                 n_classes: int,
                 prefix: str,
                 loss_type: str = "bce",
                 output_key: str = "logits",
                 target_key: str = "targets"):
        self.index = index
        self.offset = offset
        self.n_classes = n_classes
        self.prefix = prefix
        self.loss_type = loss_type
        self.output_key = output_key
        self.target_key = target_key
        self.recall = 0.0
        super().__init__(CallbackOrder.Metric)

    def on_loader_start(self, state: RunnerState):
        self.prediction: List[int] = []
        self.target: List[int] = []

    def on_batch_end(self, state: RunnerState):
        targ = state.input[self.target_key].detach()
        out = state.output[self.output_key].detach()
        head = self.offset
        tail = self.offset + self.n_classes
        if self.loss_type == "bce":
            pred_np = torch.argmax(
                torch.sigmoid(out[:, head:tail]), dim=1).cpu().numpy()
            target_np = torch.argmax(targ[:, head:tail], dim=1).cpu().numpy()
        else:
            pred_np = torch.argmax(out[:, head:tail], dim=1).cpu().numpy()
            target_np = targ[:, self.index].cpu().numpy()
        self.prediction.extend(pred_np)
        self.target.extend(target_np)
        score = recall_score(
            target_np, pred_np, average="macro", zero_division=0)
        state.metrics.add_batch_value(name="batch_" + self.prefix, value=score)

    def on_loader_end(self, state: RunnerState):
        metric_name = self.prefix
        y_true = np.asarray(self.target)
        y_pred = np.asarray(self.prediction)

        metric = recall_score(y_true, y_pred, average="macro")
        state.metrics.epoch_values[state.loader_name][metric_name] = float(
            metric)
        self.recall = metric


class TotalAverageRecall(Callback):
    def __init__(self,
                 n_grapheme=168,
                 n_vowel=11,
                 n_consonant=7,
                 loss_type: str = "bce",
                 prefix: str = "tar",
                 output_key: str = "logits",
                 target_key: str = "targets"):
        self.prefix = prefix
        self.grapheme_callback = AverageRecall(
            index=0,
            offset=0,
            n_classes=n_grapheme,
            prefix="grapheme_recall",
            loss_type=loss_type,
            output_key=output_key,
            target_key=target_key)
        self.vowel_callback = AverageRecall(
            index=1,
            offset=n_grapheme,
            n_classes=n_vowel,
            prefix="vowel_recall",
            loss_type=loss_type,
            output_key=output_key,
            target_key=target_key)
        self.consonant_callback = AverageRecall(
            index=2,
            offset=n_grapheme + n_vowel,
            n_classes=n_consonant,
            prefix="consonant_recall",
            loss_type=loss_type,
            output_key=output_key,
            target_key=target_key)
        super().__init__(CallbackOrder.Metric)

    def on_loader_start(self, state):
        self.grapheme_callback.on_loader_start(state)
        self.vowel_callback.on_loader_start(state)
        self.consonant_callback.on_loader_start(state)

    def on_batch_end(self, state: RunnerState):
        self.grapheme_callback.on_batch_end(state)
        self.vowel_callback.on_batch_end(state)
        self.consonant_callback.on_batch_end(state)

    def on_loader_end(self, state: RunnerState):
        self.grapheme_callback.on_loader_end(state)
        self.vowel_callback.on_loader_end(state)
        self.consonant_callback.on_loader_end(state)

        grapheme_recall = self.grapheme_callback.recall
        vowel_recall = self.vowel_callback.recall
        consonant_recall = self.consonant_callback.recall
        final_score = np.average(
            [grapheme_recall, vowel_recall, consonant_recall],
            weights=[2, 1, 1])
        state.metrics.epoch_values[state.loader_name][self.
                                                      prefix] = final_score



class SaveWeightsCallback(Callback):
    def __init__(self, to: Optional[Union[Path, str]] = None, name: str=""):
        if isinstance(to, str):
            self.to = Path(to)
        else:
            self.to = to
        self.name = name
        super().__init__(CallbackOrder.External)

    def on_epoch_end(self, state: RunnerState):
        weights = state.model.state_dict()
        logdir = state.logdir / "checkpoints"
        logdir.mkdir(exist_ok=True, parents=True)
        if self.name == "":
            torch.save(weights, logdir / "temp.pth")
        else:
            torch.save(weights, logdir / f"{self.name}.pth")

        if self.to is not None:
            if self.name == "":
                torch.save(weights, self.to / "temp.pth")
            else:
                torch.save(weights, self.to / f"{self.name}.pth")

In [None]:
def get_callbacks(config: edict):
    callbacks = []
    for callback in config.callbacks:
        name = list(callback.keys())[0]
        params = callback[name]
        if globals().get(name) is not None:
            if params is not None:
                callbacks.append(globals().get(name)(**params))  # type: ignore
            else:
                callbacks.append(globals().get(name)())  # type: ignore
    return callbacks

## Training

In [24]:
trn_idx, val_idx = splits[i]

print(f"Fold: {i}")

output_dir = output_base_dir / f"fold{i}"
output_dir.mkdir(exist_ok=True, parents=True)

trn_df = df.loc[trn_idx, :].reset_index(drop=True)
val_df = df.loc[val_idx, :].reset_index(drop=True)
data_loaders = {
    phase: get_base_loader(
        df,
        train_images_path,
        phase=phase,
        size=(config.img_size, config.img_size),
        batch_size=config.train.batch_size,
        num_workers=config.num_workers,
        transforms=transforms_dict[phase])
    for phase, df in zip(["train", "valid"], [trn_df, val_df])
}
model = get_model(config)
criterion = get_loss(config)
optimizer = get_optimizer(model, config)
scheduler = get_scheduler(optimizer, config)
callbacks = get_callbacks(config)

if config.mixup:
    callbacks.append(MixupCallback(fields=[
        "images",
    ]))

runner = SupervisedRunner(
    device=ct.utils.get_device(),
    input_key="images",
    input_target_key="targets",
    output_key="logits")
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=data_loaders,
    logdir=output_dir,
    scheduler=scheduler,
    num_epochs=config.train.num_epochs,
    callbacks=callbacks,
    main_metric="tar",
    state_kwargs={
        "batch_consistant_metrics": False
    },
    minimize_metric=False,
    monitoring_params=None,
    verbose=False)

Fold: 0
[2020-01-31 08:21:31,445] 
1/10 * Epoch 1 (train): _base/lr=0.0001 | _base/momentum=0.9000 | _timers/_fps=4134.0685 | _timers/batch_time=0.0933 | _timers/data_time=0.0758 | _timers/model_time=0.0175 | batch_consonant_recall=0.9016 | batch_grapheme_recall=0.6766 | batch_vowel_recall=0.9155 | consonant_recall=0.8841 | grapheme_recall=0.6505 | loss=1.3987 | tar=0.7755 | vowel_recall=0.9171
1/10 * Epoch 1 (valid): _base/lr=0.0001 | _base/momentum=0.9000 | _timers/_fps=5098.4677 | _timers/batch_time=0.1923 | _timers/data_time=0.1790 | _timers/model_time=0.0133 | batch_consonant_recall=0.9397 | batch_grapheme_recall=0.8551 | batch_vowel_recall=0.9602 | consonant_recall=0.9372 | grapheme_recall=0.8897 | loss=0.5286 | tar=0.9195 | vowel_recall=0.9615
[2020-01-31 08:21:31,445] 
1/10 * Epoch 1 (train): _base/lr=0.0001 | _base/momentum=0.9000 | _timers/_fps=4134.0685 | _timers/batch_time=0.0933 | _timers/data_time=0.0758 | _timers/model_time=0.0175 | batch_consonant_recall=0.9016 | batch_

## Check performance

In [None]:
def load_model(config: edict, bin_path: Union[str, Path]):
    # config.model.pretrained = None
    model = get_model(config)
    state_dict = torch.load(bin_path, map_location=get_device())
    if "model_state_dict" in state_dict.keys():
        model.load_state_dict(state_dict["model_state_dict"])
    else:
        model.load_state_dict(state_dict)
    return model

In [None]:
def macro_average_recall(prediction: np.ndarray, df: pd.DataFrame):
    grapheme = recall_score(
        df["grapheme_root"].values, prediction[:, 0], average="macro")
    vowel = recall_score(
        df["vowel_diacritic"].values, prediction[:, 1], average="macro")
    consonant = recall_score(
        df["consonant_diacritic"].values, prediction[:, 2], average="macro")
    return np.average([grapheme, vowel, consonant], weights=[2, 1, 1])

In [None]:
def inference_loop(model: nn.Module,
                   loader: torchdata.DataLoader,
                   cls_levels: dict,
                   loss_fn: Optional[nn.Module] = None,
                   requires_soft=False):
    n_grapheme = cls_levels["grapheme"]
    n_vowel = cls_levels["vowel"]
    n_consonant = cls_levels["consonant"]

    dataset_length = len(loader.dataset)
    prediction = np.zeros((dataset_length, 3), dtype=np.uint8)
    if requires_soft:
        soft_prediction = np.zeros(
            (dataset_length, n_grapheme + n_vowel + n_consonant),
            dtype=np.float32)

    batch_size = loader.batch_size
    device = get_device()

    avg_loss = 0.
    model.eval()

    targets: Optional[torch.Tensor] = None

    for i, batch in enumerate(progress_bar(loader, leave=False)):
        with torch.no_grad():
            if isinstance(batch, dict):
                images = batch["images"].to(device)
                targets = batch["targets"].to(device)
            else:
                images = batch.to(device)
                targets = None
            pred = model(images).detach()
            if loss_fn is not None and targets is not None:
                avg_loss += loss_fn(
                    pred, batch["targets"].to(device)).item() / len(loader)
            head = 0
            tail = n_grapheme
            pred_grapheme = torch.argmax(
                pred[:, head:tail], dim=1).cpu().numpy()

            head = tail
            tail = head + n_vowel
            pred_vowel = torch.argmax(pred[:, head:tail], dim=1).cpu().numpy()

            head = tail
            tail = head + n_consonant
            pred_consonant = torch.argmax(
                pred[:, head:tail], dim=1).cpu().numpy()

            prediction[i * batch_size:(i + 1) * batch_size, 0] = pred_grapheme
            prediction[i * batch_size:(i + 1) * batch_size, 1] = pred_vowel
            prediction[i * batch_size:(i + 1) * batch_size, 2] = pred_consonant

            if requires_soft:
                head = 0
                tail = n_grapheme
                soft_prediction[i * batch_size:(i + 1) *
                                batch_size, head:tail] = F.softmax(
                                    pred[:, head:tail], dim=1).cpu().numpy()

                head = tail
                tail = head + n_vowel
                soft_prediction[i * batch_size:(i + 1) *
                                batch_size, head:tail] = F.softmax(
                                    pred[:, head:tail], dim=1).cpu().numpy()

                head = tail
                tail = head + n_consonant
                soft_prediction[i * batch_size:(i + 1) *
                                batch_size, head:tail] = F.softmax(
                                    pred[:, head:tail], dim=1).cpu().numpy()

    return_dict = {"prediction": prediction, "loss": avg_loss}
    if requires_soft:
        return_dict["soft_prediction"] = soft_prediction

    return return_dict

In [28]:
checkpoint_path = "output/fold0/checkpoints/best.pth"
model = load_model(config, checkpoint_path)
model.to(get_device())
loader = data_loaders["valid"]

prediction = inference_loop(
    model,
    loader,
    cls_levels,
    criterion,
    requires_soft=False)
score = macro_average_recall(prediction["prediction"], val_df)
print(f"Score: {score:.5f}")

Score: 0.97414


In [None]:
!cp output/fold0/checkpoints/best.pth /content/gdrive/My\ Drive/kaggle-bengali/checkpoints/fold0/resnet34_10epoch_size224.pth