<a href="https://colab.research.google.com/github/benihime91/leaf-disease-classification-kaggle/blob/main/002_train_fold%3D0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

In [None]:
# setup
!pip install pytorch-lightning --quiet
!pip install --upgrade albumentations wandb --quiet
!git clone https://github.com/benihime91/leaf-disease-classification-kaggle.git

In [None]:
! unzip -qq "/content/drive/MyDrive/cassava-leaf-disease-classification.zip" -d "/content/"

In [None]:
import warnings
warnings.filterwarnings("ignore")
import os
os.chdir("/content/leaf-disease-classification-kaggle/")

%matplotlib inline

In [None]:
# --------------------------------
# IMPORT LIBRARIES
# --------------------------------
import pytorch_lightning as pl
import albumentations as A
from albumentations.pytorch import ToTensorV2
import random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torchvision
import logging
import argparse
import pandas as pd
import wandb

from lightning import LightningModel_resnext50_32x4d as LitModel
from lightning import LitDatatModule
from preprocess import Preprocessor

# set random seeds
random.seed(42)
pl.seed_everything(42)
torch.manual_seed(42)
np.random.seed(42)
pd.set_option("display.max_colwidth", None)

# set up paths to the data directories
image_dir = "/content/cassava-leaf-disease-classification/train_images"
csv_dir   = "/content/cassava-leaf-disease-classification/train.csv"
json_dir  = "/content/cassava-leaf-disease-classification/label_num_to_disease_map.json"

# login to wandb: a74f67fd5fae293e301ea8b6710ee0241f595a63
wandb.login(key="a74f67fd5fae293e301ea8b6710ee0241f595a63")
# !wandb login "a74f67fd5fae293e301ea8b6710ee0241f595a63"

In [None]:
os.chdir("/content/")

In [None]:
# # Run this cell for 1st time run
# processor = Preprocessor(csv_dir, json_dir, image_dir, num_folds=5)
# processor._shuffle_and_create_folds()
# dataframe = processor.dataframe
# dataframe.to_csv("/content/fold_df.csv", index=False)
# dataframe.head()

In [None]:
# since we already have the fold dataset
fold_csv = "/content/leaf-disease-classification-kaggle/fold_df.csv"
processor = Preprocessor(csv_dir, json_dir, image_dir, num_folds=5)
# set the dataframe of Preprocessor to the the fold_csv
processor.dataframe = pd.read_csv(fold_csv)
processor.dataframe.head()

In [None]:
# -------------------------------
# Grab one FOLD
# -------------------------------
fold_num = 0
trainFold, valFold = processor.get_fold(fold_num)
testFold, valFold  = train_test_split(valFold, stratify=valFold.label, test_size=0.5) 

trainFold.reset_index(drop=True, inplace=True)
testFold.reset_index(drop=True, inplace=True)
valFold.reset_index(drop=True, inplace=True)

In [None]:
print("Length of train data:", len(trainFold))
print("Length of test data:", len(testFold))
print("Length of valid data:", len(valFold))

In [None]:
weights = processor.weights
weights = torch.tensor(list(weights.values()))
weights

In [None]:
label_map = processor.label_map
label_map

In [None]:
def imshow(image, targets):
    """Imshow for Tensor."""
    grid = torchvision.utils.make_grid(images, normalize=True, nrow=4).permute(1, 2, 0).data.numpy()
    grid = np.array(grid * 255., dtype=np.uint)
    classes = targets.data.numpy()
    plt.figure(figsize=(15, 10))
    plt.axis("off")
    plt.imshow(grid)
    plt.title([label_map[i] for i in classes]);

In [None]:
image_dim = 224 # dimension of the image after resize

# Specify TRANSFORATIONS for TRAIN/VAL/TEST DATALOADERS
train_transformations = A.Compose([
    A.Rotate(p=0.5, limit=60),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.CLAHE(p=0.5),
    A.OneOf([A.RandomFog(), A.RandomRain(), A.RandomSnow()]),
    A.ShiftScaleRotate(p=0.5),
    A.RandomResizedCrop(image_dim, image_dim, always_apply=True),
    A.Normalize(always_apply=True),
    ToTensorV2(always_apply=True),

])

valid_transformations = A.Compose([
    A.Resize(image_dim, image_dim, always_apply=True),
    A.Normalize(always_apply=True),
    ToTensorV2(always_apply=True)
])

test_transformations = valid_transformations

albu_transforms = {
    "train": train_transformations, 
    "valid": valid_transformations,
    "test" : test_transformations,
}

# Generate LIGHTNING-DATAMODULE
batch_size = 64
data_module = LitDatatModule(trainFold, valFold, testFold, batch_size, albu_transforms)

