In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
%matplotlib inline

import os
import sys
from pathlib import Path
import pandas as pd

from fastai.vision import *
from fastai.callbacks import *
from fastai.layers import AdaptiveConcatPool2d
import torch
from efficientnet_pytorch import EfficientNet
import wandb
from wandb.fastai import WandbCallback
import timm
import pretrainedmodels
from kornia.losses import FocalLoss
from fastai2.layers import MishJit
from sklearn.utils.class_weight import compute_class_weight

from sklearn.metrics import recall_score


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

## Params

In [2]:
NAME = ''
WANDB_MODE = 'dryrun'

DATA_PATH = Path('/home/lextoumbourou/bengaliai-cv19/data')
IMAGE_DATA_PATH = Path(DATA_PATH/'grapheme-imgs-128x128')
OUTPUT_PATH = Path(DATA_PATH/'working')
LABELS_PATH = Path(DATA_PATH/'iterative-stratification')

VALID_PCT = 0.2
SEED = 420
BATCH_SIZE = 128
IMG_SIZE = 128

MAX_WARP = 0.2
P_AFFINE = 0.75
MAX_ROTATE = 40.
MAX_ZOOM = 1.1
P_LIGHTING = 0.75
MAX_LIGHTING = 0.2
MAX_COUNT_RANDOM_ERASING = 3
LABEL_SMOOTHING_EPS = 0.1

MAX_EPOCHS = 150

ENCODER_ARCH = 'efficientnet-b0'

GRAPHEME_ROOT_WEIGHT = 0.7
VOWEL_DIACRITIC_WEIGHT = 0.1
CONSONANT_DIACRITIC_WEIGHT = 0.2
MODEL_HEAD = 'mish_head'

SAMPLE_SIZE = None

PROG_SPRINKES = False

In [3]:
# Parameters
ENCODER_ARCH = "efficientnet-b0"
BATCH_SIZE = 192
WANDB_MODE = "run"
NAME = "effb0_with_class_weights"
OUTPUT_VAL_SIZE = None


In [4]:
seed_everything(seed=SEED)

In [5]:
!wandb login 563765550fd7b64fd10129216209724e03f3f20c

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/lextoumbourou/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [6]:
# Turn this off before running!!
os.environ['WANDB_MODE'] = WANDB_MODE

In [7]:
wandb.init(project="bengaliai-cv19", name=NAME)

wandb.config.img_size = IMG_SIZE
wandb.config.batch_size = BATCH_SIZE
wandb.config.seed = SEED

wandb.config.max_warp = MAX_WARP
wandb.config.p_affine = P_AFFINE
wandb.config.max_rotate = MAX_ROTATE
wandb.config.max_zoom = MAX_ZOOM
wandb.config.p_lighting = P_LIGHTING
wandb.config.max_lighting = MAX_LIGHTING
wandb.config.max_count_random_erasing = MAX_COUNT_RANDOM_ERASING
wandb.config.grapheme_root_weight = GRAPHEME_ROOT_WEIGHT
wandb.config.vowel_diacritic_weight = VOWEL_DIACRITIC_WEIGHT
wandb.config.consonant_diacritic_weight = CONSONANT_DIACRITIC_WEIGHT
wandb.config.sample_size = SAMPLE_SIZE
wandb.config.encoder_arch = ENCODER_ARCH
wandb.config.max_epochs = MAX_EPOCHS
wandb.config.prog_sprinkles = PROG_SPRINKES
wandb.config.label_smoothing_eps = LABEL_SMOOTHING_EPS
wandb.config.encoder_arch = ENCODER_ARCH

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


wandb: Wandb version 0.8.26 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


## Create datasets and dataloaders

In [8]:
train_df = pd.read_csv(LABELS_PATH/'train_with_fold.csv')
if wandb.config.sample_size:
    print("About to reduce train size.")
    train_df = train_df.sample(n=wandb.config.sample_size, random_state=SEED).reset_index(drop=True)

In [9]:
train_df.head()

Unnamed: 0,image_id,grapheme_root,vowel_diacritic,consonant_diacritic,grapheme,id,fold
0,Train_0,15,9,5,ক্ট্রো,0,3
1,Train_1,159,0,0,হ,1,2
2,Train_2,22,3,5,খ্রী,2,4
3,Train_3,53,2,2,র্টি,3,2
4,Train_4,71,9,5,থ্রো,4,1


