In [None]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import sys

import torch
from efficientnet_pytorch import EfficientNet
import wandb

from fastai2.basics import *
from fastai2.data.all import *
from fastai2.callback.all import *
from fastai2.callback.wandb import *
from fastai2.vision.all import *
from sklearn.metrics import recall_score

sys.path.append('/home/lextoumbourou/bengaliai-cv19')

from bengali.model import MishHead

## Params

In [None]:
NAME = ''
WANDB_MODE = 'dryrun'

DATA_PATH = Path('/home/lextoumbourou/bengaliai-cv19/data')
IMAGE_DATA_PATH = Path(DATA_PATH/'grapheme-imgs-128x128')
OUTPUT_PATH = Path(DATA_PATH/'working')
LABELS_PATH  = Path(DATA_PATH/'iterative-stratification')

VALID_PCT = 0.2
SEED = 420
BATCH_SIZE = 64
IMG_SIZE = 64

MAX_WARP = 0.2
P_AFFINE = 0.75
MAX_ROTATE = 10.
MAX_ZOOM = 1.1
P_LIGHTING = 0.75
MAX_LIGHTING = 0.2
MAX_COUNT_RANDOM_ERASING = 3

ENCODER_ARCH = 'efficientnet-b0'

GRAPHEME_ROOT_WEIGHT = 2
VOWEL_DIACRITIC_WEIGHT = 1
CONSONANT_DIACRITIC_WEIGHT = 1
MODEL_HEAD = 'mish_head'

SAMPLE_SIZE = None
if not torch.cuda.is_available():
    SAMPLE_SIZE = 10_000

In [None]:
!wandb login 563765550fd7b64fd10129216209724e03f3f20c

In [None]:
# Turn this off before running!!
os.environ['WANDB_MODE'] = WANDB_MODE

In [None]:
wandb.init(project="bengaliai-cv19", name=NAME)

wandb.config.img_size = IMG_SIZE
wandb.config.batch_size = BATCH_SIZE
wandb.config.seed = SEED

wandb.config.max_warp = MAX_WARP
wandb.config.p_affine = P_AFFINE
wandb.config.max_rotate = MAX_ROTATE
wandb.config.max_zoom = MAX_ZOOM
wandb.config.p_lighting = P_LIGHTING
wandb.config.max_lighting = MAX_LIGHTING
wandb.config.max_count_random_erasing = MAX_COUNT_RANDOM_ERASING
wandb.config.grapheme_root_weight = GRAPHEME_ROOT_WEIGHT
wandb.config.vowel_diacritic_weight = VOWEL_DIACRITIC_WEIGHT
wandb.config.consonant_diacritic_weight = CONSONANT_DIACRITIC_WEIGHT
wandb.config.sample_size = SAMPLE_SIZE
wandb.config.ENCODER_ARCH = ENCODER_ARCH

## Create datasets and dataloaders

In [None]:
aug_kwargs = dict(size=wandb.config.img_size, mode='bilinear', pad_mode=PadMode.Reflection, batch=False)

AUGMENTATIONS = [
    Warp(magnitude=wandb.config.max_warp, p=wandb.config.p_affine, **aug_kwargs),
    Rotate(max_deg=wandb.config.max_rotate, p=wandb.config.p_affine, **aug_kwargs),
    Zoom(max_zoom=wandb.config.max_zoom, p=wandb.config.p_affine, **aug_kwargs),
    Brightness(max_lighting=wandb.config.max_lighting, p=wandb.config.p_lighting, batch=False),
    Contrast(max_lighting=wandb.config.max_lighting, p=wandb.config.p_lighting, batch=False),
    RandomErasing(max_count=wandb.config.max_count_random_erasing)
]

In [None]:
train_df = pd.read_csv(LABELS_PATH/'train_with_fold.csv')
if wandb.config.sample_size:
    print("About to reduce train size.")
    train_df = train_df.sample(n=wandb.config.sample_size, random_state=SEED).reset_index(drop=True)

In [None]:
imagenet_stats

In [None]:
datablock = DataBlock(
    blocks=(ImageBlock(cls=PILImageBW), CategoryBlock, CategoryBlock, CategoryBlock),
    getters=[
        ColReader('image_id', pref=IMAGE_DATA_PATH, suff='.png'),
        ColReader('grapheme_root'),
        ColReader('vowel_diacritic'),
        ColReader('consonant_diacritic')
    ],
    splitter=IndexSplitter(train_df.loc[train_df.fold==0].index))

