<a href="https://colab.research.google.com/github/koukyo1994/kaggle-bengali-ai/blob/master/notebook/Colab-Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dependencies

In [0]:
%%sh
pip install albumentations==0.4.3 catalyst==20.1.1 easydict==1.9.0 >> /dev/null
pip install efficientnet-pytorch==0.6.1 PyYAML==5.3 >> /dev/null
pip install pretrainedmodels==0.7.4 >> /dev/null

## Integration with Google Drive

In [2]:
from google.colab import drive
drive.mount("/content/gdrive")

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
%%sh
mkdir input
cp -r /content/gdrive/My\ Drive/kaggle-bengali ./input/bengaliai-cv19
unzip -qq -d input/bengaliai-cv19/ input/bengaliai-cv19/train_images.zip
unzip -qq -d input/bengaliai-cv19/ input/bengaliai-cv19/test_images.zip

## Libraries

In [4]:
import albumentations as A
import catalyst as ct
import cv2
import numpy as np
import pandas as pd
import pretrainedmodels
import torch
import torch.nn as nn
import torch.utils.data as torchdata
import yaml

from pathlib import Path
from typing import Tuple, Dict, Union, Optional, List

from catalyst.dl import SupervisedRunner
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from easydict import EasyDict as edict
from efficientnet_pytorch import EfficientNet
from skimage.transform import AffineTransform, warp
from sklearn.metrics import recall_score
from sklearn.model_selection import KFold, train_test_split
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import (ReduceLROnPlateau, 
                                      CosineAnnealingLR,
                                      CosineAnnealingWarmRestarts)


alchemy not available, to install alchemy, run `pip install alchemy-catalyst`.


## Settings

In [0]:
i = 0
trial = "cross_entropy"

## Config

In [0]:
conf_string = '''
dataset:
  train:
    affine: True
    morphology: True
  val:
    affine: False
    morphology: False
  test:
    affine: False
    morphology: False

data:
  train_df_path: input/bengaliai-cv19/train.csv
  train_images_path: input/bengaliai-cv19/train_images
  test_images_path: input/bengaliai-cv19/test_images
  sample_submission_path: input/bengaliai-cv19/sample_submission.csv

model:
  model_name: se_resnext50_32x4d
  pretrained: imagenet
  num_classes: 186

train:
  batch_size: 32
  num_epochs: 50

test:
  batch_size: 32

loss:
  name: cross_entropy
  params:
    n_grapheme: 168
    n_vowel: 11
    n_consonant: 7

optimizer:
  name: Adam
  params:
    lr: 0.001

scheduler:
  name: plateau

transforms:
  Noise: True
  Contrast: True
  Cutout:
    num_holes: 0

val:
  name: kfold
  params:
    random_state: 42
    n_splits: 5

log_dir: log/
num_workers: 2
seed: 1213
img_size: 64
checkpoints: /content/gdrive/My Drive/kaggle-bengali/checkpoints/
'''

In [0]:
def _get_default():
    cfg = edict()

    # dataset
    cfg.dataset = edict()
    cfg.dataset.train = edict()
    cfg.dataset.val = edict()
    cfg.dataset.test = edict()
    cfg.dataset.train.affine = False
    cfg.dataset.train.morphology = False
    cfg.dataset.val.affine = False
    cfg.dataset.val.morphology = False
    cfg.dataset.test.affine = False
    cfg.dataset.test.morphology = False

    # dataset
    cfg.data = edict()

    # model
    cfg.model = edict()

    # train
    cfg.train = edict()

    # test
    cfg.test = edict()

    # loss
    cfg.loss = edict()
    cfg.loss.params = edict()

    # optimizer
    cfg.optimizer = edict()
    cfg.optimizer.params = edict()

    # scheduler
    cfg.scheduler = edict()
    cfg.scheduler.params = edict()

    # transforms:
    cfg.transforms = edict()
    cfg.transforms.HorizontalFlip = False
    cfg.transforms.VerticalFlip = False
    cfg.transforms.Noise = False
    cfg.transforms.Contrast = False
    cfg.transforms.Cutout = edict()
    cfg.transforms.Cutout.num_holes = 0
    cfg.transforms.mean = [0.485, 0.456, 0.406]
    cfg.transforms.std = [0.229, 0.224, 0.225]

    # val
    cfg.val = edict()
    cfg.val.params = edict()

    return cfg


