In [1]:
%load_ext autoreload
%load_ext dotenv

%autoreload 2
%dotenv

In [2]:
# System Modules
import os
from pathlib import Path

# General Purpose Libraries 
import torch
import wandb
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib widget
import imageio.v3 as iio

# Paths, Datasets and Datamodules
from etl.pathfactory import PathFactory
from etl.etl import reset_dir
from geovision.data.datamodules import ImageDatasetDataModule 
from data.oxfordiiitpet import OxfordIIITPetSegmentation

# Transforms
import torchvision.transforms.v2 as t

# Models
from torchvision.models import alexnet, AlexNet_Weights

# Tasks
from training.tasks import SegmentationTask
from training.callbacks import SegmentationReport 
from training.callbacks import setup_logger, setup_wandb_logger, setup_checkpoint
from lightning.pytorch.loggers import WandbLogger

#Trainers
from lightning import Trainer

# Type Hints
from typing import Callable, Any, Optional, Literal
from numpy.typing import NDArray

# Logging
import logging
from lightning.pytorch.utilities import disable_possible_user_warnings # type: ignore
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
disable_possible_user_warnings()

In [3]:
os.environ["WANDB_NOTEBOOK_NAME"] = "segmentation_experiments.ipynb" 
experiment = {
    "name": "test_run",

    "model_name": "unet",
    "model_params": {
        "encoder": "resnet18",
        "decoder": "deconvolution",
        "weights": "imagenet",
    },

    "dataset_name": "oxford-iiit-pet",
    "task": "segmentation",
    "num_classes": OxfordIIITPetSegmentation.NUM_CLASSES,
    "class_names": OxfordIIITPetSegmentation.CLASS_NAMES,

    "random_seed": 69,

    "test_split": 0.2,
    "val_split": 0.2,
    "batch_size": 32,
    "grad_accum": 2,
    "num_workers": 4,

    "loss": "cross_entropy",
    "loss_params": {
        "reduction": "mean",
    },

    "optimizer": "adam",
    "optimizer_params": {
        "lr": 5e-6,
    },

    "monitor_metric": "jaccard",
    "monitor_mode": "max",
}
PATHS = PathFactory(experiment["dataset_name"], experiment["task"])

image_transform = t.Compose([
    t.ToImage(),
    t.ToDtype(torch.float32, scale=True),
])

mask_transform = t.Compose([
    t.ToImage(),
    t.ToDtype(torch.float32, scale=True),
])

augmentations = t.Compose([
    t.RandomResizedCrop((256, 256), interpolation=0, antialias=False)
])

datamodule = ImageDatasetDataModule(
    root = PATHS.path,
    is_remote = False,
    is_streaming = False,
    dataset_constructor = OxfordIIITPetSegmentation, 
    image_transform = image_transform,
    target_transform = mask_transform,
    common_transform = augmentations,
    **experiment
)
display(datamodule)


        
Local Dataset: oxford-iiit-pet @ [/home/sambhav/datasets/oxford-iiit-pet]
Configured For: segmentation
        
Random Seed: 69
Train: 60.0%
Val: 20.0%
Test: 20.0%
Batch Size: 16
        

In [4]:
from segmentation_models_pytorch import Unet
# from torch.nn import CrossEntropyLoss
# from torchmetrics import JaccardIndex
unet = Unet(experiment["model_params"]["encoder"], classes=experiment["num_classes"]) 

# datamodule.setup("fit")
# ds = datamodule.train_dataset 
# image, mask, idx = ds[np.random.randint(0, len(ds))]
# mask = mask.unsqueeze(0)
# pred = model(mask).detach()

# print(mask.shape, mask.dtype, mask.min(), mask.max())
# print(pred.shape, pred.dtype, pred.min(), pred.max())

# criterion = CrossEntropyLoss()
# print(f"{criterion._get_name()}: {criterion(pred, mask)}") 

# pred = torch.argmax(pred, 1)
# mask = torch.argmax(mask, 1)

# print(mask.shape, mask.dtype, mask.min(), mask.max())
# print(pred.shape, pred.dtype, pred.min(), pred.max())

# metric = JaccardIndex(task = "multiclass", num_classes = 3)
# print(f"{metric._get_name()}: {metric(pred, mask)}") 

# _, (l, m, r) = plt.subplots(1, 3, figsize = (10, 5))
# l.imshow(image.permute(1,2,0))
# m.imshow(mask.squeeze())
# r.imshow(pred.squeeze())
# plt.show()

In [5]:
logger = setup_logger(PATHS.experiments_path, experiment["name"])
reset_dir(logger.log_dir)
wandb_logger = setup_wandb_logger(PATHS.experiments_path, experiment["name"])
checkpoint = setup_checkpoint(Path(logger.log_dir, "model_ckpts"), experiment["monitor_metric"], experiment["monitor_mode"]) 
segmentation_report = SegmentationReport()

BEST_CKPT = checkpoint.best_model_path 
LAST_CKPT = checkpoint.last_model_path

trainer = Trainer(
    callbacks=[segmentation_report],
    logger = [logger],
    enable_model_summary=False,
    enable_checkpointing=False,
    accumulate_grad_batches=experiment["grad_accum"],
    #num_sanity_val_steps=0,

    #fast_dev_run=True
    max_epochs=9,
    check_val_every_n_epoch=2, 
    #limit_train_batches=10,
    #limit_test_batches=10
)

Monitoring: val_jaccard, Checkpoints Saved To: /home/sambhav/experiments/oxford-iiit-pet_segmentation/test_run/model_ckpts


In [6]:
trainer.fit(
    model=SegmentationTask(unet, **experiment),
    datamodule=datamodule,
    ckpt_path=LAST_CKPT if Path(LAST_CKPT).is_file() else None,
)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [9]:
trainer.test(
    model=SegmentationTask(unet, **experiment),
    datamodule=datamodule,
    ckpt_path=LAST_CKPT if Path(LAST_CKPT).is_file() else None,
    verbose=False
)

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_jaccard': 0.716504693031311,
  'test_CrossEntropyLoss': 0.3905423581600189,
  'test_MulticlassCohenKappa': 0.8001563549041748,
  'test_foreground_precision': 0.9031159281730652,
  'test_foreground_recall': 0.9353954195976257,
  'test_foreground_jaccard': 0.8500913381576538,
  'test_foreground_dice': 0.9189723134040833,
  'test_foreground_support': 43469532.0,
  'test_background_precision': 0.9303725957870483,
  'test_background_recall': 0.9090898036956787,
  'test_background_jaccard': 0.8511800765991211,
  'test_background_dice': 0.9196080565452576,
  'test_background_support': 39963832.0,
  'test_outline_precision': 0.6345722079277039,
  'test_outline_recall': 0.604204535484314,
  'test_outline_jaccard': 0.44824284315109253,
  'test_outline_dice': 0.6190161108970642,
  'test_outline_support': 13363308.0,
  'test_accuracy': 0.8788120746612549,
  'test_macro_precision': 0.8367182016372681,
  'test_macro_recall': 0.8318754434585571,
  'test_macro_jaccard': 0.7570815682411194,
  '