In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path
import pandas as pd

from fastai.vision import *
from fastai.callbacks import *
from fastai.layers import AdaptiveConcatPool2d
import torch
from efficientnet_pytorch import EfficientNet
import wandb
from wandb.fastai import WandbCallback
import timm

from sklearn.metrics import recall_score

sys.path.append('/home/lextoumbourou/bengaliai-cv19')

from bengali.model import MishHead


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

## Params

In [2]:
NAME = ''
WANDB_MODE = 'dryrun'

DATA_PATH = Path('/home/lextoumbourou/bengaliai-cv19/data')
IMAGE_DATA_PATH = Path(DATA_PATH/'grapheme-imgs-128x128')
OUTPUT_PATH = Path(DATA_PATH/'working')
LABELS_PATH = Path(DATA_PATH/'iterative-stratification')

VALID_PCT = 0.2
SEED = 420
BATCH_SIZE = 128
IMG_SIZE = 128

MAX_WARP = 0.2
P_AFFINE = 0.75
MAX_ROTATE = 40.
MAX_ZOOM = 1.1
P_LIGHTING = 0.75
MAX_LIGHTING = 0.2
MAX_COUNT_RANDOM_ERASING = 3
LABEL_SMOOTHING_EPS = 0.1

MAX_EPOCHS = 150

ENCODER_ARCH = 'efficientnet_b0'

GRAPHEME_ROOT_WEIGHT = 0.7
VOWEL_DIACRITIC_WEIGHT = 0.1
CONSONANT_DIACRITIC_WEIGHT = 0.2
MODEL_HEAD = 'mish_head'

SAMPLE_SIZE = None

PROG_SPRINKES = False

In [3]:
# Parameters
LABEL_SMOOTHING_EPS = 0.0
ENCODER_ARCH = "se_resnext50_32x4d"
BATCH_SIZE = 64
WANDB_MODE = "run"
NAME = "se_resnext50_32x4d"
OUTPUT_VAL_SIZE = None


In [4]:
seed_everything(seed=SEED)

In [5]:
!wandb login 563765550fd7b64fd10129216209724e03f3f20c

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/lextoumbourou/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [6]:
# Turn this off before running!!
os.environ['WANDB_MODE'] = WANDB_MODE

In [7]:
wandb.init(project="bengaliai-cv19", name=NAME)

wandb.config.img_size = IMG_SIZE
wandb.config.batch_size = BATCH_SIZE
wandb.config.seed = SEED

wandb.config.max_warp = MAX_WARP
wandb.config.p_affine = P_AFFINE
wandb.config.max_rotate = MAX_ROTATE
wandb.config.max_zoom = MAX_ZOOM
wandb.config.p_lighting = P_LIGHTING
wandb.config.max_lighting = MAX_LIGHTING
wandb.config.max_count_random_erasing = MAX_COUNT_RANDOM_ERASING
wandb.config.grapheme_root_weight = GRAPHEME_ROOT_WEIGHT
wandb.config.vowel_diacritic_weight = VOWEL_DIACRITIC_WEIGHT
wandb.config.consonant_diacritic_weight = CONSONANT_DIACRITIC_WEIGHT
wandb.config.sample_size = SAMPLE_SIZE
wandb.config.encoder_arch = ENCODER_ARCH
wandb.config.max_epochs = MAX_EPOCHS
wandb.config.prog_sprinkles = PROG_SPRINKES
wandb.config.label_smoothing_eps = LABEL_SMOOTHING_EPS
wandb.config.encoder_arch = ENCODER_ARCH

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


wandb: Wandb version 0.8.25 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


## Create datasets and dataloaders

In [8]:
train_df = pd.read_csv(LABELS_PATH/'train_with_fold.csv')
if wandb.config.sample_size:
    print("About to reduce train size.")
    train_df = train_df.sample(n=wandb.config.sample_size, random_state=SEED).reset_index(drop=True)

In [9]:
imagenet_stats