def _merge_config(src: edict, dst: edict):
    if not isinstance(src, edict):
        return
    for k, v in src.items():
        if isinstance(v, edict):
            _merge_config(src[k], dst[k])
        else:
            dst[k] = v

In [0]:
cfg = edict(yaml.load(conf_string, Loader=yaml.SafeLoader))
config = _get_default()
_merge_config(cfg, config)

## Environmental settings

In [10]:
ct.utils.set_global_seed(config.seed)
ct.utils.prepare_cudnn(deterministic=True)

In [0]:
output_base_dir = Path("output")
output_base_dir.mkdir(exist_ok=True, parents=True)

train_images_path = Path(config.data.train_images_path)

## Data and utilities preparation

### validation utils

In [0]:
def no_fold(df: pd.DataFrame,
            config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    params = config.val.params
    idx = np.arange(len(df))
    trn_idx, val_idx = train_test_split(idx, **params)
    return [(trn_idx, val_idx)]


def kfold(df: pd.DataFrame,
          config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    params = config.val.params
    kf = KFold(shuffle=True, **params)
    splits = list(kf.split(df))
    return splits


def get_validation(df: pd.DataFrame,
                   config: edict) -> List[Tuple[np.ndarray, np.ndarray]]:
    name: str = config.val.name

    func = globals().get(name)
    if func is None:
        raise NotImplementedError

    return func(df, config)

### transforms

In [0]:
def get_transforms(config: edict):
    list_transforms = []
    if config.transforms.HorizontalFlip:
        list_transforms.append(A.HorizontalFrip())
    if config.transforms.VerticalFlip:
        list_transforms.append(A.VerticalFlip())
    if config.transforms.Noise:
        list_transforms.append(
            A.OneOf(
                [A.GaussNoise(), A.IAAAdditiveGaussianNoise()], p=0.5))
    if config.transforms.Contrast:
        list_transforms.append(
            A.OneOf(
                [A.RandomContrast(0.5),
                 A.RandomGamma(),
                 A.RandomBrightness()],
                p=0.5))
    if config.transforms.Cutout.num_holes > 0:
        list_transforms.append(A.Cutout(**config.Cutout))

    list_transforms.append(
        A.Normalize(
            mean=config.transforms.mean, std=config.transforms.std, p=1))

    return A.Compose(list_transforms, p=1.0)

### Data Loading

In [0]:
df = pd.read_csv(config.data.train_df_path)
splits = get_validation(df, config)
transforms = get_transforms(config)

cls_levels = {
    "grapheme": df.grapheme_root.nunique(),
    "vowel": df.vowel_diacritic.nunique(),
    "consonant": df.consonant_diacritic.nunique()
}

## Dataset and DataLoader

In [0]:
def crop_image(image: np.ndarray, threshold=5. / 255.) -> np.ndarray:
    assert image.ndim == 2
    is_black = image > threshold
    is_black_vertical = np.sum(is_black, axis=0) > 0
    is_black_horizontal = np.sum(is_black, axis=1) > 0

    left = np.argmax(is_black_horizontal)
    right = np.argmax(is_black_horizontal[::-1])
    top = np.argmax(is_black_vertical)
    bottom = np.argmax(is_black_vertical[::-1])
    height, width = image.shape
    cropped_image = image[left:height - right, top:width - bottom]
    return cropped_image


def resize(image, size=(128, 128)) -> np.ndarray:
    return cv2.resize(image, size)


def crop_and_embed(image: np.ndarray, size=(128, 128), threshold=20. / 255.):
    cropped = crop_image(image, threshold)
    height, width = cropped.shape
    aspect_ratio = height / width
    embedded = np.zeros(size)
    if aspect_ratio > 1.0:
        if height > size[0]:
            new_height = size[0]
            new_width = int(size[0] * 1 / aspect_ratio)
            image = resize(cropped, size=(new_width, new_height))

            margin = size[1] - new_width
            head = margin // 2
            embedded[:, head:head + new_width] = image
        else:
            margin = size[0] - height

            new_height = height + np.random.randint(0, margin)
            new_width = int(new_height * 1 / aspect_ratio)
            image = resize(cropped, size=(new_width, new_height))

            margin_height = size[0] - new_height
            margin_width = size[1] - new_width

            head_height = margin_height // 2
            head_width = margin_width // 2
            embedded[head_height:head_height +
                     new_height, head_width:head_width + new_width] = image
    else:
        if width > size[1]:
            new_width = size[1]
            new_height = int(size[1] * aspect_ratio)
            image = resize(cropped, size=(new_width, new_height))

            margin = size[0] - new_height
            head = margin // 2
            embedded[head:head + new_height, :] = image
        else:
            margin = size[1] - width

            new_width = width + np.random.randint(0, margin)
            new_height = int(new_width * aspect_ratio)
            image = resize(cropped, size=(new_width, new_height))

            margin_height = size[0] - new_height
            margin_width = size[1] - new_width

            head_height = margin_height // 2
            head_width = margin_width // 2
            embedded[head_height:head_height +
                     new_height, head_width:head_width + new_width] = image

    return embedded


def normalize(image: np.ndarray):
    if image.ndim == 3:
        image = image[:, :, 0]
    image = (255 - image).astype(np.float32) / 255.0
    return image


def to_image(image: np.ndarray):
    if image.ndim == 2:
        image = np.stack([image, image, image])
        image = np.moveaxis(image, 0, -1)
    image = (255 - image * 255).astype(np.uint8)
    return image


def affine_image(image: np.ndarray):
    assert image.ndim == 2
    min_scale = 0.8
    max_scale = 1.2
    sx = np.random.uniform(min_scale, max_scale)
    sy = np.random.uniform(min_scale, max_scale)

    max_rot_angle = 10
    rot_angle = np.random.uniform(-max_rot_angle, max_rot_angle) * np.pi / 180.

    max_shear_angle = 10
    shear_angle = np.random.uniform(-max_shear_angle,
                                    max_shear_angle) * np.pi / 180.

    max_translation = 4
    tx = np.random.randint(-max_translation, max_translation)
    ty = np.random.randint(-max_translation, max_translation)

    tform = AffineTransform(
        scale=(sx, sy),
        rotation=rot_angle,
        shear=shear_angle,
        translation=(tx, ty))
    transformed_image = warp(image, tform)
    return transformed_image


def random_erosion_or_dilation(image: np.ndarray):
    dice = np.random.randint(0, 3)
    if dice == 0:
        return image
    elif dice == 1:
        kernel = np.ones((3, 3), dtype=np.uint8)
        return cv2.erode(image, kernel, iterations=1)
    else:
        kernel = np.ones((3, 3), dtype=np.uint8)
        return cv2.dilate(image, kernel, iterations=1)

In [0]:
class TrainDataset(torchdata.Dataset):
    def __init__(self,
                 image_dir: Path,
                 df: pd.DataFrame,
                 transforms,
                 size: Tuple[int, int],
                 cls_levels: Dict[str, int] = None,
                 affine=True,
                 morphology=True,
                 onehot=True):
        self.df = df
        self.image_dir = image_dir
        self.transforms = transforms
        self.size = size
        self.onehot = onehot
        self.cls_levels = cls_levels
        self.affine = affine
        self.morphology = morphology

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_id = self.df.loc[idx, "image_id"]
        image_path = self.image_dir / f"{image_id}.png"

        image = cv2.imread(str(image_path))
        image = normalize(image)
        image = crop_and_embed(image, size=self.size, threshold=5. / 255.)
        if self.affine:
            image = affine_image(image)
        if self.morphology:
            image = random_erosion_or_dilation(image)
        image = to_image(image)
        if self.transforms is not None:
            image = self.transforms(image=image)["image"]
        if image.shape[2] == 3:
            image = np.moveaxis(image, -1, 0)
        grapheme = self.df.loc[idx, "grapheme_root"]
        vowel = self.df.loc[idx, "vowel_diacritic"]
        consonant = self.df.loc[idx, "consonant_diacritic"]

        if self.onehot:
            grapheme_levels = self.cls_levels["grapheme"]
            vowel_levels = self.cls_levels["vowel"]
            consonant_levels = self.cls_levels["consonant"]
            total_n_levels = grapheme_levels + vowel_levels + consonant_levels
            label = np.zeros(total_n_levels, dtype=np.float32)
            label[grapheme] = 1.0
            label[grapheme_levels + vowel] = 1.0
            label[grapheme_levels + vowel_levels + consonant] = 1.0

        else:
            label = np.zeros(3, dtype=int)
            label[0] = grapheme
            label[1] = vowel
            label[2] = consonant
        return {"images": image, "targets": label}


class TestDataset(torchdata.Dataset):
    def __init__(self,
                 image_dir: Path,
                 df: pd.DataFrame,
                 transforms,
                 size: Tuple[int, int],
                 affine=True,
                 morphology=True):
        self.image_dir = image_dir
        self.df = df
        self.transforms = transforms
        self.size = size
        self.affine = affine
        self.morphology = morphology

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_id = self.df.loc[idx, "image_id"]
        image_path = self.image_dir / f"{image_id}.png"

        image = cv2.imread(image_path)
        image = normalize(image)
        image = crop_and_embed(image, size=self.size, threshold=5. / 255.)
        if self.affine:
            image = affine_image(image)
        if self.morphology:
            image = random_erosion_or_dilation(image)
        image = to_image(image)
        if self.transforms is not None:
            image = self.transforms(image=image)["image"]
        if image.shape[2] == 3:
            image = np.moveaxis(image, -1, 0)
        return image


def get_loader(df: pd.DataFrame,
               image_dir: Path,
               phase: str = "train",
               size: Tuple[int, int] = (128, 128),
               batch_size=256,
               num_workers=2,
               transforms=None,
               cls_levels=None,
               affine=True,
               morphology=True,
               onehot=None):
    assert phase in ["train", "valid", "test"]
    if phase == "test":
        dataset = TestDataset(image_dir, df, transforms, size, affine,
                              morphology)
        is_shuffle = False
        drop_last = False
    else:
        if phase == "train":
            is_shuffle = True
            drop_last = True
        else:
            is_shuffle = False
            drop_last = False
        if onehot is not None:
            if cls_levels is None:
                raise ValueError(
                    "if 'onehot' is set to None, cls_levels must be set")
            else:
                dataset = TrainDataset(  # type: ignore
                    image_dir,
                    df,
                    transforms,
                    size,
                    cls_levels,
                    affine=affine,
                    morphology=morphology,
                    onehot=onehot)
        else:
            dataset = TrainDataset(  # type: ignore
                image_dir,
                df,
                transforms,
                size,
                affine=affine,
                morphology=morphology)
    return torchdata.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=is_shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=drop_last)

## Model and Loss

### Model

In [0]:
class BengaliClassifier(nn.Module):
    def __init__(self, model_name: str, num_classes: int, pretrained=True):
        super().__init__()
        self.model_name = model_name
        self.num_classes = num_classes
        self.pretrained = pretrained

        if "se_resnext" in self.model_name:
            self.base = getattr(pretrainedmodels,
                                self.model_name)(pretrained=pretrained)
            self.base.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.base.last_linear = nn.Linear(
                self.base.last_linear.in_features, self.num_classes)
        elif "resnet" in self.model_name:
            self.base = getattr(pretrainedmodels,
                                self.model_name)(pretrained=pretrained)
            self.base.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.base.fc = nn.Linear(self.base.fc.in_features,
                                     self.num_classes)
        elif "efficientnet" in self.model_name:
            if pretrained:
                self.base = EfficientNet.from_pretrained(self.model_name)
            else:
                self.base = EfficientNet.from_name(self.model_name)
            self.base._fc = nn.Linear(self.base._fc.in_features,
                                      self.num_classes)
        else:
            raise NotImplementedError

    def fresh_params(self):
        if "se_resnext" in self.model_name:
            return self.base.last_linear.parameters()
        elif "resnet" in self.model_name:
            return self.base.fc.parameters()
        elif "efficientnet" in self.model_name:
            return self.base._fc.parameters()
        else:
            raise NotImplementedError

    def base_params(self):
        params = []
        if "se_resnext" in self.model_name:
            fc_name = "last_linear"
        elif "resnet" in self.model_name:
            fc_name = "fc"
        elif "efficientnet" in self.model_name:
            fc_name = "_fc"
        else:
            raise NotImplementedError
        for name, param in self.net.named_parameters():
            if fc_name not in name:
                params.append(param)
        return params

    def forward(self, x):
        return self.base(x)

### Loss

In [0]:
class BengaliCrossEntropyLoss(nn.Module):
    def __init__(self, n_grapheme: int, n_vowel: int, n_consonant: int):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, 0]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, 1]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, 2]

        return self.cross_entropy(grapheme_pred, grapheme_true) + \
            self.cross_entropy(vowel_pred, vowel_true) + \
            self.cross_entropy(consonant_pred, consonant_true)


