<a href="https://colab.research.google.com/github/joe-jachim/cassava-leaf-classifier/blob/main/full_pytorch_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -U git+https://github.com/albu/albumentations --no-cache-dir

Collecting git+https://github.com/albu/albumentations
  Cloning https://github.com/albu/albumentations to /tmp/pip-req-build-cinqrrba
  Running command git clone -q https://github.com/albu/albumentations /tmp/pip-req-build-cinqrrba
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.5.2-cp36-none-any.whl size=79946 sha256=f876052ff1e55da8240ae6fef7502ed12545bfe962aa0413367a0ee31a4821cb
  Stored in directory: /tmp/pip-ephem-wheel-cache-n395d7rv/wheels/45/8b/e4/2837bbcf517d00732b8e394f8646f22b8723ac00993230188b
Successfully built albumentations
Installing collected packages: albumentations
  Found existing installation: albumentations 0.5.2
    Uninstalling albumentations-0.5.2:
      Successfully uninstalled albumentations-0.5.2
Successfully installed albumentations-0.5.2


In [2]:
! pip install pytorch-lightning



In [3]:
%%javascript
function ClickConnect(){
console.log("Working");
document.querySelector("colab-toolbar-button#connect").click()
}setInterval(ClickConnect,60000)

<IPython.core.display.Javascript object>

In [4]:
! pip install geffnet



In [5]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import sys
from typing import Tuple
import PIL
from torch.utils.data import Dataset
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImage
from torch.utils.data.dataloader import DataLoader
import numpy as np
import pandas as pd
from pytorch_lightning import LightningDataModule
from sklearn.model_selection import train_test_split, StratifiedKFold
import albumentations as A
from albumentations.pytorch.transforms import ToTensor

from torchvision import models
import torch.nn as nn
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torch import optim
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

sys.path.append('../input/geneffnet/gen-efficientnet-pytorch-master')

import geffnet

path = Path("drive/MyDrive/data/cassava-leaf-disease-classification/")

In [6]:
class CassavaDataset(Dataset):
    def __init__(self, path, df, transform=None) -> None:
        super().__init__()
        self.df = df
        self.path = path
        self.transform = transform
        self.num_workers = 2

    def __getitem__(self, index) -> Tuple[PILImage, int]:
        img_id, label = self.df.iloc[index]
        image = Image.open(self.path / img_id)
        image = np.array(image)
        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]
        return image, label

    def __len__(self):
        return self.df.shape[0]

In [7]:
class CassavaDataModule(LightningDataModule):
    def __init__(
        self,
        path: str = None,
        aug_p: float = 0.5,
        val_pct: float = 0.2,
        img_sz: int = 224,
        batch_size: int = 64,
        num_workers: int = 4,
        fold_id: int = 0,
    ):
        super().__init__()
        self.path = Path(path)
        self.aug_p = aug_p
        self.val_pct = val_pct
        self.img_sz = img_sz
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.fold_id = fold_id

    def prepare_data(self):
        # only called on 1 GPU/TPU in distributed
        df = pd.read_csv(self.path / "train.csv")
        skf = StratifiedKFold(n_splits=5)
        t = df.label
        train_index, valid_index = list(skf.split(np.zeros(len(t)), t))[self.fold_id]
        train_df = df.loc[train_index]
        valid_df = df.loc[valid_index]

        train_df.to_pickle("train_df.pkl")
        valid_df.to_pickle("valid_df.pkl")

    def setup(self):
        # called on every process in DDP
        self.train_transform, self.test_transform = get_augmentations(
            p=self.aug_p, image_size=self.img_sz
        )
        self.train_df = pd.read_pickle("train_df.pkl")
        self.valid_df = pd.read_pickle("valid_df.pkl")

    def train_dataloader(self):
        train_dataset = CassavaDataset(
            self.path / "train_images", df=self.train_df, transform=self.train_transform
        )
        return DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        valid_dataset = CassavaDataset(
            self.path / "train_images", df=self.valid_df, transform=self.test_transform
        )
        return DataLoader(
            valid_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )

In [8]:
df = pd.read_csv(path/'train.csv')
ds = CassavaDataset(path/'train_images',df=df)

In [9]:
ds.path

PosixPath('drive/MyDrive/data/cassava-leaf-disease-classification/train_images')