([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

In [10]:
train_tfms = [
    symmetric_warp(
        magnitude=(-wandb.config.max_warp, wandb.config.max_warp),
        p=wandb.config.p_affine
    ),
    rotate(
        degrees=(-wandb.config.max_rotate, wandb.config.max_rotate),
        p=wandb.config.p_affine
    ),
    rand_zoom(
        scale=(1., wandb.config.max_zoom), p=wandb.config.p_affine
    ),
    brightness(
        change=(0.5*(1 - wandb.config.max_lighting), 0.5*(1 + wandb.config.max_lighting)),
        p=wandb.config.p_lighting
    ),
    contrast(
        scale=(1-wandb.config.max_lighting, 1/(1-wandb.config.max_lighting)),
        p=wandb.config.p_lighting
    ),
    cutout(n_holes=(1, 6), length=(5, 15), p=.5)
]

In [11]:
train_df['is_valid'] = train_df.fold == 0

In [12]:
stats = ([0.485], [0.229])

data = (
    ImageList.from_df(
        path=DATA_PATH, df=train_df, folder='./grapheme-imgs-128x128/', suffix='.png',
        cols='image_id', convert_mode='L'
    )
    .split_from_df(col='is_valid')
    .label_from_df(cols=['grapheme_root', 'vowel_diacritic', 'consonant_diacritic'])
    .transform((train_tfms, []), size=wandb.config.img_size, padding_mode='zeros')
    .databunch(bs=wandb.config.batch_size)
).normalize(stats)

In [13]:
data.show_batch()

In [14]:
class MixUpLoss(Module):
    "Adapt the loss function `crit` to go with mixup."
    
    def __init__(self, crit, reduction='mean'):
        super().__init__()
        if hasattr(crit, 'reduction'): 
            self.crit = crit
            self.old_red = crit.reduction
            setattr(self.crit, 'reduction', 'none')
        else: 
            self.crit = partial(crit, reduction='none')
            self.old_crit = crit
        self.reduction = reduction
        
    def forward(self, output, target):
        if len(target.shape) == 2 and target.shape[1] == 7:
            loss1, loss2 = self.crit(output,target[:,0:3].long()), self.crit(output,target[:,3:6].long())
            d = loss1 * target[:,-1] + loss2 * (1-target[:,-1])
        else:  d = self.crit(output, target)
        if self.reduction == 'mean':    return d.mean()
        elif self.reduction == 'sum':   return d.sum()
        return d
    
    def get_old(self):
        if hasattr(self, 'old_crit'):  return self.old_crit
        elif hasattr(self, 'old_red'): 
            setattr(self.crit, 'reduction', self.old_red)
            return self.crit


class MixUpCallback(LearnerCallback):
    "Callback that creates the mixed-up input and target."
    def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True):
        super().__init__(learn)
        self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y
    
    def on_train_begin(self, **kwargs):
        if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func)
        
    def on_batch_begin(self, last_input, last_target, train, **kwargs):
        "Applies mixup to `last_input` and `last_target` if `train`."
        if not train: return
        lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
        lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
        lambd = last_input.new(lambd)
        shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
        x1, y1 = last_input[shuffle], last_target[shuffle]
        if self.stack_x:
            new_input = [last_input, last_input[shuffle], lambd]
        else: 
            out_shape = [lambd.size(0)] + [1 for _ in range(len(x1.shape) - 1)]
            new_input = (last_input * lambd.view(out_shape) + x1 * (1-lambd).view(out_shape))
        if self.stack_y:
            new_target = torch.cat([last_target.float(), y1.float(), lambd[:,None].float()], 1)
        else:
            if len(last_target.shape) == 2:
                lambd = lambd.unsqueeze(1).float()
            new_target = last_target.float() * lambd + y1.float() * (1-lambd)
        return {'last_input': new_input, 'last_target': new_target}  
    
    def on_train_end(self, **kwargs):
        if self.stack_y: self.learn.loss_func = self.learn.loss_func.get_old()

## Loss and metrics

In [15]:
class LabelSmoothingCrossEntropy(Module):
    def __init__(self, eps:float=0.1, reduction='mean'):
        self.eps, self.reduction = eps, reduction

    def forward(self, output, target):
        c = output.size()[-1]
        log_preds = F.log_softmax(output, dim=-1)
        if self.reduction=='sum': loss = -log_preds.sum()
        else:
            loss = -log_preds.sum(dim=-1)
            if self.reduction=='mean':  loss = loss.mean()
        return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)