class BengaliBCELoss(nn.Module):
    def __init__(self, n_grapheme: int, n_vowel: int, n_consonant: int):
        super().__init__()
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred, true):
        head = 0
        tail = self.n_grapheme
        grapheme_pred = pred[:, head:tail]
        grapheme_true = true[:, head:tail]

        head = tail
        tail = head + self.n_vowel
        vowel_pred = pred[:, head:tail]
        vowel_true = true[:, head:tail]

        head = tail
        tail = head + self.n_consonant
        consonant_pred = pred[:, head:tail]
        consonant_true = true[:, head:tail]

        return self.bce(grapheme_pred, grapheme_true) + \
            self.bce(vowel_pred, vowel_true) + \
            self.bce(consonant_pred, consonant_true)


def get_loss(config: edict):
    name = config.loss.name
    params = config.loss.params
    if name == "bce":
        criterion = BengaliBCELoss(**params)
    elif name == "cross_entropy":
        criterion = BengaliCrossEntropyLoss(**params)  # type: ignore
    else:
        raise NotImplementedError
    return criterion

## Optimizer and Scheduler

### Optimizer

In [0]:
Optimizer = Union[Adam, SGD]


def get_optimizer(model, config: edict) -> Optimizer:
    name = config.optimizer.name
    params = config.optimizer.params
    if name == "Adam":
        optimizer = Adam(model.parameters(), **params)
    elif name == "SGD":
        optimizer = Adam(model.parameters(), **params)
    else:
        raise NotImplementedError
    return optimizer

