In [None]:
%%capture
def is_running_in_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_running_in_colab():
  # Normal packages
  %pip install lightning polars segmentation_models_pytorch
  # Dev packages
  %pip install icecream rich tqdm

In [None]:
from pathlib import Path

import polars as pl
import numpy as np
import torch
import torch.nn as nn
from torchvision.io import decode_image
from torchvision.transforms import v2
import lightning as L
from lightning.pytorch.callbacks import RichProgressBar
from lightning.pytorch.loggers import CSVLogger
import torchmetrics

# Dev Imports
from icecream import ic

class PlantVillageData(L.LightningDataModule):
  def __init__(self, ws_root: Path = Path("."), num_workers=0):
    super().__init__()
    metadata_path = ws_root / 'plantvillage_dataset' / 'metadata'
    self.train_ds = ImageDataset(pl.read_csv(metadata_path / 'resampled_training_set.csv').filter(pl.col('image_path').str.contains('augment').eq(False)), training=True)
    self.val_ds = ImageDataset(pl.read_csv(metadata_path / 'validation_set.csv'))
    self.test_ds = ImageDataset(pl.read_csv(metadata_path / 'test_set.csv'))

    self.n_classes = len(self.train_ds.disease_to_idx)
    self.idx_to_disease = {v:k for k,v in self.train_ds.disease_to_idx.items()}

    self.dataloader_extras = dict(
        num_workers = num_workers,
        pin_memory = True,
        persistent_workers = num_workers > 0
    )

  def train_dataloader(self):
    return torch.utils.data.DataLoader(self.train_ds, batch_size=32, shuffle=True, **self.dataloader_extras)

  def val_dataloader(self):
    return torch.utils.data.DataLoader(self.val_ds, batch_size=64, **self.dataloader_extras)

  def test_dataloader(self):
    return torch.utils.data.DataLoader(self.test_ds, batch_size=64, **self.dataloader_extras)

class ImageDataset(torch.utils.data.Dataset):
  def __init__(self, dataframe: pl.DataFrame, training=False):
    super().__init__()
    self.image_path = dataframe.select('image_path').to_numpy().squeeze().copy()
    self.disease_type = dataframe.select('disease_type').to_numpy().squeeze().copy()
    self.disease_to_idx = {disease: i for i, disease in enumerate(np.unique(self.disease_type))}

    self.training = training
    self.train_transforms = v2.Compose([
        v2.RandomHorizontalFlip(),
        v2.RandomVerticalFlip(),
        v2.RandomErasing(),
    ])
    self.transforms = v2.Compose([
        v2.ToDtype(torch.float32, scale=True),
    ])

  def __len__(self):
    return len(self.image_path)

  def __getitem__(self, idx):
    image = decode_image(self.image_path[idx])
    if self.training:
      image = self.train_transforms(image)
    image = self.transforms(image)
    disease = self.disease_to_idx[self.disease_type[idx]]
    return image, disease

def channel_shuffle(x, groups=2):
  bat_size, channels, w, h = x.shape
  group_c = channels // groups
  x = x.view(bat_size, groups, group_c, w, h)
  x = torch.transpose(x, 1, 2).contiguous()
  x = x.view(bat_size, -1, w, h)
  return x

class ShuffleBlock(nn.Module):
  def __init__(self, in_c, out_c, downsample=False):
    super().__init__()
    self.downsample = downsample
    half_c = out_c // 2
    if downsample:
      self.branch1 = nn.Sequential(
          # 3*3 dw conv, stride = 2
          nn.Conv2d(in_c, in_c, 3, 2, 1, groups=in_c, bias=False),
          nn.BatchNorm2d(in_c),
          # 1*1 pw conv
          nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False),
          nn.BatchNorm2d(half_c),
          nn.ReLU(True)
      )

      self.branch2 = nn.Sequential(
          # 1*1 pw conv
          nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False),
          nn.BatchNorm2d(half_c),
          nn.ReLU(True),
          # 3*3 dw conv, stride = 2
          nn.Conv2d(half_c, half_c, 3, 2, 1, groups=half_c, bias=False),
          nn.BatchNorm2d(half_c),
          # 1*1 pw conv
          nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
          nn.BatchNorm2d(half_c),
          nn.ReLU(True)
      )
    else:
      # in_c = out_c
      assert in_c == out_c

      self.branch2 = nn.Sequential(
          # 1*1 pw conv
          nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
          nn.BatchNorm2d(half_c),
          nn.ReLU(True),
          # 3*3 dw conv, stride = 1
          nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False),
          nn.BatchNorm2d(half_c),
          # 1*1 pw conv
          nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
          nn.BatchNorm2d(half_c),
          nn.ReLU(True)
      )


  def forward(self, x):
    out = None
    if self.downsample:
      # if it is downsampling, we don't need to do channel split
      out = torch.cat((self.branch1(x), self.branch2(x)), 1)
    else:
      # channel split
      channels = x.shape[1]
      c = channels // 2
      x1 = x[:, :c, :, :]
      x2 = x[:, c:, :, :]
      new_x2 = self.branch2(x2)
      x2 = x2 + new_x2 # Residual connection
      out = torch.cat((x1, x2), 1)
    return channel_shuffle(out, 2)