In [16]:
class LossCombine(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_smoothing_ce = LabelSmoothingCrossEntropy(eps=wandb.config.label_smoothing_eps)
        
    def forward(self, input, target, reduction='mean'):
        x1,x2,x3 = input
        x1,x2,x3 = x1.float(),x2.float(),x3.float()
        y = target.long()
        return (
            wandb.config.grapheme_root_weight * self.label_smoothing_ce(x1, y[:,0]) +
            wandb.config.vowel_diacritic_weight * self.label_smoothing_ce(x2, y[:,1]) +
            wandb.config.consonant_diacritic_weight * self.label_smoothing_ce(x3, y[:,2])
        )

In [17]:
class MetricBase(Callback):
    def __init__(self, average='macro'):
        super().__init__()
        self.n_classes = 0
        self.average = average
        self.cm = None
        self.eps = 1e-9
        
    def on_epoch_begin(self, **kwargs):
        self.tp = 0
        self.fp = 0
        self.cm = None
    
    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        last_output = last_output[self.idx]
        last_target = last_target[:,self.idx]
        preds = last_output.argmax(-1).view(-1).cpu()
        targs = last_target.long().cpu()
        
        if self.n_classes == 0:
            self.n_classes = last_output.shape[-1]
            self.x = torch.arange(0, self.n_classes)
        cm = ((preds==self.x[:, None]) & (targs==self.x[:, None, None])) \
          .sum(dim=2, dtype=torch.float32)
        if self.cm is None: self.cm =  cm
        else:               self.cm += cm

    def _weights(self, avg:str):
        if self.n_classes != 2 and avg == "binary":
            avg = self.average = "macro"
            warn("average=`binary` was selected for a non binary case. \
                 Value for average has now been set to `macro` instead.")
        if avg == "binary":
            if self.pos_label not in (0, 1):
                self.pos_label = 1
                warn("Invalid value for pos_label. It has now been set to 1.")
            if self.pos_label == 1: return Tensor([0,1])
            else: return Tensor([1,0])
        elif avg == "micro": return self.cm.sum(dim=0) / self.cm.sum()
        elif avg == "macro": return torch.ones((self.n_classes,)) / self.n_classes
        elif avg == "weighted": return self.cm.sum(dim=1) / self.cm.sum()
        
    def _recall(self):
        rec = torch.diag(self.cm) / (self.cm.sum(dim=1) + self.eps)
        if self.average is None: return rec
        else:
            if self.average == "micro": weights = self._weights(avg="weighted")
            else: weights = self._weights(avg=self.average)
            return (rec * weights).sum()
    
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics, self._recall())

    
class GraphemeRoot(MetricBase):
    idx = 0

    
class VowelDiacritic(MetricBase):
    idx = 1


class ConsonantDiacritic(MetricBase):
    idx = 2


class RecallCombine(Callback):
    def __init__(self):
        super().__init__()
        self.grapheme = GraphemeRoot()
        self.vowel = VowelDiacritic()
        self.consonant = ConsonantDiacritic()
        
    def on_epoch_begin(self, **kwargs):
        self.grapheme.on_epoch_begin(**kwargs)
        self.vowel.on_epoch_begin(**kwargs)
        self.consonant.on_epoch_begin(**kwargs)
    
    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        self.grapheme.on_batch_end(last_output, last_target, **kwargs)
        self.vowel.on_batch_end(last_output, last_target, **kwargs)
        self.consonant.on_batch_end(last_output, last_target, **kwargs)
        
    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(
            last_metrics,
            0.5 * self.grapheme._recall() +
            0.25 * self.vowel._recall() +
            0.25 * self.consonant._recall()
        )

## Model

In [18]:
class BengaliModel(nn.Module):
    def __init__(self, encoder, encoder_output_features):
        super().__init__()
        self.input_conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=1)
        
        self.pooling = AdaptiveConcatPool2d(1)
        
        self.encoder = encoder
    
        self.fc_grapheme_root = MishHead(encoder_output_features, output_size=168)
        self.fc_vowel_diacritic = MishHead(encoder_output_features, output_size=11)
        self.fc_consonant_diacritic = MishHead(encoder_output_features, output_size=7)
        
    def forward(self, inputs):
        bs = inputs.size(0)
        
        # Convolve to 3 channels
        x = self.input_conv(inputs)

        # Convolution layers
        x = self.encoder(x)
        
        # Pooling
        x = self.pooling(x)
        
        # Final layers
        x = x.view(bs, -1)

        return [
            self.fc_grapheme_root(x),
            self.fc_vowel_diacritic(x),
            self.fc_consonant_diacritic(x)
        ]

In [19]:
import pretrainedmodels

In [20]:
model_name = 'se_resnext50_32x4d' # could be fbresnet152 or inceptionresnetv2
model = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')

In [21]:
class RWightmanModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.m =  timm.create_model(model_name, pretrained=True)
        
    def forward(self, x):
        return self.m.forward_features(x)

In [22]:
class CadeneModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.m = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')
        
    def forward(self, x):
        return self.m.features(x)

## Training