In [10]:
train_df.grapheme_root.value_counts().plot.bar(figsize=(20, 4), title="Grapheme root", rot=90)

<matplotlib.axes._subplots.AxesSubplot at 0x7f7d69bf8150>

In [11]:
train_df.vowel_diacritic.value_counts().plot.bar(title="Vowel diacrititc")

<matplotlib.axes._subplots.AxesSubplot at 0x7f7d69bf8150>

In [12]:
train_df.consonant_diacritic.value_counts().plot.bar(title="Consonant diacrititc")

<matplotlib.axes._subplots.AxesSubplot at 0x7f7d69bf8150>

In [13]:
imagenet_stats

([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

In [14]:
train_tfms = [
    symmetric_warp(
        magnitude=(-wandb.config.max_warp, wandb.config.max_warp),
        p=wandb.config.p_affine
    ),
    rotate(
        degrees=(-wandb.config.max_rotate, wandb.config.max_rotate),
        p=wandb.config.p_affine
    ),
    rand_zoom(
        scale=(1., wandb.config.max_zoom), p=wandb.config.p_affine
    ),
    brightness(
        change=(0.5*(1 - wandb.config.max_lighting), 0.5*(1 + wandb.config.max_lighting)),
        p=wandb.config.p_lighting
    ),
    contrast(
        scale=(1-wandb.config.max_lighting, 1/(1-wandb.config.max_lighting)),
        p=wandb.config.p_lighting
    ),
    cutout(n_holes=(1, 6), length=(5, 15), p=.5)
]

In [15]:
train_df['is_valid'] = train_df.fold == 0

In [16]:
grapheme_root_class_weight = compute_class_weight('balanced', range(train_df.grapheme_root.max() + 1), y=train_df.grapheme_root)
vowel_diacritic_class_weight = compute_class_weight('balanced', range(train_df.vowel_diacritic.max() + 1), y=train_df.vowel_diacritic)
consonant_diacritic_class_weight = compute_class_weight('balanced', range(train_df.consonant_diacritic.max() + 1), y=train_df.consonant_diacritic)

In [17]:
stats = ([0.485], [0.229])

data = (
    ImageList.from_df(
        path=DATA_PATH, df=train_df, folder='./grapheme-imgs-128x128/', suffix='.png',
        cols='image_id', convert_mode='L'
    )
    .split_from_df(col='is_valid')
    .label_from_df(cols=['grapheme_root', 'vowel_diacritic', 'consonant_diacritic'])
    .transform((train_tfms, []), size=wandb.config.img_size, padding_mode='zeros')
    .databunch(bs=wandb.config.batch_size)
).normalize(stats)

In [18]:
data.show_batch()

In [19]:
class MixUpLoss(Module):
    "Adapt the loss function `crit` to go with mixup."
    
    def __init__(self, crit, reduction='mean'):
        super().__init__()
        if hasattr(crit, 'reduction'): 
            self.crit = crit
            self.old_red = crit.reduction
            setattr(self.crit, 'reduction', 'none')
        else: 
            self.crit = partial(crit, reduction='none')
            self.old_crit = crit
        self.reduction = reduction
        
    def forward(self, output, target):
        if len(target.shape) == 2 and target.shape[1] == 7:
            loss1, loss2 = self.crit(output,target[:,0:3].long()), self.crit(output,target[:,3:6].long())
            d = loss1 * target[:,-1] + loss2 * (1-target[:,-1])
        else:  d = self.crit(output, target)
        if self.reduction == 'mean':    return d.mean()
        elif self.reduction == 'sum':   return d.sum()
        return d
    
    def get_old(self):
        if hasattr(self, 'old_crit'):  return self.old_crit
        elif hasattr(self, 'old_red'): 
            setattr(self.crit, 'reduction', self.old_red)
            return self.crit


class MixUpCallback(LearnerCallback):
    "Callback that creates the mixed-up input and target."
    def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True):
        super().__init__(learn)
        self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y
    
    def on_train_begin(self, **kwargs):
        if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func)
        
    def on_batch_begin(self, last_input, last_target, train, **kwargs):
        "Applies mixup to `last_input` and `last_target` if `train`."
        if not train: return
        lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
        lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
        lambd = last_input.new(lambd)
        shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
        x1, y1 = last_input[shuffle], last_target[shuffle]
        if self.stack_x:
            new_input = [last_input, last_input[shuffle], lambd]
        else: 
            out_shape = [lambd.size(0)] + [1 for _ in range(len(x1.shape) - 1)]
            new_input = (last_input * lambd.view(out_shape) + x1 * (1-lambd).view(out_shape))
        if self.stack_y:
            new_target = torch.cat([last_target.float(), y1.float(), lambd[:,None].float()], 1)
        else:
            if len(last_target.shape) == 2:
                lambd = lambd.unsqueeze(1).float()
            new_target = last_target.float() * lambd + y1.float() * (1-lambd)
        return {'last_input': new_input, 'last_target': new_target}  
    
    def on_train_end(self, **kwargs):
        if self.stack_y: self.learn.loss_func = self.learn.loss_func.get_old()

