Training Notebook from Google-colab

# 🚀 Installing and importing

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Ensure colab doesn't disconnect
%%javascript
function ClickConnect(){
console.log("Working");
document.querySelector("colab-toolbar-button#connect").click()
}setInterval(ClickConnect,60000)

In [None]:
!unzip -qq '/content/drive/MyDrive/cassava-leaf-disease-classification.zip'

In [None]:
!pip install --upgrade wandb albumentations pytorch-lightning timm --quiet

In [None]:
!git clone https://github.com/benihime91/leaf-disease-classification-kaggle.git

!wandb login a74f67fd5fae293e301ea8b6710ee0241f595a63

In [None]:
import sys
#sys.path.append('../input/timmmodels/pytorch-image-models/')
sys.path.append("/content/leaf-disease-classification-kaggle")

import warnings
warnings.filterwarnings('ignore')

In [None]:
import logging
import os

import pytorch_lightning as pl
import torch
from torch import nn, optim
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

from fastai.torch_core import apply_init
from functools import partial
import wandb

from src.core import *
from src.lightning.core import *
from src.layers import *
from src.mixmethods import *
from src.networks import *

logger = logging.getLogger("wandb")
logger.setLevel(logging.ERROR)

**set random seeds so that results are reproducible**

In [None]:
seed = seed_everything(42)
idx  = generate_random_id()

# ⚡ 💘 🏋️‍♀️ Configure the Training Parameters

In [None]:
# configure the training paramters/job
config = dict(
    random_seed = seed,
    unique_idx = idx,
    project_name = "kaggle-leaf-disease-v2",
    
    curr_fold = 0,
    image_dir = "cassava-leaf-disease-classification/train_images/",
    csv_path = "leaf-disease-classification-kaggle/data/stratified-data-5folds.csv",
    
    encoder = "tf_efficientnet_b3_ns",
    activation = dict(type='src.layers.Mish'),
    
    image_dims = 512,
    num_epochs = 30,
    batch_size = 32,
    accumulate_batches = 1,
    clip_grad_norm = 0.1
    )

hparams = dict(
    mixmethod = dict(type='src.mixmethods.Mixup', alpha=0.5),
    loss_function = dict(type='src.core.LabelSmoothingCrossEntropy', eps=0.1),
    
    learning_rate = 1e-03,
    lr_mult = 100,
    
    optimizer = dict(type='torch.optim.Adam', betas=(0.9, 0.99), eps=1e-06, weight_decay=1e-06),
    
    scheduler = dict(type='torch.optim.lr_scheduler.CosineAnnealingWarmRestarts', T_0=10, T_mult=2),
    
    metric_to_track = None,
    step_after = "step",
    frequency = 1,
    )


# Albumentations augmentations for train/ valid data
TRAIN_AUGS = A.Compose([
    A.OneOf([
        A.RandomResizedCrop(config["image_dims"], config["image_dims"]), 
        A.CenterCrop(config["image_dims"], config["image_dims"])], 
    p=0.7),
    A.Resize(config["image_dims"], config["image_dims"], p=1.0),
    A.OneOf([A.ShiftScaleRotate(), A.HorizontalFlip(), A.Transpose()], p=0.8),
    A.OneOf([A.RandomBrightnessContrast(0.1, 0.1), A.HueSaturationValue(20, 20, 20)], p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
    A.CoarseDropout(p=0.5),
    ToTensorV2(p=1.0),
])
    
VALID_AUGS = A.Compose([
    A.CenterCrop(config["image_dims"], config["image_dims"], p=1.0),
    A.Resize(config["image_dims"], config["image_dims"], p=1.0), 
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
])

MODEL_SAVE_PATH = f"{config['encoder']}-fold={config['curr_fold']}-{idx}"

# 🏗️ Building a Model with Lightning

In [None]:
# initate the model architecture
# for snapmix we will call BasicTransferLearningModel class to init a model
# suitable for snapmix, we can also use TransferLearningModel class to init
# a model similar to the model created by the fast.ai cnn_learner func

encoder = timm.create_model(config["encoder"], pretrained=True)

model = TransferLearningModel(
    encoder, 
    cut=-2, 
    c=len(idx2lbl), 
    act=object_from_dict(config["activation"]))

# replace all the model activations
replace_activs(model.encoder, func=object_from_dict(config["activation"]))

# init the weights of the final untrained layer
apply_init(model.fc, torch.nn.init.kaiming_normal_)

In [None]:
litModel = LightningCassava(model=model, conf=hparams)

In [None]:
print(litModel)

# 🛒 Loading data

In [None]:
# init the LightingDataModule + LightningModule
dm = CassavaLightningDataModule(config["csv_path"], config["image_dir"], 
                                curr_fold=config["curr_fold"], 
                                train_augs=TRAIN_AUGS, 
                                valid_augs=VALID_AUGS, 
                                bs=config["batch_size"], 
                                num_workers=0)

# 📲 Callbacks ➕ Optional methods for even better logging

In [None]:
# initialize pytorch_lightning Trainer + Callbacks
callbacks = [
    pl.callbacks.LearningRateMonitor("step"), 
    WandbImageClassificationCallback(dm, default_config=config),]

chkpt_callback = pl.callbacks.ModelCheckpoint(
    monitor="valid/acc",
    save_top_k=1,
    mode='max',
    filename=MODEL_SAVE_PATH)

wb_logger = pl.loggers.WandbLogger(project=config["project_name"], log_model=True)

# 👟 Making a Trainer

In [None]:
trainer = pl.Trainer(
    gpus=-1, 
    precision=16,
    checkpoint_callback=chkpt_callback, logger=wb_logger,
    callbacks=callbacks,
    max_epochs=config["num_epochs"],
    gradient_clip_val=config["clip_grad_norm"], 
    accumulate_grad_batches=config["accumulate_batches"],
    log_every_n_steps=1,
    deterministic=True)

In [None]:
# start learning_rate finder to find optimum starting Lr
lr_finder = trainer.tuner.lr_find(litModel, datamodule=dm)

fig = lr_finder.plot(suggest=True)
fig.show()

# 🏃‍♀️ Running our Model

In [None]:
# modify the initial learning rate 
litModel.hparams['learning_rate'] = 1e-03

# start the training job
trainer.fit(litModel, datamodule=dm)

# 💾 Testing and saving the model

In [None]:
# automatically loads in the best model weights
# according to metric in checkpoint callback
results = trainer.test(datamodule=dm, ckpt_path=None) # uses last-saved model

In [None]:
path = f"{MODEL_SAVE_PATH}.pt"
# save the weights of the model
litModel.save_model_weights(path)
wandb.save(path)

In [None]:
# finish the experiment
wandb.finish()