<a href="https://colab.research.google.com/github/joe-jachim/cassava-leaf-classifier/blob/main/tez_faster_and_easier_training_for_leaf_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Tez

In [None]:
%%capture

import os

if not os.path.isfile('/content/data/train.csv'):
  from google.colab import drive
  drive.mount('/content/drive')
  !pip install tez
  !pip install efficientnet-pytorch
  !pip install efficientnet_pytorch
  !pip install -U git+https://github.com/albu/albumentations --no-cache-dir
  !sudo apt-get install unzip
  !unzip /content/drive/MyDrive/data/cassava-leaf-disease-classification.zip -d /content/data

# Import What You Need

In [None]:
import albumentations
import pandas as pd

import tez
from tez.datasets import ImageDataset
from tez.callbacks import EarlyStopping

import torch
import torch.nn as nn
from torch.nn import functional as F

#from efficientnet_pytorch import EfficientNet
from sklearn import metrics, model_selection, preprocessing

# Model via Tez.Model

In [None]:
from efficientnet_pytorch import EfficientNet

In [None]:


class LeafModel(tez.Model):
    def __init__(self, num_classes):
        super().__init__()

        self.effnet = EfficientNet.from_pretrained("efficientnet-b7")
        self.dropout = nn.Dropout(0.25)
        self.out = nn.Linear(1792, num_classes)
        self.step_scheduler_after = "epoch"
        
    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        outputs = torch.argmax(outputs, dim=1).cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": accuracy}
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=3e-4)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch

    def forward(self, image, targets=None):
        batch_size, _, _, _ = image.shape

        x = self.effnet.extract_features(image)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1)
        outputs = self.out(self.dropout(x))
        
        if targets is not None:
            loss = nn.CrossEntropyLoss()(outputs, targets)
            metrics = self.monitor_metrics(outputs, targets)
            return outputs, loss, metrics
        return outputs, None, None

# Augmentations

In [None]:
# augmentations taken from: https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-train-amp-aug
train_aug = albumentations.Compose([
            albumentations.RandomResizedCrop(256, 256),
            albumentations.Transpose(p=0.5),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
            albumentations.HueSaturationValue(
                hue_shift_limit=0.2, 
                sat_shift_limit=0.2, 
                val_shift_limit=0.2, 
                p=0.5
            ),
            albumentations.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1), 
                p=0.5
            ),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
            albumentations.CoarseDropout(p=0.5),
            albumentations.Cutout(p=0.5)], p=1.)
  
        
valid_aug = albumentations.Compose([
            albumentations.CenterCrop(256, 256, p=1.),
            albumentations.Resize(256, 256),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            )], p=1.)

# Read CSV, split & create dataset

In [None]:
dfx = pd.read_csv('data/train.csv')
df_train, df_valid = model_selection.train_test_split(
        dfx, test_size=0.1, random_state=42, stratify=dfx.label.values
)

df_train = df_train.reset_index(drop=True)
df_valid = df_valid.reset_index(drop=True)

image_path = "data/train_images/"
train_image_paths = [os.path.join(image_path, x) for x in df_train.image_id.values]
valid_image_paths = [os.path.join(image_path, x) for x in df_valid.image_id.values]
train_targets = df_train.label.values
valid_targets = df_valid.label.values

train_dataset = ImageDataset(
    image_paths=train_image_paths,
    targets=train_targets,
    augmentations=train_aug,
)

valid_dataset = ImageDataset(
    image_paths=valid_image_paths,
    targets=valid_targets,
    augmentations=valid_aug,
)

In [None]:
import matplotlib.pyplot as plt

# Load, Train & Save Model

In [None]:
torch.cuda.device_count()

1

In [None]:
model = LeafModel(num_classes=dfx.label.nunique())
es = EarlyStopping(
    monitor="valid_loss", model_path="model.bin", patience=3, mode="min"
)
model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=64,
    valid_bs=64,
    device="cuda",
    epochs=100,
    callbacks=[es],
    fp16=True,
)
model.save("model.bin")

Loaded pretrained weights for efficientnet-b4


100%|██████████| 301/301 [05:23<00:00,  1.07s/it, accuracy=0.778, loss=0.633, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.51it/s, accuracy=0.804, loss=0.571, stage=valid]


Validation score improved (inf --> 0.5709431486971238). Saving model!


100%|██████████| 301/301 [05:23<00:00,  1.08s/it, accuracy=0.835, loss=0.47, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.53it/s, accuracy=0.843, loss=0.469, stage=valid]


Validation score improved (0.5709431486971238 --> 0.46927128995166106). Saving model!


100%|██████████| 301/301 [05:23<00:00,  1.07s/it, accuracy=0.852, loss=0.431, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.51it/s, accuracy=0.842, loss=0.446, stage=valid]


Validation score improved (0.46927128995166106 --> 0.4462587298715816). Saving model!


100%|██████████| 301/301 [05:24<00:00,  1.08s/it, accuracy=0.861, loss=0.395, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.52it/s, accuracy=0.853, loss=0.437, stage=valid]


Validation score improved (0.4462587298715816 --> 0.43744300130535574). Saving model!


100%|██████████| 301/301 [05:22<00:00,  1.07s/it, accuracy=0.869, loss=0.38, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.52it/s, accuracy=0.847, loss=0.452, stage=valid]
  0%|          | 0/301 [00:00<?, ?it/s]

EarlyStopping counter: 1 out of 3


100%|██████████| 301/301 [05:23<00:00,  1.08s/it, accuracy=0.874, loss=0.36, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.52it/s, accuracy=0.842, loss=0.458, stage=valid]
  0%|          | 0/301 [00:00<?, ?it/s]

EarlyStopping counter: 2 out of 3


100%|██████████| 301/301 [05:23<00:00,  1.08s/it, accuracy=0.882, loss=0.334, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.51it/s, accuracy=0.852, loss=0.425, stage=valid]


Validation score improved (0.43744300130535574 --> 0.42465033031561794). Saving model!


100%|██████████| 301/301 [05:23<00:00,  1.08s/it, accuracy=0.887, loss=0.319, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.53it/s, accuracy=0.857, loss=0.412, stage=valid]


Validation score improved (0.42465033031561794 --> 0.4124784688739216). Saving model!


100%|██████████| 301/301 [05:23<00:00,  1.08s/it, accuracy=0.89, loss=0.31, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.52it/s, accuracy=0.854, loss=0.416, stage=valid]
  0%|          | 0/301 [00:00<?, ?it/s]

EarlyStopping counter: 1 out of 3


100%|██████████| 301/301 [05:23<00:00,  1.08s/it, accuracy=0.894, loss=0.297, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.51it/s, accuracy=0.859, loss=0.415, stage=valid]
  0%|          | 0/301 [00:00<?, ?it/s]

EarlyStopping counter: 2 out of 3


100%|██████████| 301/301 [05:23<00:00,  1.07s/it, accuracy=0.873, loss=0.365, stage=train]
100%|██████████| 34/34 [00:22<00:00,  1.53it/s, accuracy=0.832, loss=0.464, stage=valid]


EarlyStopping counter: 3 out of 3