## Loss and metrics

In [20]:
class MetricBase(Callback):
    def __init__(self, average='macro'):
        super().__init__()
        self.n_classes = 0
        self.average = average
        self.cm = None
        self.eps = 1e-9
        
    def on_epoch_begin(self, **kwargs):
        self.tp = 0
        self.fp = 0
        self.cm = None
    
    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        last_output = last_output[self.idx]
        last_target = last_target[:,self.idx]
        preds = last_output.argmax(-1).view(-1).cpu()
        targs = last_target.long().cpu()
        
        if self.n_classes == 0:
            self.n_classes = last_output.shape[-1]
            self.x = torch.arange(0, self.n_classes)
        cm = ((preds==self.x[:, None]) & (targs==self.x[:, None, None])) \
          .sum(dim=2, dtype=torch.float32)
        if self.cm is None: self.cm =  cm
        else:               self.cm += cm

    def _weights(self, avg:str):
        if self.n_classes != 2 and avg == "binary":
            avg = self.average = "macro"
            warn("average=`binary` was selected for a non binary case. \
                 Value for average has now been set to `macro` instead.")
        if avg == "binary":
            if self.pos_label not in (0, 1):
                self.pos_label = 1
                warn("Invalid value for pos_label. It has now been set to 1.")
            if self.pos_label == 1: return Tensor([0,1])
            else: return Tensor([1,0])
        elif avg == "micro": return self.cm.sum(dim=0) / self.cm.sum()
        elif avg == "macro": return torch.ones((self.n_classes,)) / self.n_classes
        elif avg == "weighted": return self.cm.sum(dim=1) / self.cm.sum()
        
    def _recall(self):
        rec = torch.diag(self.cm) / (self.cm.sum(dim=1) + self.eps)
        if self.average is None: return rec
        else:
            if self.average == "micro": weights = self._weights(avg="weighted")
            else: weights = self._weights(avg=self.average)
            return (rec * weights).sum()
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self._recall())

    
class GraphemeRoot(MetricBase):
    idx = 0

    
class VowelDiacritic(MetricBase):
    idx = 1


class ConsonantDiacritic(MetricBase):
    idx = 2


class RecallCombine(Callback):
    def __init__(self):
        super().__init__()
        self.grapheme = GraphemeRoot()
        self.vowel = VowelDiacritic()
        self.consonant = ConsonantDiacritic()
        
    def on_epoch_begin(self, **kwargs):
        self.grapheme.on_epoch_begin(**kwargs)
        self.vowel.on_epoch_begin(**kwargs)
        self.consonant.on_epoch_begin(**kwargs)
    
    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        self.grapheme.on_batch_end(last_output, last_target, **kwargs)
        self.vowel.on_batch_end(last_output, last_target, **kwargs)
        self.consonant.on_batch_end(last_output, last_target, **kwargs)
        
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(
            last_metrics,
            0.5 * self.grapheme._recall() +
            0.25 * self.vowel._recall() +
            0.25 * self.consonant._recall()
        )

## Model

In [21]:
from torch.nn.parameter import Parameter


def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

In [22]:
class AdaptiveConcatWithGemPool2d(Module):
    "Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`."
    def __init__(self, sz:Optional[int]=None):
        "Output will be 2*sz or 2 if sz is None"
        self.output_size = sz or 1
        self.ap = GeM()
        self.mp = nn.AdaptiveMaxPool2d(self.output_size)

    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

