# Banana Leaf Disease Classifier

In [15]:
# Unzip dataset

In [21]:
!unzip -o dataset.zip

In [None]:
# Dataset Helper Functions

In [2]:
import glob
import os
from torch.utils.data import Dataset, DataLoader
import torch
from PIL import Image
from torchvision import transforms
import numpy as np
import cv2
from pathlib import Path

In [3]:
class BananaLeafDiseaseDataset(Dataset):
    
    def __init__(self, target_dir: str, transform=None):
        self.transform = transform
        
        # Load image paths
        self.image_paths = list(Path(target_dir).glob('*/*.jpg'))
        self.classes, self.class_to_idx = self.load_classes(target_dir)
        
    def load_image(self, idx: int):
        return Image.open(self.image_paths[idx])
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img = self.load_image(idx)
        class_name = self.image_paths[idx].parent.name
        class_idx = self.class_to_idx[class_name]
        
        if self.transform:
            return self.transform(img), class_idx
        else:
            return img, class_idx
        
        
    def load_classes(self, target_dir):
        """
            Returns:
                classes[]
                class_to_idx{}
        """
        classes = sorted(entry.name for entry in os.scandir(target_dir) if entry.is_dir())
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

In [4]:
img_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((144,144))
])

In [5]:
dataset = BananaLeafDiseaseDataset('./dataset', transform=img_transforms)

In [6]:
train_size = 0.8
test_size = 1 - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [7]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [8]:
import lightning as pl
import torch
import numpy as np
from torch import nn
from einops.layers.torch import Rearrange
import torch.nn.functional as F

In [9]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

In [10]:
class MixerBlock(nn.Module):
    def __init__(self, dim, num_patch, token_dim, channel_dim, dropout = 0.):
        super().__init__()

        self.token_mix = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n d -> b d n'),
            FeedForward(num_patch, token_dim, dropout),
            Rearrange('b d n -> b n d')
        )

        self.channel_mix = nn.Sequential(
            nn.LayerNorm(dim),
            FeedForward(dim, channel_dim, dropout),
        )

    def forward(self, x):
        x = x + self.token_mix(x)
        x = x + self.channel_mix(x)
        return x

class MLPMixer(pl.LightningModule):
    
    def __init__(
        self, 
        in_channels, 
        dim, 
        num_classes, 
        patch_size, 
        image_size, 
        depth, 
        token_dim, 
        channel_dim
    ):
        super().__init__()

        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        self.num_patch =  (image_size// patch_size) ** 2
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(in_channels, dim, patch_size, patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )

        self.mixer_blocks = nn.ModuleList([])

        for _ in range(depth):
            self.mixer_blocks.append(MixerBlock(dim, self.num_patch, token_dim, channel_dim))

        self.layer_norm = nn.LayerNorm(dim)

        self.mlp_head = nn.Sequential(
            nn.Linear(dim, num_classes)
        )
    
    def forward(self, x):
        x = self.to_patch_embedding(x)
        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)

        x = self.layer_norm(x)
        x = x.mean(dim=1)
        return self.mlp_head(x)
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer
    
    def _calculate_loss(self, batch, mode="train"):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=-1) == y).float().mean()

        self.log("%s_loss" % mode, loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        self.log("%s_acc" % mode, acc, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return loss
    
    def training_step(self, batch, batch_idx):
        return self._calculate_loss(batch, mode="train")
        
    def validation_step(self, batch, batch_idx):
        return self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        return self._calculate_loss(batch, mode="test")

In [11]:
model = MLPMixer(in_channels=3, image_size=144, patch_size=16, num_classes=3,
                     dim=512, depth=8, token_dim=256, channel_dim=2048)

In [12]:
trainer = pl.Trainer(max_epochs=5, check_val_every_n_epoch=1)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(model, train_dataloaders=train_loader)

  rank_zero_warn(

  | Name               | Type       | Params
--------------------------------------------------
0 | to_patch_embedding | Sequential | 393 K 
1 | mixer_blocks       | ModuleList | 17.1 M
2 | layer_norm         | LayerNorm  | 1.0 K 
3 | mlp_head           | Sequential | 1.5 K 
--------------------------------------------------
17.5 M    Trainable params
0         Non-trainable params
17.5 M    Total params
70.179    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [14]:
trainer.test(model, dataloaders=test_loader)

  rank_zero_warn(
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[{'test_loss_epoch': 0.6882150173187256, 'test_acc_epoch': 0.6770427823066711}]