### Scheduler

In [0]:
Scheduler = Optional[
    Union[ReduceLROnPlateau, CosineAnnealingLR, CosineAnnealingWarmRestarts]]


def get_scheduler(optimizer, config: edict) -> Scheduler:
    params = config.scheduler.params
    name = config.scheduler.name
    scheduler: Scheduler = None
    if name == "plateau":
        scheduler = ReduceLROnPlateau(optimizer, **params)
    elif name == "cosine":
        scheduler = CosineAnnealingLR(optimizer, **params)
    elif name == "cosine_warmup":
        scheduler = CosineAnnealingWarmRestarts(optimizer, **params)

    return scheduler

## Callbacks

In [0]:
class MacroAverageRecall(Callback):
    def __init__(self,
                 n_grapheme=168,
                 n_vowel=11,
                 n_consonant=7,
                 loss_type: str = "bce",
                 prefix: str = "mar",
                 output_key: str = "logits",
                 target_key: str = "targets"):
        self.prefix = prefix
        self.output_key = output_key
        self.target_key = target_key
        self.n_grapheme = n_grapheme
        self.n_vowel = n_vowel
        self.n_consonant = n_consonant
        self.loss_type = loss_type
        super().__init__(CallbackOrder.Metric)

    def on_batch_end(self, state: RunnerState):
        targ = state.input[self.target_key].detach()
        out = state.output[self.output_key]
        head = 0
        tail = self.n_grapheme
        grapheme = torch.sigmoid(out[:, head:tail])
        grapheme_np = torch.argmax(grapheme, dim=1).detach().cpu().numpy()
        if self.loss_type == "bce":
            grapheme_target = torch.argmax(
                targ[:, head:tail], dim=1).cpu().numpy()
        else:
            grapheme_target = targ[:, 0].cpu().numpy()

        head = tail
        tail = head + self.n_vowel
        vowel = torch.sigmoid(out[:, head:tail])
        vowel_np = torch.argmax(vowel, dim=1).detach().cpu().numpy()
        if self.loss_type == "bce":
            vowel_target = torch.argmax(
                targ[:, head:tail], dim=1).cpu().numpy()
        else:
            vowel_target = targ[:, 1].cpu().numpy()

        head = tail
        tail = head + self.n_consonant
        consonant = torch.sigmoid(out[:, head:tail])
        consonant_np = torch.argmax(consonant, dim=1).detach().cpu().numpy()
        if self.loss_type == "bce":
            consonant_target = torch.argmax(
                targ[:, head:tail], dim=1).cpu().numpy()
        else:
            consonant_target = targ[:, 2].cpu().numpy()

        scores = []
        scores.append(
            recall_score(
                grapheme_target, grapheme_np, average="macro",
                zero_division=0))
        scores.append(
            recall_score(
                vowel_target, vowel_np, average="macro", zero_division=0))
        scores.append(
            recall_score(
                consonant_target,
                consonant_np,
                average="macro",
                zero_division=0))
        final_score = np.average(scores, weights=[2, 1, 1])
        state.metrics.add_batch_value(name=self.prefix, value=final_score)


