In [1]:
!pip install catalyst==20.01



In [0]:
import pandas as pd
import numpy as np
import os
from PIL import Image as PImage
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from typing import List
import catalyst
%matplotlib inline

In [3]:
# TO save for later w drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [0]:
DATA_FOLDER = '/content/drive/My Drive/bengali'
TRAIN_DATA = '/content/drive/My Drive/bengali/train'

In [5]:
train_data = pd.read_csv(os.path.join(DATA_FOLDER, './train_data.tsv'))
train_data[['grapheme_root', 'vowel_diacritic', 'consonant_diacritic']] = train_data[['grapheme_root', 'vowel_diacritic', 'consonant_diacritic']].astype('uint8')
train_data.drop(['test_fold'],axis=1,inplace=True)
train_data['type'] = 'train'
validation_data = pd.read_csv(os.path.join(DATA_FOLDER, './validation_data.tsv'))
validation_data[['grapheme_root', 'vowel_diacritic', 'consonant_diacritic']] = validation_data[['grapheme_root', 'vowel_diacritic', 'consonant_diacritic']].astype('uint8')
validation_data['type'] = 'validation'
all_data = pd.concat([train_data, validation_data])
all_data.head()

Unnamed: 0,image_id,grapheme_root,vowel_diacritic,consonant_diacritic,type
0,Train_0,15,9,5,train
1,Train_1,159,0,0,train
2,Train_2,22,3,5,train
3,Train_3,53,2,2,train
4,Train_4,71,9,5,train


In [0]:
train_image_ids = all_data['image_id'][all_data['type']=='train'].values
valid_image_ids = all_data['image_id'][all_data['type']!='train'].values
train_mask = all_data['type']=='train'
val_mask = all_data['type']!='train'

Prepare dataset 

In [0]:
image_data = []
for i in range(4):
  chunk = pd.read_parquet(os.path.join(TRAIN_DATA, 'train_image_data_{}.parquet'.format(i)))
  chunk.index = chunk.image_id
  chunk.drop(['image_id'],axis=1,inplace=True)
  chunk.astype(np.uint8)
  image_data.append(chunk)

In [8]:
image_data = pd.concat(image_data)
image_data.head()

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,32292,32293,32294,32295,32296,32297,32298,32299,32300,32301,32302,32303,32304,32305,32306,32307,32308,32309,32310,32311,32312,32313,32314,32315,32316,32317,32318,32319,32320,32321,32322,32323,32324,32325,32326,32327,32328,32329,32330,32331
image_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1
Train_0,254,253,252,253,251,252,253,251,251,253,254,253,253,253,254,253,252,253,253,253,253,252,252,253,253,252,252,253,252,252,252,253,254,253,253,252,252,252,253,252,...,252,252,252,252,252,252,252,252,252,252,252,252,252,252,252,252,252,252,252,252,252,252,253,253,253,253,253,253,253,253,253,253,253,253,253,253,253,253,253,251
Train_1,251,244,238,245,248,246,246,247,251,252,250,250,246,249,248,250,249,251,252,253,253,253,253,253,253,253,250,249,251,252,251,251,251,251,252,253,251,250,252,251,...,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,254
Train_2,251,250,249,250,249,245,247,252,252,252,253,252,252,251,250,251,253,254,251,251,252,252,253,253,252,251,251,249,251,252,252,253,252,251,251,251,250,250,252,253,...,253,253,252,252,252,253,253,253,253,253,252,251,251,250,250,250,251,251,251,250,250,250,251,252,253,253,253,253,254,254,254,253,252,252,253,253,253,253,251,249
Train_3,247,247,249,253,253,252,251,251,250,250,251,250,249,251,251,251,250,252,251,245,245,251,252,251,252,252,250,249,250,251,250,249,250,251,252,253,252,252,252,252,...,253,252,252,254,253,253,254,253,252,253,254,253,252,253,254,254,254,254,254,254,254,254,253,252,253,254,253,252,253,254,254,254,254,254,254,253,253,252,251,252
Train_4,249,248,246,246,248,244,242,242,229,225,231,229,229,228,221,224,226,221,221,220,217,217,218,219,222,224,214,218,227,227,227,228,224,231,235,235,233,212,183,196,...,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,255