In [23]:
@dataclass
class CutoutScheduler(Callback):
    learn:Learner
    n_holes_max:int=80
    max_at_epoch:int=30
        
    def on_epoch_begin(self, **kwargs):
        i = kwargs['iteration']
        n_epochs = kwargs['n_epochs']
        n_holes_max = int(round(annealing_linear(0, self.n_holes_max, min(i / self.max_at_epoch, 1.))))
        n_holes = (max(n_holes_max-10, 0), n_holes_max)
        print(n_holes)
        self.learn.data.train_ds.tfms[-1].kwargs['n_holes'] = n_holes

In [24]:
if ENCODER_ARCH in pretrainedmodels.model_names:
    print(f"Loading Cadene model: {ENCODER_ARCH}")
    encoder = CadeneModel(ENCODER_ARCH)
    output_feats = encoder.m.last_linear.in_features
else:
    print("Loading R WIghtman model: {ENCODER_ARCH}")
    encoder = RWightmanModel(ENCODER_ARCH)
    output_feats = encoder.m.classifier.in_features

Loading Cadene model: se_resnext50_32x4d


In [25]:
model = BengaliModel(encoder=encoder, encoder_output_features=output_feats)

learner = Learner(
    data, model, loss_func=LossCombine(),
    metrics=[GraphemeRoot(), VowelDiacritic(), ConsonantDiacritic(), RecallCombine()],
    callback_fns=WandbCallback
)
learner.clip_grad = 1.0
learner.unfreeze()

In [26]:
# learner.lr_find()
# learner.recorder.plot()

In [27]:
callbacks=[
    CSVLogger(learner, OUTPUT_PATH/'history.csv'),
    SaveModelCallback(
        learner, monitor='recall_combine', mode='max', name=NAME
    ),
    MixUpCallback(learner),
    ReduceLROnPlateauCallback(learner, patience=2, monitor='recall_combine', mode='max', min_lr=1e-5),
    EarlyStoppingCallback(learner, monitor='recall_combine', mode='max', patience=3)
]

if wandb.config.prog_sprinkles:
    callbacks.append(CutoutScheduler(learner, n_holes_max=80))

learner.fit(
    wandb.config.max_epochs,
    3e-3,
    callbacks=callbacks
)

epoch,train_loss,valid_loss,grapheme_root,vowel_diacritic,consonant_diacritic,recall_combine,time
0,2.090501,0.963677,0.554857,0.814893,0.770113,0.67368,36:22
1,1.545009,0.571088,0.7961,0.895978,0.874038,0.840554,36:08
2,1.452983,0.450001,0.881916,0.935726,0.926816,0.906593,36:12
3,1.374492,0.394108,0.872319,0.946087,0.921715,0.90311,36:10
4,1.361522,0.379101,0.908699,0.955931,0.935627,0.927239,36:01
5,1.331822,0.352243,0.911196,0.965366,0.940421,0.932045,36:00
6,1.291239,0.342641,0.914152,0.958502,0.944919,0.932931,36:00
7,1.274053,0.345318,0.917023,0.96715,0.957935,0.939783,36:00
8,1.283398,0.321884,0.925945,0.963151,0.944528,0.939892,36:00
9,1.252005,0.337019,0.913521,0.970539,0.954628,0.938052,35:59


Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable


wandb: Wandb version 0.8.25 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Better model found at epoch 0 with valid_loss value: 0.9636773467063904.


Better model found at epoch 0 with recall_combine value: 0.6736799478530884.


Better model found at epoch 1 with valid_loss value: 0.5710877776145935.


Better model found at epoch 1 with recall_combine value: 0.8405538201332092.


Better model found at epoch 2 with valid_loss value: 0.4500010013580322.


Better model found at epoch 2 with recall_combine value: 0.906593382358551.


Better model found at epoch 3 with valid_loss value: 0.39410820603370667.


Better model found at epoch 5 with valid_loss value: 0.3522430956363678.


Better model found at epoch 5 with recall_combine value: 0.9320445656776428.


Better model found at epoch 8 with valid_loss value: 0.3218839466571808.


Better model found at epoch 8 with recall_combine value: 0.9398922324180603.


Better model found at epoch 11 with valid_loss value: 0.3016285300254822.


Better model found at epoch 11 with recall_combine value: 0.9456005692481995.


Epoch 18: reducing lr to 0.0006000000000000001


Better model found at epoch 21 with valid_loss value: 0.2489168792963028.


Better model found at epoch 24 with valid_loss value: 0.23974283039569855.


Better model found at epoch 27 with recall_combine value: 0.9680794477462769.


Better model found at epoch 29 with valid_loss value: 0.23702527582645416.


Epoch 40: early stopping


Loaded best saved model from /home/lextoumbourou/bengaliai-cv19/wandb/run-20200208_015420-h03lkrz9/bestmodel.pth


In [28]:
learner.recorder.plot_losses()