class SaveWeightsCallback(Callback):
    def __init__(self, to: Optional[Path] = None, name: str=""):
        self.to = to
        self.name = name
        super().__init__(CallbackOrder.External)

    def on_epoch_end(self, state: RunnerState):
        weights = state.model.state_dict()
        logdir = state.logdir / "checkpoints"
        logdir.mkdir(exist_ok=True, parents=True)
        if self.name == "":
            torch.save(weights, logdir / "temp.pth")
        else:
            torch.save(weights, logdir / f"{self.name}.pth")

        if self.to is not None:
            if self.name == "":
                torch.save(weights, self.to / "temp.pth")
            else:
                torch.save(weights, self.to / f"{self.name}.pth")

## KFold Training

In [0]:
trn_idx, val_idx = splits[i]

print(f"Fold: {i}")

output_dir = output_base_dir / f"fold{i}"
output_dir.mkdir(exist_ok=True, parents=True)

trn_df = df.loc[trn_idx, :].reset_index(drop=True)
val_df = df.loc[val_idx, :].reset_index(drop=True)
data_loaders = {
    phase: get_loader(
        df,
        train_images_path,
        phase=phase,
        size=(config.img_size, config.img_size),
        batch_size=config.train.batch_size,
        num_workers=config.num_workers,
        transforms=transforms,
        cls_levels=cls_levels,
        affine=config.dataset.train.affine
        if phase == "train" else config.dataset.val.affine,
        morphology=config.dataset.train.morphology
        if phase == "train" else config.dataset.val.morphology,
        onehot=config.loss.name == "bce")
    for phase, df in zip(["train", "valid"], [trn_df, val_df])
}
model = BengaliClassifier(**config.model)
criterion = get_loss(config)
optimizer = get_optimizer(model, config)
scheduler = get_scheduler(optimizer, config)
callbacks = [
    MacroAverageRecall(
        n_grapheme=cls_levels["grapheme"],
        n_vowel=cls_levels["vowel"],
        n_consonant=cls_levels["consonant"],
        loss_type=config.loss.name),
    SaveWeightsCallback(
        to=Path(config.checkpoints
                ) if config.checkpoints is not None else None,
        name=trial)
]