In [None]:
tfms = AUGMENTATIONS + [Normalize(mean=0.485, std=0.229)]

In [None]:
tfms

In [None]:
data = datablock.dataloaders(train_df, bs=wandb.config.batch_size, batch_tfms=tfms)
data.n_inp = 1 

In [None]:
data.show_batch()

## Loss and metrics

In [None]:
def loss_func(inp, grapheme_root_targ, vowel_diacritic_targ, consonant_diacritic_targ, *args, **kwargs):
    grapheme_root_inp, vowel_diacritic_inp, consonant_diacritic_inp = inp

    return (
        F.cross_entropy(grapheme_root_inp, grapheme_root_targ) * wandb.config.grapheme_root_weight +
        F.cross_entropy(vowel_diacritic_inp, vowel_diacritic_targ) * wandb.config.vowel_diacritic_weight +
        F.cross_entropy(consonant_diacritic_inp, consonant_diacritic_targ) * wandb.config.consonant_diacritic_weight
    )

In [None]:
class RecallPartial(Metric):
    """Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."""
    def __init__(self, a=0, **kwargs):
        self.func = partial(recall_score, average='macro', zero_division=0)
        self.a = a

    def reset(self): self.targs,self.preds = [],[]

    def accumulate(self, learn):
        pred = learn.pred[self.a].argmax(dim=-1)
        targ = learn.y[self.a]
        pred,targ = to_detach(pred),to_detach(targ)
        pred,targ = flatten_check(pred,targ)
        self.preds.append(pred)
        self.targs.append(targ)

    @property
    def value(self):
        if len(self.preds) == 0: return
        preds,targs = torch.cat(self.preds),torch.cat(self.targs)
        return self.func(targs, preds)

    @property
    def name(self): return train_df.columns[self.a+1]
    

class RecallCombine(Metric):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.combine = 0

    def accumulate(self, learn):
        scores = [learn.metrics[i].value for i in range(3)]
        self.combine = np.average(scores, weights=[2,1,1])

    @property
    def value(self):
        return self.combine

## Model

In [None]:
from fastai2.layers import AdaptiveConcatPool2d

In [None]:
class BengaliModel(nn.Module):
    def __init__(self, encoder, encoder_output_features):
        super().__init__()
        self.input_conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=1)
        
        self.pooling = AdaptiveConcatPool2d(1)
        
        self.encoder = encoder
    
        self.fc_grapheme_root = MishHead(encoder_output_features, output_size=168)
        self.fc_vowel_diacritic = MishHead(encoder_output_features, output_size=11)
        self.fc_consonant_diacritic = MishHead(encoder_output_features, output_size=7)
        
    def forward(self, inputs):
        bs = inputs.size(0)
        
        # Convolve to 3 channels
        x = self.input_conv(inputs)

        # Convolution layers
        x = self.encoder(x)
        
        # Pooling
        x = self.pooling(x)
        
        # Final layers
        x = x.view(bs, -1)

        return [
            self.fc_grapheme_root(x),
            self.fc_vowel_diacritic(x),
            self.fc_consonant_diacritic(x)
        ]

In [None]:
class EfficientNetEncoder(EfficientNet):
    def forward(self, x):
        """Calls extract_features to extract features, applies final linear layer, and returns logits."""
        return self.extract_features(x)

## Training

In [None]:
encoder = EfficientNetEncoder.from_pretrained(ENCODER_ARCH)
model = BengaliModel(encoder=encoder, encoder_output_features=encoder._fc.in_features)

if torch.cuda.is_available():
    print("Cuda is available")
    model = model.cuda()
    data = data.cuda()

learner = Learner(
    data,
    model,
    loss_func=loss_func,
    cbs=[CSVLogger(OUTPUT_PATH/'history.csv'), WandbCallback(log_preds=False)],
    metrics=[RecallPartial(a=i) for i in range(len(data.c))] + [RecallCombine()]
)

For the first epoch, I'll train just the fc layers and the first layer, which start out as random weights.

In [None]:
learner.fit_one_cycle(12, 1e-3)

In [None]:
learner.recorder.plot_loss()

In [None]:
learner.save('model')

In [None]:
learner.load('model')

## Error analysis

In [None]:
o = learner.get_preds()