class CustomModel(nn.Module):
  def __init__(self, num_classes=2):
    super().__init__()

    self.stage_repeat_num = [4, 8, 4]
    self.out_channels = [3, 24, 48, 96, 192, 1024]
    # self.out_channels = [3, 24, 116, 232, 464, 1024]
    # self.out_channels = [3, 24, 176, 352, 704, 1024]
    # self.out_channels = [3, 24, 244, 488, 976, 2948]

    # let's start building layers
    self.conv1 = nn.Conv2d(3, self.out_channels[1], 3, 2, 1)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    in_c = self.out_channels[1]

    self.stages = []
    for stage_idx in range(len(self.stage_repeat_num)):
      out_c = self.out_channels[2+stage_idx]
      repeat_num = self.stage_repeat_num[stage_idx]
      for i in range(repeat_num):
        if i == 0:
          self.stages.append(ShuffleBlock(in_c, out_c, downsample=True))
        else:
          self.stages.append(ShuffleBlock(in_c, in_c, downsample=False))
        in_c = out_c
    self.stages = nn.Sequential(*self.stages)

    in_c = self.out_channels[-2]
    out_c = self.out_channels[-1]
    self.conv5 = nn.Sequential(
      nn.Conv2d(in_c, out_c, 1, 1, 0, bias=False),
      nn.BatchNorm2d(out_c),
      nn.ReLU(True)
    )

    # fc layer
    self.fc = nn.Linear(out_c, num_classes)

  def forward(self, x):
    x = self.conv1(x)
    x = self.maxpool(x)
    x = self.stages(x)
    x = self.conv5(x)
    x = x.mean(-1).mean(-1)
    x = x.view(-1, self.out_channels[-1])
    x = self.fc(x)
    return x

class WrappedModel(torch.nn.Module):
  def __init__(self, n_classes, model_type):
    super().__init__()

    model_dict = {
      "ShuffleNetV2": "shufflenet_v2_x0_5",
      "ResNet50": "resnet50",
      "MobileNetV2": "mobilenet_v2",
    }

    if model_type in model_dict:
      self.model = torch.hub.load("pytorch/vision", model_dict[model_type], weights=None)
    else:
      raise Error(f"model_type {model_type} not supported")
    self.out_layer = torch.nn.Linear(1000, n_classes)

  def forward(self, x):
    x = self.model(x)
    x = self.out_layer(x)
    return x

class LitWrappedModel(L.LightningModule):
  def __init__(self, n_classes, model_type):
    super().__init__()
    if model_type == "Custom":
      self.model = CustomModel(num_classes = n_classes)
    else:
      self.model = WrappedModel(n_classes, model_type)
    self.n_classes = n_classes

    self.val_metrics = torchmetrics.MetricCollection(
        {
            "accuracy": torchmetrics.classification.Accuracy(task="multiclass", num_classes=n_classes),
            "f1": torchmetrics.classification.F1Score(task="multiclass", num_classes=n_classes),
            "auroc": torchmetrics.classification.AUROC(task="multiclass", num_classes=n_classes)
        },
        prefix="val_",
    )
    self.test_metrics = self.val_metrics.clone(prefix="test_")

  def training_step(self, batch, batch_idx):
    x, y = batch
    y_pred = self.model(x)
    loss = torch.nn.functional.cross_entropy(y_pred, y)
    self.log("train_loss", loss, on_step=False, on_epoch=True)
    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    y_pred = self.model(x)
    self.log_dict(self.val_metrics(y_pred, y), prog_bar=True)

  def test_step(self, batch, batch_idx):
    x, y = batch
    y_pred = self.model(x)
    self.log_dict(self.test_metrics(y_pred, y), prog_bar=True)

  def on_validation_epoch_end(self):
    L.pytorch.utilities.memory.garbage_collection_cuda()

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=1e-3)

plantvillage_data = PlantVillageData(num_workers=15)

for exp_name in ("Custom", "ShuffleNetV2", "ResNet50", "MobileNetV2"):
  lit_model = LitWrappedModel(plantvillage_data.n_classes, model_type=exp_name)

  trainer = L.Trainer(
      max_epochs=50,
      accelerator='gpu',
      callbacks=[RichProgressBar()],
      logger=CSVLogger("csv_logs/classification", name=exp_name, version=0),
  )
  trainer.fit(model=lit_model, datamodule=plantvillage_data)

  model_save_path = Path("models") / "classification"
  model_save_path = model_save_path / exp_name
  model_save_path.mkdir(exist_ok=True, parents=True)

  model = lit_model.model
  model = model.eval().cpu()
  ## Save just weights
  torch.save(model.state_dict(), model_save_path / f"weights_{exp_name}.pt")
  ## Pickle the whole model
  torch.save(model, model_save_path / f"model_{exp_name}.pt")
  ## Using experimental torch export
  _height = torch.export.Dim('_height', min=1)
  _width = torch.export.Dim('_width', min=1)
  dynamic_shapes = {"x": {
    0: torch.export.Dim("batch", min=1, max=9223372036854775806),
    2: 32*_height,
    3: 32*_width,
  }}
  ep = torch.export.export(model, (torch.randn(2, 3, 512, 512),), dynamic_shapes=dynamic_shapes, strict=True)
  torch.export.save(ep, model_save_path / f"export_{exp_name}.pt2")

  trainer.test(model=lit_model, datamodule=plantvillage_data)