In [10]:
def get_augmentations(p=0.5, image_size=224):
    imagenet_stats = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
    train_tfms = A.Compose(
        [
            # A.Resize(image_size, image_size),
            A.RandomResizedCrop(image_size, image_size),
            A.ShiftScaleRotate(shift_limit=0.15, scale_limit=0.4, rotate_limit=45, p=p),
            A.Cutout(p=p),
            A.RandomRotate90(p=p),
            A.Flip(p=p),
            A.OneOf(
                [
                    A.RandomBrightnessContrast(
                        brightness_limit=0.2,
                        contrast_limit=0.2,
                    ),
                    A.HueSaturationValue(
                        hue_shift_limit=20, sat_shift_limit=50, val_shift_limit=50
                    ),
                ],
                p=p,
            ),
            A.OneOf(
                [
                    A.IAAAdditiveGaussianNoise(),
                    A.GaussNoise(),
                ],
                p=p,
            ),
            A.CoarseDropout(max_holes=10, p=p),
            A.OneOf(
                [
                    A.MotionBlur(p=0.2),
                    A.MedianBlur(blur_limit=3, p=0.1),
                    A.Blur(blur_limit=3, p=0.1),
                ],
                p=p,
            ),
            A.OneOf(
                [
                    A.OpticalDistortion(p=0.3),
                    A.GridDistortion(p=0.1),
                    A.IAAPiecewiseAffine(p=0.3),
                ],
                p=p,
            ),
            ToTensor(normalize=imagenet_stats),
        ]
    )

    valid_tfms = A.Compose(
        [A.CenterCrop(image_size, image_size), ToTensor(normalize=imagenet_stats)]
    )

    return train_tfms, valid_tfms

In [11]:
ssl_models = [
    "resnet18_ssl",
    "resnet50_ssl",
    "resnext50_32x4d_ssl",
    "resnext101_32x4d_ssl",
    "resnext101_32x8d_ssl",
    "resnext101_32x16d_ssl",
]

class Resnext(nn.Module):
    def __init__(
        self,
        model_name="resnet18_ssl",
        pool_type=F.adaptive_avg_pool2d,
        num_classes=1000,
        kaggle=False,
    ):
        super().__init__()
        self.pool_type = pool_type

        if kaggle:
            backbone = eval(model_name)()
        else:
            backbone = torch.hub.load(
                "facebookresearch/semi-supervised-ImageNet1K-models", model_name
            )
        list(backbone.children())[:-2]
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        in_features = getattr(backbone, "fc").in_features
        self.classifier = nn.Linear(in_features, num_classes)

    def forward(self, x):
        features = self.pool_type(self.backbone(x), 1)
        features = features.view(x.size(0), -1)
        return self.classifier(features)


def get_efficientnet(model_name, pretrained=True, num_classes=5):
    model = geffnet.create_model(model_name, pretrained=pretrained)
    model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    return model

In [12]:
class CassavaModel(pl.LightningModule):
    def __init__(
        self,
        model_name: str = None,
        num_classes: int = None,
        data_path: Path = None,
        loss_fn=F.cross_entropy,
        lr=1e-4,
        wd=1e-6,
    ):
        super().__init__()

        if model_name.find("res") > -1:
            self.model = Resnext(model_name=model_name, num_classes=num_classes)
        elif model_name.find("effi") > -1:
            self.model = get_efficientnet(model_name)
        self.data_path = data_path
        self.loss_fn = loss_fn
        self.lr = lr
        self.accuracy = pl.metrics.Accuracy()
        self.wd = wd

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("valid_loss", loss, prog_bar=True)
        self.log("val_acc", self.accuracy(y_hat, y), prog_bar=True)

    def configure_optimizers(self):
        optimizer = optim.AdamW(
            self.model.parameters(), lr=self.lr, weight_decay=self.wd
        )
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.trainer.max_epochs, 0
        )

        return [optimizer], [scheduler]

In [13]:
fold_id = 0
aug_p = 0.5
img_sz= 224
batch_size = 64
num_workers = 4
num_classes = 5
loss_fn = F.cross_entropy
lr = 1e-4
epochs = 1
gradient_clip_val = 0.1
precision = 16
model_name=ssl_models[2]

In [14]:
data_module = CassavaDataModule(
    path=path,
    aug_p=aug_p,
    img_sz=img_sz,
    batch_size=batch_size,
    num_workers=num_workers,
    fold_id=fold_id,
)
data_module.prepare_data()
data_module.setup()

In [15]:
model = CassavaModel(
    model_name=model_name,
    num_classes=num_classes,
    data_path=path,
    lr=lr,
    loss_fn=loss_fn,
)

Using cache found in /root/.cache/torch/hub/facebookresearch_semi-supervised-ImageNet1K-models_master


In [16]:
!mkdir drive/MyDrive/data/cassava-leaf-disease-classification/weights

mkdir: cannot create directory ‘drive/MyDrive/data/cassava-leaf-disease-classification/weights’: File exists


In [18]:
weights_path = Path(f"drive/MyDrive/data/cassava-leaf-disease-classification/weights")

checkpoint_callback = ModelCheckpoint(
    dirpath=weights_path,
    save_weights_only=True,
    monitor="val_acc",
    mode="max",
    save_last=True,
    filename=f"{fold_id}",
)
trainer = pl.Trainer(
    gpus=1,
    callbacks=[checkpoint_callback],
    max_epochs=epochs,
    gradient_clip_val=gradient_clip_val,
    precision=precision
   )

trainer.fit(model=model, datamodule=data_module)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name     | Type     | Params
--------------------------------------
0 | model    | Resnext  | 23.0 M
1 | accuracy | Accuracy | 0     
--------------------------------------
23.0 M    Trainable params
0         Non-trainable params
23.0 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint...





1