Training Notebook from Kaggle

# 🚀 Installing and importing

In [1]:
!git clone https://github.com/benihime91/leaf-disease-classification-kaggle.git

!wandb login a74f67fd5fae293e301ea8b6710ee0241f595a63

Cloning into 'leaf-disease-classification-kaggle'...
remote: Enumerating objects: 77, done.[K
remote: Counting objects: 100% (77/77), done.[K
remote: Compressing objects: 100% (58/58), done.[K
remote: Total 1329 (delta 37), reused 51 (delta 19), pack-reused 1252[K
Receiving objects: 100% (1329/1329), 43.77 MiB | 27.65 MiB/s, done.
Resolving deltas: 100% (751/751), done.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [2]:
import sys
sys.path.append('../input/timmmodels/pytorch-image-models/')
sys.path.append('leaf-disease-classification-kaggle/')

import warnings
warnings.filterwarnings('ignore')

In [3]:
import logging
import os

import pytorch_lightning as pl
import torch
from torch import nn, optim
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

from fastai.torch_core import apply_init
from functools import partial
import wandb

from src.core import *
from src.lightning.core import *
from src.layers import *
from src.mixmethods import *
from src.networks import *

logger = logging.getLogger("wandb")
logger.setLevel(logging.ERROR)

**set random seeds so that results are reproducible**

In [4]:
seed = seed_everything(42)
idx  = generate_random_id()

# ⚡ 💘 🏋️‍♀️ Configure the Training Parameters

In [5]:
# configure the training paramters/job
config = dict(
    random_seed = seed,
    unique_idx = idx,
    project_name = "kaggle-leaf-disease-v2",
    
    curr_fold = 0,
    image_dir = "../input/cassava-leaf-disease-classification/train_images/",
    csv_path = "leaf-disease-classification-kaggle/data/stratified-data-5folds.csv",
    
    encoder = "efficientnet_b3a",
    activation = dict(type='torch.nn.ReLU', inplace=True),
    
    image_dims = 512,
    num_epochs = 40,
    batch_size = 30,
    accumulate_batches = 2,
    clip_grad_norm = 0.5
    )

hparams = dict(
    mixmethod = dict(type='src.mixmethods.SnapMix', alpha=5.0, conf_prob=1.0),
    loss_function = dict(type='src.core.LabelSmoothingCrossEntropy', eps=0.1),
    
    learning_rate = 1e-03,
    lr_mult = 100,
    
    optimizer = dict(type='torch.optim.AdamW', betas=(0.9, 0.99), eps=1e-06, weight_decay=1e-03),
    
    scheduler = dict(type='torch.optim.lr_scheduler.CosineAnnealingWarmRestarts', T_0=10, T_mult=2),
    
    metric_to_track = None,
    step_after = "step",
    frequency = 1,
    )


# Albumentations augmentations for train/ valid data
TRAIN_AUGS = A.Compose([
    A.OneOf([
        A.RandomResizedCrop(config["image_dims"], config["image_dims"]), 
        A.CenterCrop(config["image_dims"], config["image_dims"])], 
    p=0.7),
    A.Resize(config["image_dims"], config["image_dims"], p=1.0),
    A.OneOf([A.ShiftScaleRotate(), A.HorizontalFlip(), A.Transpose()], p=0.8),
    A.OneOf([A.RandomBrightnessContrast(), A.HueSaturationValue(20, 20, 20)], p=0.5),
    A.Normalize(p=1.0),
    ToTensorV2(p=1.0)
])
    
VALID_AUGS = A.Compose([
    A.CenterCrop(config["image_dims"], config["image_dims"], p=1.0),
    A.Resize(config["image_dims"], config["image_dims"], p=1.0), 
    A.Normalize(p=1.0),
    ToTensorV2(p=1.0),
])

MODEL_SAVE_PATH = f"{config['encoder']}-fold={config['curr_fold']}-{idx}"

# 🏗️ Building a Model with Lightning

In [6]:
# initate the model architecture
# for snapmix we will call BasicTransferLearningModel class to init a model
# suitable for snapmix, we can also use TransferLearningModel class to init
# a model similar to the model created by the fast.ai cnn_learner func

encoder = timm.create_model(config["encoder"], pretrained=True)

model = SnapMixTransferLearningModel(
    encoder=encoder, 
    c=len(idx2lbl), 
    cut=-2, 
    act=object_from_dict(config["activation"]),)

# init the weights of the final untrained layer
apply_init(model.fc, torch.nn.init.kaiming_normal_)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b3_ra2-cf984f9c.pth


In [7]:
litModel = LightningCassava(model=model, conf=hparams)

Mixmethod : SnapMix
Loss Function : LabelSmoothingCrossEntropy()


In [8]:
print(litModel)

LightningCassava(
  (model): SnapMixTransferLearningModel(
    (encoder): Sequential(
      (0): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
      (3): Sequential(
        (0): Sequential(
          (0): DepthwiseSeparableConv(
            (conv_dw): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
            (bn1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act1): SiLU(inplace=True)
            (se): SqueezeExcite(
              (conv_reduce): Conv2d(40, 10, kernel_size=(1, 1), stride=(1, 1))
              (act1): SiLU(inplace=True)
              (conv_expand): Conv2d(10, 40, kernel_size=(1, 1), stride=(1, 1))
            )
            (conv_pw): Conv2d(40, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn2): BatchNorm2d(24, ep

# 🛒 Loading data

In [9]:
# init the LightingDataModule + LightningModule
dm = CassavaLightningDataModule(config["csv_path"], config["image_dir"], 
                                curr_fold=config["curr_fold"], 
                                train_augs=TRAIN_AUGS, 
                                valid_augs=VALID_AUGS, 
                                bs=config["batch_size"], 
                                num_workers=4)

# 📲 Callbacks ➕ Optional methods for even better logging

In [10]:
# initialize pytorch_lightning Trainer + Callbacks
callbacks = [
    pl.callbacks.LearningRateMonitor("step"), 
    WandbImageClassificationCallback(dm, default_config=config),]

chkpt_callback = pl.callbacks.ModelCheckpoint(
    monitor="valid/acc",
    save_top_k=1,
    mode='max',
    filename=MODEL_SAVE_PATH)

wb_logger = pl.loggers.WandbLogger(project=config["project_name"], log_model=True)

# 👟 Making a Trainer

In [11]:
trainer = pl.Trainer(
    gpus=-1, 
    precision=16,
    checkpoint_callback=chkpt_callback, logger=wb_logger,
    callbacks=callbacks,
    max_epochs=config["num_epochs"],
    gradient_clip_val=config["clip_grad_norm"], 
    accumulate_grad_batches=config["accumulate_batches"],
    log_every_n_steps=1,
    deterministic=True)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


In [12]:
# # start learning_rate finder to find optimum starting Lr
# lr_finder = trainer.tuner.lr_find(litModel, datamodule=dm)

# fig = lr_finder.plot(suggest=True)
# fig.show()

# 🏃‍♀️ Running our Model

In [13]:
# modify the initial learning rate 
litModel.hparams['learning_rate'] = 1e-03

# start the training job
trainer.fit(litModel, datamodule=dm)

Generating data for fold: 0
[34m[1mwandb[0m: Currently logged in as: [33mayushman[0m (use `wandb login --relogin` to force relogin)



  | Name      | Type                         | Params
-----------------------------------------------------------
0 | model     | SnapMixTransferLearningModel | 10.7 M
1 | loss_func | LabelSmoothingCrossEntropy   | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

wandb config updated -->


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…






1

# 💾 Testing and saving the model

In [14]:
# automatically loads in the best model weights
# according to metric in checkpoint callback
results = trainer.test(datamodule=dm, ckpt_path=None) # uses last-saved model

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/acc': tensor(0.8075, device='cuda:0'),
 'test/loss': tensor(0.8046, device='cuda:0'),
 'train/acc': tensor(0.8235, device='cuda:0'),
 'train/acc_epoch': tensor(0.7000, device='cuda:0'),
 'train/acc_step': tensor(0.8235, device='cuda:0'),
 'train/loss': tensor(1.4048, device='cuda:0'),
 'train/loss_epoch': tensor(1.5233, device='cuda:0'),
 'train/loss_step': tensor(1.4048, device='cuda:0'),
 'valid/acc': tensor(0.8075, device='cuda:0'),
 'valid/loss': tensor(0.8046, device='cuda:0')}
--------------------------------------------------------------------------------



In [15]:
path = f"{MODEL_SAVE_PATH}.pt"
# save the weights of the model
litModel.save_model_weights(path)
wandb.save(path)

weights saved to efficientnet_b3a-fold=0-c181d280.pt


['/kaggle/working/wandb/run-20201227_092745-1v39vjad/files/efficientnet_b3a-fold=0-c181d280.pt']

In [16]:
# finish the experiment
wandb.finish()

VBox(children=(Label(value=' 297.80MB of 339.16MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.8780455…

0,1
lr-AdamW/pg1,1e-05
lr-AdamW/pg2,0.00097
train/loss_step,1.40479
train/acc_step,0.82353
epoch,39.0
_step,22879.0
_runtime,23926.0
_timestamp,1609085191.0
valid/loss,0.80462
valid/acc,0.80748


0,1
lr-AdamW/pg1,▂▇█▅█▇▆▃▁██▇▆▄▃▂▁▁███▇▇▇▆▅▅▄▃▃▂▂▁▁▁▁████
lr-AdamW/pg2,▂▇█▅█▇▆▃▁██▇▆▄▃▂▁▁███▇▇▇▆▅▅▄▃▃▂▂▁▁▁▁████
train/loss_step,▄▄▄▄▆▁▆▃▄▃▄█▂▄▅▃▃▃▁▃▅▂▂▃▄▆▂▄▁▁▂▄▄▂▂▂▆▂▄▃
train/acc_step,▃▅▃▃▄▅▃▅▃▅▆▂▆▂▂▆▆▅▆▇▃▅▆▁▃▂▅▄▅█▅▄▅▅▄▆▄▄▆▅
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
valid/loss,█▅▄▄▄▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▄▂▂▂▂▁▁▁▁▁▂▁▁▁▁▂▁▁▁▁
valid/acc,▁▄▄▆▄▆▆▆▆▆▆▆▆▇▇▇▇▆▇▇▆▇▇▇█▇██▇██▇███▇▇▇█▇