runner = SupervisedRunner(
    device=ct.utils.get_device(),
    input_key="images",
    input_target_key="targets",
    output_key="logits")
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=data_loaders,
    logdir=output_dir,
    scheduler=scheduler,
    num_epochs=config.train.num_epochs,
    callbacks=callbacks,
    main_metric="mar",
    minimize_metric=False,
    monitoring_params=None,
    verbose=False)

Fold: 0
[2020-01-27 16:30:41,520] 
1/50 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=982.7494 | _timers/batch_time=0.0330 | _timers/data_time=0.0010 | _timers/model_time=0.0319 | loss=4.1251 | mar=0.4379
1/50 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=649.4201 | _timers/batch_time=0.0576 | _timers/data_time=0.0133 | _timers/model_time=0.0442 | loss=1.6391 | mar=0.6934
[2020-01-27 16:30:41,520] 
1/50 * Epoch 1 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=982.7494 | _timers/batch_time=0.0330 | _timers/data_time=0.0010 | _timers/model_time=0.0319 | loss=4.1251 | mar=0.4379
1/50 * Epoch 1 (valid): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=649.4201 | _timers/batch_time=0.0576 | _timers/data_time=0.0133 | _timers/model_time=0.0442 | loss=1.6391 | mar=0.6934
[2020-01-27 16:42:59,981] 
2/50 * Epoch 2 (train): _base/lr=0.0010 | _base/momentum=0.9000 | _timers/_fps=1004.0703 | _timers/batch_time=0.0324 | 