In [1]:
import os
import pandas as pd
import numpy as np
import gc

import torch as th
from torchvision import transforms   
import albumentations as alb

from src.config import Config
from src.dataset import DataModule
from src.model import GraphemeClassifier

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, GPUStatsMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from tqdm.auto import tqdm

from sklearn.utils.class_weight import compute_class_weight

Global seed set to 2021
Global seed set to 2021


# Setup dataset

In [2]:
data_transforms = {
    'train': th.nn.Sequential(
    transforms.CenterCrop(Config.resize_shape),
    transforms.RandomRotation(degrees=35, resample=False, expand=False, center=None, fill=0.0),
    transforms.RandomVerticalFlip(p=0.6),
    transforms.RandomHorizontalFlip(p=0.6),
),
    "validation": th.nn.Sequential(
    transforms.CenterCrop(Config.resize_shape),
    transforms.RandomRotation(degrees=35, resample=False, expand=False, center=None, fill=0.0),

),
    'test': th.nn.Sequential(
    transforms.CenterCrop(Config.resize_shape),

)
}

In [3]:
train_df = pd.read_csv(os.path.join(Config.data_dir, 'train.csv'))

dm = DataModule(
    df=train_df, 
    frac=1, 
    validation_split=.25, 
    train_batch_size=Config.train_batch_size, 
    test_batch_size=Config.test_batch_size,
    transform = data_transforms
)

dm.setup()

[INFO] Training on 150630 samples
[INFO] Validating on 50210 samples


# Training pipeline
* Model definition : GraphemeClassifier
* Callbacks : pl.callbacks
* Logger : pl.loggers


### Model definition

In [4]:
vowels_class_weight = compute_class_weight(
    class_weight='balanced', 
    classes=train_df.vowel_diacritic.unique(), 
    y=train_df.vowel_diacritic.values
)


In [5]:
# classes weights
vowels_class_weight = compute_class_weight(
    class_weight='balanced', 
    classes=train_df.vowel_diacritic.unique(), 
    y=train_df.vowel_diacritic.values
)
g_root_class_weight = compute_class_weight(
    class_weight='balanced', 
    classes=train_df.grapheme_root.unique(), 
    y=train_df.grapheme_root.values
)
consonant_class_weight = compute_class_weight(
    class_weight='balanced', 
    classes=train_df.consonant_diacritic.unique(), 
    y=train_df.consonant_diacritic.values
)
# model definition 
model = GraphemeClassifier(
    base_encoder= Config.base_model,
    arch_from = 'timm',
    vowels_class_weight=th.from_numpy(vowels_class_weight).float(),
    g_root_class_weight=th.from_numpy(g_root_class_weight).float(),
    consonant_class_weight=th.from_numpy(consonant_class_weight).float(),
    drop=0.25,
    lr=Config.learning_rate,
    pretrained=True
)

#model

### Callbacks

In [6]:
# callbacks definitions
model_ckpt = ModelCheckpoint(
    filename=os.path.join(Config.models_dir, f"bengali_grapheme-{Config.base_model}"), 
    monitor='val_recall', 
    mode="max"
)
es = EarlyStopping( 
    monitor='val_recall', 
    patience=10, 
    mode="max"
)
gpu_stats = GPUStatsMonitor(
    memory_utilization = True,
    gpu_utilization = True,
    intra_step_time = False,
    inter_step_time = False,
    fan_speed = True,
    temperature = True,
)

callbacks_list = [es, model_ckpt, gpu_stats]

### Logger

In [7]:
# Logger(s) definition
tb_logger = TensorBoardLogger(
    save_dir = Config.logs_dir,
    name = 'kaggle-bengali-ai',
    default_hp_metric = False
)


### Trainer

In [8]:
trainer = Trainer(
    gpus=1,
    precision=32,
    #fast_dev_run=True,
    max_epochs =  Config.epochs,
    min_epochs =2,
    # plugins = 'deepspeed'
    logger=tb_logger,
    callbacks = callbacks_list
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


### Train model(s)

In [9]:
th.cuda.empty_cache()
trainer.fit(
    model=model, 
    datamodule=dm
)
gc.collect()


  | Name                        | Type       | Params
-----------------------------------------------------------
0 | extractor                   | ResNet     | 21.8 M
1 | encoder                     | Sequential | 21.8 M
2 | dropout_layer               | Dropout    | 0     
3 | grapheme_root_decoder       | Linear     | 168 K 
4 | vowel_diacritic_decoder     | Linear     | 11.0 K
5 | consonant_diacritic_decoder | Linear     | 7.0 K 
-----------------------------------------------------------
22.0 M    Trainable params
0         Non-trainable params
22.0 M    Total params
87.936    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

61

# Convert model to jit

In [10]:
th.jit.save(
    model.to_torchscript(),
    os.path.join(Config.models_dir, 'grapheme-classifier-3-in-1.pt')
)