In [23]:
class MishHead(nn.Module):
    def __init__(self, input_size: int, output_size: int, dropout_ps=0.5):
        super().__init__()

        layers = [
            AdaptiveConcatWithGemPool2d(),
            Flatten(),
            nn.Linear(input_size * 2, 512),
            MishJit(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout_ps),
            nn.Linear(512, output_size)
        ]

        self.fc = nn.Sequential(*layers)
        self._init_weight()

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1.0)
                m.bias.data.zero_()

    def forward(self, x):
        return self.fc(x)

In [24]:
class MishHead(nn.Module):
    def __init__(self, input_size: int, output_size: int, dropout_ps=0.5):
        super().__init__()

        layers = [
            AdaptiveConcatWithGemPool2d(),
            Flatten(),
            nn.Linear(input_size * 2, 512),
            MishJit(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout_ps),
            nn.Linear(512, output_size)
        ]

        self.fc = nn.Sequential(*layers)
        self._init_weight()

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1.0)
                m.bias.data.zero_()

    def forward(self, x):
        return self.fc(x)

In [25]:
class BengaliModel(nn.Module):
    def __init__(self, encoder, encoder_output_features):
        super().__init__()
        self.input_conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=1)
        
        self.encoder = encoder
        
        self.fc_grapheme_root = MishHead(encoder_output_features, output_size=168)
        self.fc_vowel_diacritic = MishHead(encoder_output_features, output_size=11)
        self.fc_consonant_diacritic = MishHead(encoder_output_features, output_size=7)
        
    def forward(self, inputs):
        bs = inputs.size(0)
        
        # Convolve to 3 channels
        x = self.input_conv(inputs)

        # Convolution layers
        x = self.encoder(x)

        return [
            self.fc_grapheme_root(x),
            self.fc_vowel_diacritic(x),
            self.fc_consonant_diacritic(x)
        ]