Prepare Callbacks and Dataset for training

In [0]:
class ImageDataset(Dataset):
    def __init__(self, 
                 df, 
                 labels,
                 transforms=None):
        
        self.df = df
        self.labels = labels
        self.transforms = transforms
        
    def __getitem__(self, idx):
        flattened_image = self.df.iloc[idx].values.astype(np.uint8)
        image = np.expand_dims(flattened_image.reshape(137, 236), 2)
        
        grapheme_root =  self.labels['grapheme_root'].values[idx]
        vowel_diacritic = self.labels['vowel_diacritic'].values[idx]
        consonant_diacritic = self.labels['consonant_diacritic'].values[idx]
        
        if self.transforms is not None:
            augmented = self.transforms(image=image)
            image = augmented['image']
        
        image = torch.from_numpy(image.transpose((2,0,1)))
        grapheme_root = torch.tensor(grapheme_root).long()
        vowel_diacritic = torch.tensor(vowel_diacritic).long()
        consonant_diacritic = torch.tensor(consonant_diacritic).long() 
        
        output_dict  = {
            'grapheme_root' : grapheme_root, 
            'vowel_diacritic' : vowel_diacritic, 
            'consonant_diacritic' : consonant_diacritic, 
            'image' : image
                       }

        return output_dict

    def __len__(self):
        return len(self.df)

Make train and validation datasets

In [0]:
batch_size = 128
num_workers = 1