In [None]:
# -----------------------------------
# GENERATE TRAIN/VAL/TEST DATLOADERS
# ------------------------------------
data_module.setup()
train_dl = data_module.train_dataloader()
val_dl = data_module.val_dataloader()
test_dl = data_module.test_dataloader()

In [None]:
# TRAIN DATALOADER
batch = next(iter(train_dl))
images, targets = batch
example_input_array = images # needed to log graph to logger
images  = images[:4]
targets = targets[:4]
# view images from the DATALOADER
imshow(images, targets)

In [None]:
# VALID DATALOADER
batch = next(iter(val_dl))
images, targets = batch
val_samples = batch
images  = images[:4]
targets = targets[:4]
# view images from the DATALOADER
imshow(images, targets)

In [None]:
# TEST DATALOADER
batch = next(iter(test_dl))
images, targets = batch
images  = images[:4]
targets = targets[:4]
# view images from the DATALOADER
imshow(images, targets)

In [None]:
class ImagePredsModelLogger(pl.Callback):
    def __init__(self, val_samples, num_samples=batch_size):
        """
        Upon validation_epoch_end log num_samples images 
        and their predictions to wandb
        """
        super().__init__()
        self.val_imgs, self.val_labels = val_samples
        self.val_imgs = self.val_imgs[:num_samples]
        self.val_labels = self.val_labels[:num_samples]
          
    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)

        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)

        trainer.logger.experiment.log({
            "examples": [wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") for x, pred, y in zip(val_imgs, preds, self.val_labels)],
            "global_step": trainer.global_step
            })

In [None]:
# ----------------------------------
# TRAINING ARGUMENTS
# ------------------------------------
num_epochs = 15
steps_per_epoch = len(train_dl)
total_steps = num_epochs * steps_per_epoch

learning_rate = 3e-04
weight_decay = 0.001

output_dims = len(label_map)

# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--learning_rate", type=float, default=learning_rate, help="AdamW: learning rate")
parser.add_argument("--weight_decay", type=float, default=weight_decay, help="AdamW: weight_decay")
parser.add_argument("--total_steps", type=int, default=total_steps, help="total steps to train for")
parser.add_argument("--output_dims", type=int, default=output_dims, help="number of output classes")
args, _ = parser.parse_known_args()

logger = logging.getLogger("lightning")
logger.info(f"num_epochs: {num_epochs}")
logger.info(f"steps_per_epoch: {steps_per_epoch}")
logger.info(f"total_steps: {total_steps}")
logger.info(f"learning_rate: {learning_rate}")
logger.info(f"weight_decay: {weight_decay}")
logger.info(f"output_dims: {output_dims}")

# -----------------------------------
# LIGHTNING TRAINER
# ------------------------------------

# Init trainer callbacks
PATH = "/content/drive/MyDrive/modelCheckpoint"
os.makedirs(PATH, exist_ok=True)
model_checkpoint = pl.callbacks.ModelCheckpoint(filepath=PATH, monitor="val_loss", save_top_k=1, mode="min")

lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")

run_name = "fold=0"
wb_logger = pl.loggers.WandbLogger(project="kaggle-leaf-disease", name=run_name)
wb_logger.log_hyperparams(vars(args))


callbacks = [lr_monitor, ImagePredsModelLogger(val_samples), stopping]

# Init trainer
trainer = pl.Trainer(
    precision=16, 
    gpus=-1, 
    logger=wb_logger,
    checkpoint_callback=model_checkpoint, 
    callbacks=callbacks, 
    max_epochs=num_epochs, 
    max_steps=total_steps,
    gradient_clip_val=0.1,
    benchmark=True,
    )

In [None]:
# -------------------------------
# INSTANTIATE AND FIT MODEL  :
# --------------------------------

# Init DataModule
batch_size = 64
data_module = LitDatatModule(trainFold, valFold, testFold, batch_size, albu_transforms)

# Init model
model = LitModel(**vars(args), class_weights=weights)
model.example_input_array = torch.zeros_like(example_input_array)

# Freeze the feature extractor/base of the model
model.freeze_classifier()

# Log model topology 
wb_logger.watch(model.net)

# Pass the datamodule as arg to trainer.fit to override model hooks :)
trainer.fit(model, datamodule=data_module)

In [None]:
# Compute metrics on test dataset
trainer.test(model, datamodule=data_module)
wandb.finish()

In [None]:
torchmodel = model.net
# save torch model state dict
torch.save(torch.model.state_dict(), f"/content/modelWeights-Fold={fold_num}.pth")