In [26]:
class LossCombine(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.grapheme_root_class_weight = tensor(grapheme_root_class_weight).float()
        self.vowel_diacritic_class_weight = tensor(vowel_diacritic_class_weight).float()
        self.consonant_diacritic_class_weight = tensor(consonant_diacritic_class_weight).float()
    
        if torch.cuda.is_available():
            self.grapheme_root_class_weight = self.grapheme_root_class_weight.cuda()
            self.vowel_diacritic_class_weight = self.vowel_diacritic_class_weight.cuda()
            self.consonant_diacritic_class_weight = self.consonant_diacritic_class_weight.cuda()
        
    def forward(self, input, target, reduction='mean'):
        x1,x2,x3 = input
        x1,x2,x3 = x1.float(), x2.float(), x3.float()
        y = target.long()

        return (
            wandb.config.grapheme_root_weight * F.cross_entropy(
                x1, y[:,0], weight=self.grapheme_root_class_weight, reduction=reduction) +
            wandb.config.vowel_diacritic_weight * F.cross_entropy(
                x2,y[:,1], weight=self.vowel_diacritic_class_weight, reduction=reduction) +
            wandb.config.consonant_diacritic_weight * F.cross_entropy(
                x3, y[:,2], weight=self.consonant_diacritic_class_weight, reduction=reduction)
        )

In [27]:
class EfficientNetModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.m = EfficientNet.from_pretrained(model_name)
        
    def forward(self, x):
        return self.m.extract_features(x)

In [28]:
class RWightmanModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.m =  timm.create_model(model_name, pretrained=True)
        
    def forward(self, x):
        return self.m.forward_features(x)

In [29]:
class CadeneModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.m = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')
        
    def forward(self, x):
        return self.m.features(x)

## Training

In [30]:
@dataclass
class CutoutScheduler(Callback):
    learn:Learner
    n_holes_max:int=80
    max_at_epoch:int=30
        
    def on_epoch_begin(self, **kwargs):
        i = kwargs['iteration']
        n_epochs = kwargs['n_epochs']
        n_holes_max = int(round(annealing_linear(0, self.n_holes_max, min(i / self.max_at_epoch, 1.))))
        n_holes = (max(n_holes_max-10, 0), n_holes_max)
        print(n_holes)
        self.learn.data.train_ds.tfms[-1].kwargs['n_holes'] = n_holes

In [31]:
if ENCODER_ARCH.startswith('efficientnet-'):
    print(f'Loading EfficientNet models: {ENCODER_ARCH}')
    encoder = EfficientNetModel(ENCODER_ARCH)
    output_feats = encoder.m._fc.in_features
elif ENCODER_ARCH in pretrainedmodels.model_names:
    print(f"Loading Cadene model: {ENCODER_ARCH}")
    encoder = CadeneModel(ENCODER_ARCH)
    output_feats = encoder.m.last_linear.in_features
else:
    print(f"Loading R Wightman model: {ENCODER_ARCH}")
    encoder = RWightmanModel(ENCODER_ARCH)
    output_feats = encoder.m.classifier.in_features

Loading EfficientNet models: efficientnet-b0
Loaded pretrained weights for efficientnet-b0


In [32]:
model = BengaliModel(encoder=encoder, encoder_output_features=output_feats)

learner = Learner(
    data, model, loss_func=LossCombine(),
    metrics=[GraphemeRoot(), VowelDiacritic(), ConsonantDiacritic(), RecallCombine()],
    callback_fns=WandbCallback
)
learner.clip_grad = 1.0
learner.unfreeze()

In [33]:
# learner.lr_find()
# learner.recorder.plot()

In [34]:
callbacks=[
    CSVLogger(learner, OUTPUT_PATH/'history.csv'),
    SaveModelCallback(
        learner, monitor='recall_combine', mode='max', name=NAME
    ),
    MixUpCallback(learner),
    ReduceLROnPlateauCallback(learner, patience=2, monitor='recall_combine', mode='max', min_lr=1e-6),
    EarlyStoppingCallback(learner, monitor='recall_combine', mode='max', patience=6)
]

if wandb.config.prog_sprinkles:
    callbacks.append(CutoutScheduler(learner, n_holes_max=80))

learner.fit(
    wandb.config.max_epochs,
    5e-3,
    callbacks=callbacks
)

epoch,train_loss,valid_loss,grapheme_root,vowel_diacritic,consonant_diacritic,recall_combine,time
0,3.163283,2.555422,0.223216,0.484118,0.634814,0.391341,11:24
1,2.361298,1.223809,0.596086,0.797881,0.81017,0.700056,11:21
2,1.865513,0.805896,0.761809,0.879637,0.878693,0.820487,11:21
3,1.679841,1.056373,0.729215,0.675331,0.875928,0.752422,11:21
4,1.590842,0.504602,0.835044,0.89862,0.912395,0.870276,11:21
5,1.535017,0.498489,0.851066,0.923297,0.935425,0.890213,11:21
6,1.502162,0.700772,0.846955,0.925095,0.886863,0.876467,11:21
7,1.425802,0.522261,0.86778,0.937138,0.923848,0.899137,11:21
8,1.40744,0.401077,0.857696,0.938783,0.94302,0.899299,11:21
9,1.376898,0.446015,0.858173,0.926447,0.932909,0.893926,11:21


Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


wandb: Wandb version 0.8.26 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Better model found at epoch 0 with valid_loss value: 2.555421829223633.
Better model found at epoch 0 with recall_combine value: 0.3913410007953644.


Better model found at epoch 1 with valid_loss value: 1.2238086462020874.


Better model found at epoch 1 with recall_combine value: 0.7000559568405151.


Better model found at epoch 2 with valid_loss value: 0.8058962821960449.


Better model found at epoch 2 with recall_combine value: 0.8204872012138367.


Better model found at epoch 4 with valid_loss value: 0.5046020746231079.


Better model found at epoch 4 with recall_combine value: 0.8702757358551025.


Better model found at epoch 5 with valid_loss value: 0.4984886646270752.


Better model found at epoch 5 with recall_combine value: 0.8902134895324707.


Better model found at epoch 7 with recall_combine value: 0.8991369605064392.


Better model found at epoch 10 with recall_combine value: 0.9107271432876587.


Better model found at epoch 16 with valid_loss value: 0.185821071267128.


Better model found at epoch 16 with recall_combine value: 0.9533652067184448.


Better model found at epoch 19 with valid_loss value: 0.18294832110404968.


Better model found at epoch 19 with recall_combine value: 0.9546142816543579.


Better model found at epoch 24 with recall_combine value: 0.9594618678092957.


Better model found at epoch 29 with recall_combine value: 0.9613159894943237.


Better model found at epoch 32 with valid_loss value: 0.15331198275089264.


Better model found at epoch 32 with recall_combine value: 0.9619977474212646.


Epoch 35: reducing lr to 4e-05


Epoch 38: reducing lr to 8.000000000000001e-06


In [35]:
learner.recorder.plot_losses()