In [0]:
from albumentations.core.transforms_interface import DualTransform
from albumentations.augmentations import functional as F
class GridDropout(DualTransform):
    """
    GridDropout, drops out rectangular regions of an image and the corresponding mask in a grid fashion.
        Args:
            ratio (float): the ratio of the mask holes to the unit_size (same for horizontal and vertical directions).
                Must be between 0 and 1. Default: 0.5.
            unit_size_min (int): minimum size of the grid unit. Must be between 2 and the image shorter edge.
                If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`.
            unit_size_max (int): maximum size of the grid unit. Must be between 2 and the image shorter edge.
                If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`.
            holes_number_x (int): the number of grid units in x direction. Must be between 1 and image width//2.
                If 'None', grid unit width is set as image_width//10. Default: `None`.
            holes_number_y (int): the number of grid units in y direction. Must be between 1 and image height//2.
                If `None`, grid unit height is set equal to the grid unit width or image height, whatever is smaller.
            shift_x (int): offsets of the grid start in x direction from (0,0) coordinate.
                Clipped between 0 and grid unit_width - hole_width. Default: 0.
            shift_y (int): offsets of the grid start in y direction from (0,0) coordinate.
                Clipped between 0 and grid unit_width - hole_width. Default: 0.
            shift_y (int): offsets of the grid start in y direction from (0,0) coordinate.
                Clipped between 0 and grid unit height - hole_height. Default: 0.
            random_offset (boolean): weather to offset the grid randomly between 0 and grid unit size - hole size
                If 'True', entered shift_x, shift_y are ignored and set randomly. Default: `False`.
            fill_value (int): value for the dropped pixels. Default = 0
            mask_fill_value (int): value for the dropped pixels in mask.
                If `None`, tranformation is not applied to the mask. Default: `None`.
        Targets:
            image, mask
        Image types:
            uint8, float32
        References:
            https://arxiv.org/abs/2001.04086
    """
    def __init__(
        self,
        ratio: float = 0.5,
        unit_size_min: int = None,
        unit_size_max: int = None,
        holes_number_x: int = None,
        holes_number_y: int = None,
        shift_x: int = 0,
        shift_y: int = 0,
        random_offset: bool = False,
        fill_value: int = 0,
        mask_fill_value: int = None,
        always_apply: bool = False,
        p: float = 0.5,
    ):
        super(GridDropout, self).__init__(always_apply, p)
        self.ratio = ratio
        self.unit_size_min = unit_size_min
        self.unit_size_max = unit_size_max
        self.holes_number_x = holes_number_x
        self.holes_number_y = holes_number_y
        self.shift_x = shift_x
        self.shift_y = shift_y
        self.random_offset = random_offset
        self.fill_value = fill_value
        self.mask_fill_value = mask_fill_value
        if not 0 < self.ratio <= 1:
            raise ValueError("ratio must be between 0 and 1.")

    def apply(self, image, holes=[], **params):
        return F.cutout(image, holes, self.fill_value)

    def apply_to_mask(self, image, holes=[], **params):
        if self.mask_fill_value is None:
            return image
        else:
            return F.cutout(image, holes, self.mask_fill_value)

    def get_params_dependent_on_targets(self, params):
        img = params["image"]
        height, width = img.shape[:2]
        # set grid using unit size limits
        if self.unit_size_min and self.unit_size_max:
            if not 2 <= self.unit_size_min <= self.unit_size_max:
                raise ValueError("Max unit size should be >= min size, both at least 2 pixels.")
            if self.unit_size_max > min(height, width):
                raise ValueError("Grid size limits must be within the shortest image edge.")
            unit_width = random.randint(self.unit_size_min, self.unit_size_max + 1)
            unit_height = unit_width
        else:
            # set grid using holes numbers
            if self.holes_number_x is None:
                unit_width = max(2, width // 10)
            else:
                if not 1 <= self.holes_number_x <= width // 2:
                    raise ValueError("The hole_number_x must be between 1 and image width//2.")
                unit_width = width // self.holes_number_x
            if self.holes_number_y is None:
                unit_height = max(min(unit_width, height), 2)
            else:
                if not 1 <= self.holes_number_y <= height // 2:
                    raise ValueError("The hole_number_y must be between 1 and image height//2.")
                unit_height = height // self.holes_number_y

        hole_width = int(unit_width * self.ratio)
        hole_height = int(unit_height * self.ratio)
        # min 1 pixel and max unit length - 1
        hole_width = min(max(hole_width, 1), unit_width - 1)
        hole_height = min(max(hole_height, 1), unit_height - 1)
        # set offset of the grid
        if self.shift_x is None:
            shift_x = 0
        else:
            shift_x = min(max(0, self.shift_x), unit_width - hole_width)
        if self.shift_y is None:
            shift_y = 0
        else:
            shift_y = min(max(0, self.shift_y), unit_height - hole_height)
        if self.random_offset:
            shift_x = random.randint(0, unit_width - hole_width)
            shift_y = random.randint(0, unit_height - hole_height)
        holes = []
        for i in range(width // unit_width + 1):
            for j in range(height // unit_height + 1):
                x1 = min(shift_x + unit_width * i, width)
                y1 = min(shift_y + unit_height * j, height)
                x2 = min(x1 + hole_width, width)
                y2 = min(y1 + hole_height, height)
                holes.append((x1, y1, x2, y2))

        return {"holes": holes}

    @property
    def targets_as_params(self):
        return ["image"]

    def get_transform_init_args_names(self):
        return (
            "ratio",
            "unit_size_min",
            "unit_size_max",
            "holes_number_x",
            "holes_number_y",
            "shift_x",
            "shift_y",
            "mask_fill_value",
            "random_offset",
        )
    
class AugMix(DualTransform):
    """Augmentations mix to Improve Robustness and Uncertainty.

    Args:
        image (np.ndarray): Raw input image of shape (h, w, c)
        severity (int): Severity of underlying augmentation operators.
        width (int): Width of augmentation chain
        depth (int): Depth of augmentation chain. -1 enables stochastic depth uniformly
          from [1, 3]
        alpha (float): Probability coefficient for Beta and Dirichlet distributions.
        augmentations (list of augmentations): Augmentations that need to mix and perform.

    Target:
        image

    Image types:
        uint8, float32

    Returns:
        mixed: Augmented and mixed image.
      
    Reference:
    |   https://arxiv.org/abs/1912.02781
    |   https://github.com/google-research/augmix
    """
    def __init__(self, width=4, 
                 depth=3, 
                 alpha=0.5,
                 augmentations=None,
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225], 
                 always_apply=False,
                 resize_width = None,
                 resize_height = None,
                 p=0.5):
        super(AugMix, self).__init__(always_apply, p)
        if isinstance(augmentations, (list, tuple)):
            self.augmentations = augmentations
        else:
            raise ValueError("Augmentations list should be passed to 'augmentations' argument.")
        self.width = width
        self.depth = depth
        self.alpha = alpha
        self.mean = mean
        self.std = std
        self.resize_width = resize_width
        self.resize_height = resize_height

    def apply_op(self, image, op):
        image = np.clip(image * 255., 0, 255).astype(np.uint8)\
              if 'float32' not in op.__doc__\
              else image
        image = op(image=image)['image']
        return image

    def apply(self, img, **params):
        ws = np.float32(np.random.dirichlet([self.alpha] * self.width))
        m = np.float32(np.random.beta(self.alpha, self.alpha))

        mix = np.float32(np.zeros_like(img))
        for i in range(self.width):
            image_aug = img.copy()

            for _ in range(self.depth):
                op = np.random.choice(self.augmentations)
                image_aug = self.apply_op(img, op)

        # Preprocessing commutes since all coefficients are convex
        mix = np.add(mix, ws[i] * F.normalize(image_aug, mean=self.mean, std=self.std), out=mix, casting="unsafe")
        mixed = (1 - m) * F.normalize(img, mean=self.mean, std=self.std) + m * mix
        if self.resize_height is not None and self.resize_width is not None:
            mixed = F.resize(mixed, height = self.resize_height, width = self.resize_width)
        return mixed
 
    def get_transform_init_args_names(self):
        return ('width', 'depth', 'alpha', 'mean', 'std', 'height', 'width')


In [0]:
from albumentations import (
    HorizontalFlip, ShiftScaleRotate,
    GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, Cutout
)
import albumentations as A

augs = [HorizontalFlip(always_apply=True),
        MotionBlur(always_apply=True),
        ShiftScaleRotate(always_apply=True),
        GaussNoise(always_apply=True),
        MedianBlur(always_apply=True),
        Cutout(always_apply=True),
        GridDropout(always_apply=True)]

In [0]:
transforms_train = A.Compose([
    AugMix(width=3, 
           depth=8,
           alpha=.25, 
           p=1., 
           augmentations=augs, 
           mean=[0.5], 
           std=[0.5],
           resize_height = 64,
           resize_width = 128),
])
transforms_val = A.Compose([
    A.Resize(width = 128, 
             height = 64),
    A.Normalize(mean=(0.5), 
                std=(0.5))
])

Callbacks for catalyst

In [14]:
from sklearn.metrics import recall_score
from catalyst.dl import Callback, RunnerState, MetricCallback, CallbackOrder, CriterionCallback

class TaskMetricCallback(Callback):
    '''
    Proposed metrics:
    import numpy as np
    import sklearn.metrics

    scores = []
    for component in ['grapheme_root', 'consonant_diacritic', 'vowel_diacritic']:
        y_true_subset = solution[solution[component] == component]['target'].values
        y_pred_subset = submission[submission[component] == component]['target'].values
        scores.append(sklearn.metrics.recall_score(
            y_true_subset, y_pred_subset, average='macro'))
    final_score = np.average(scores, weights=[2,1,1])
    '''

    def __init__(
        self, 
        input_key: str = ['grapheme_root', 'consonant_diacritic', 'vowel_diacritic'], 
        output_key: str = ['grapheme_root', 'consonant_diacritic', 'vowel_diacritic'],
        class_names: str = ['grapheme_root', 'consonant_diacritic', 'vowel_diacritic'],
        prefix: str = "taskmetric", 
        ignore_index=None
    ):
        super().__init__(CallbackOrder.Metric)
        self.metric_fn = lambda outputs, targets: recall_score(targets, outputs, average="macro")
        self.prefix = prefix
        self.output_key = output_key
        self.input_key = input_key
        self.class_names = class_names
        self.outputs = [[] for i in range(3)]
        self.targets = [[] for i in range(3)]

    def on_batch_end(self, state: RunnerState):
        
        for i in range(3):
            outputs = state.output[self.output_key[i]].detach().cpu().numpy()
            targets = state.input[self.input_key[i]].detach().cpu().numpy()
            #num_classes = outputs.shape[1]
            outputs = np.argmax(outputs, axis=1)
            #outputs = [np.eye(num_classes)[y] for y in outputs]
            #targets = [np.eye(num_classes)[y] for y in targets]
            self.outputs[i].extend(outputs)
            self.targets[i].extend(targets)

    def on_loader_start(self, state):
        self.outputs = [[] for i in range(3)]
        self.targets = [[] for i in range(3)]

    def on_loader_end(self, state):
        metric_name = self.prefix
        score_vec = []
        for i in range(3):
            targets = np.array(self.targets[i])
            outputs = np.array(self.outputs[i])
            metric = self.metric_fn(outputs, targets)
            score_vec.append(metric)
            state.metrics.epoch_values[state.loader_name][self.class_names[i]] = float(metric)
            
            
        state.metrics.epoch_values[state.loader_name][metric_name] = np.average(score_vec, weights=[2,1,1])

alchemy not available, to install alchemy, run `pip install alchemy-catalyst`.


In [0]:
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride=1,kernel_size=3,padding=1,bias=False):
        super(ResidualBlock,self).__init__()
        self.cnn1 =nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )
        self.cnn2 = nn.Sequential(
            nn.Conv2d(out_channels,out_channels,kernel_size,1,padding,bias=False),
            nn.BatchNorm2d(out_channels)
        )
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride,bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Sequential()
    def forward(self,x):
        residual = x
        x = self.cnn1(x)
        x = self.cnn2(x)
        x += self.shortcut(residual)
        x = nn.ReLU(True)(x)
        return x
class ResNet18(nn.Module):    
    def __init__(self):
        super(ResNet18,self).__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(1,64,kernel_size=2,stride=2,padding=3,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        
        self.block2 = nn.Sequential(
            nn.MaxPool2d(1,1),
            ResidualBlock(64,64),
            ResidualBlock(64,64,2)
        )
        
        self.block3 = nn.Sequential(
            ResidualBlock(64,128),
            ResidualBlock(128,128,2)
        )
        self.block4 = nn.Sequential(
            ResidualBlock(128,256),
            ResidualBlock(256,256,2)
        )
        self.block5 = nn.Sequential(
            ResidualBlock(256,512),
            ResidualBlock(512,512,2)
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        # vowel_diacritic
        self.fc1 = nn.Linear(512,11)
        # grapheme_root
        self.fc2 = nn.Linear(512,168)
        # consonant_diacritic
        self.fc3 = nn.Linear(512,7)
        
    def forward(self,x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.avgpool(x)
        x = x.view(x.size(0),-1)
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x3 = self.fc3(x)
        return({'vowel_diacritic':x1,
                'grapheme_root':x2,
                'consonant_diacritic':x3})
class ResNet34(nn.Module):    
    def __init__(self):
        super(ResNet34,self).__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(1,64,kernel_size=2,stride=2,padding=3,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        
        self.block2 = nn.Sequential(
            nn.MaxPool2d(1,1),
            ResidualBlock(64,64),
            ResidualBlock(64,64,2)
        )
        
        self.block3 = nn.Sequential(
            ResidualBlock(64,128),
            ResidualBlock(128,128,2)
        )
        
        self.block4 = nn.Sequential(
            ResidualBlock(128,256),
            ResidualBlock(256,256,2)
        )
        self.block5 = nn.Sequential(
            ResidualBlock(256,512),
            ResidualBlock(512,512,2)
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        # vowel_diacritic
        self.fc1 = nn.Linear(512,11)
        # grapheme_root
        self.fc2 = nn.Linear(512,168)
        # consonant_diacritic
        self.fc3 = nn.Linear(512,7)
    def forward(self,x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.avgpool(x)
        x = x.view(x.size(0),-1)
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x3 = self.fc3(x)
        return({'vowel_diacritic':x1,
                'grapheme_root':x2,
                'consonant_diacritic':x3})

In [0]:
import collections
from catalyst.utils import set_global_seed
from catalyst.dl.runner import SupervisedRunner
from catalyst.dl.callbacks import CriterionCallback, CriterionAggregatorCallback, EarlyStoppingCallback

In [17]:
set_global_seed(42)

In [0]:
def baseile_train_resnet18():
  train_dataset = ImageDataset(df = image_data.loc[train_image_ids,:], 
                              labels = all_data.loc[train_mask, :], 
                              transforms = transforms_train
                              )
  val_dataset = ImageDataset(df = image_data.loc[valid_image_ids,:],
                            labels = all_data.loc[val_mask, :], 
                            transforms = transforms_val
                            ) 
  train_loader = DataLoader(
      train_dataset,
      batch_size=batch_size,
      num_workers=num_workers,
      pin_memory=True,
      shuffle=True
  )
  val_loader = DataLoader(
      val_dataset,
      batch_size=batch_size,
      num_workers=num_workers,
      pin_memory=True,
      shuffle=False   
      )
  model = ResNet18().cuda()
  loaders = collections.OrderedDict()
  loaders["train"] = train_loader
  loaders["valid"] = val_loader
  runner = SupervisedRunner(input_key='image',
                            input_target_key=None, 
                            output_key=None)
  optimizer = torch.optim.AdamW(
      model.parameters(), 
      lr=3e-4, 
      weight_decay=0.01)  
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
      optimizer,
      factor=0.1, 
      patience=5) 
  criterions_dict = {'vowel_diacritic_loss':torch.nn.CrossEntropyLoss(), 
                    'grapheme_root_loss':torch.nn.CrossEntropyLoss(),
                    'consonant_diacritic_loss':torch.nn.CrossEntropyLoss(),}
  callbacks=[
      CriterionCallback(input_key='grapheme_root',
                        output_key='grapheme_root',
                        prefix='grapheme_root_loss',
                        criterion_key='grapheme_root_loss', multiplier=2.0),
      CriterionCallback(input_key='vowel_diacritic',
                        output_key='vowel_diacritic',
                        prefix='vowel_diacritic_loss',
                        criterion_key='vowel_diacritic_loss', 
                        multiplier=1.0),
      CriterionCallback(input_key='consonant_diacritic',
                        output_key='consonant_diacritic',
                        prefix='consonant_diacritic_loss',
                        criterion_key='consonant_diacritic_loss', 
                        multiplier=1.0),
      CriterionAggregatorCallback(prefix='loss',
                                  loss_keys=['grapheme_root_loss',
                                            'vowel_diacritic_loss',
                                            'consonant_diacritic_loss']),
      TaskMetricCallback(), 
      EarlyStoppingCallback(patience = 7)] 
  runner.train(
      model=model,
      main_metric='loss',
      minimize_metric=True,
      criterion=criterions_dict,
      optimizer=optimizer,
      callbacks=callbacks,
      loaders=loaders,
      logdir=os.path.join(DATA_FOLDER, './baseline_resnet18_validation_train_split'),
      scheduler=scheduler,
      num_epochs=50,
      verbose=True)      

In [0]:
baseile_train_resnet18()