In [1]:
%load_ext autoreload
%load_ext dotenv

%autoreload 2
%dotenv

In [2]:
from pathlib import Path
import torch
import torchvision.transforms.v2 as t
from lightning import Trainer

from data.imagenette import ImagenetteClassification
from data.plant_disease import PlantDiseaseClassification 
from data.inria import InriaHDF5, InriaImageFolder
from data.datamodules import ImageDatasetDataModule 

from training.tasks import ClassificationTask
from training.callbacks import (
    setup_logger, setup_wandb_logger, setup_checkpoint, eval_callback
)


from etl.pathfactory import PathFactory
from etl.etl import reset_dir

import logging
from lightning.pytorch.utilities import disable_possible_user_warnings # type: ignore
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
disable_possible_user_warnings()

In [3]:
DATASET = InriaImageFolder
# MODEL = Unet
experiment = {
    "name": "test_run",
    "model_name": "unet",
    "model_params": {
        "encoder": "resnet18",
        "decoder": "deconvolution",
        "weights": "imagenet",
    },

    "dataset_name": DATASET.NAME,
    "task": DATASET.TASK,
    "num_classes": DATASET.NUM_CLASSES,
    "class_names": DATASET.CLASS_NAMES,

    "random_seed": 69,

    "test_split": 0.2,
    "val_split": 0.2,
    "batch_size": 2,
    "grad_accum": 1,
    "num_workers": 4,

    "loss": "cross_entropy",
    "loss_params": {
        "reduction": "mean",
    },

    "optimizer": "adam",
    "optimizer_params": {
        "lr": 1e-5,
    },

    "monitor_metric": "iou",
    "monitor_mode": "max",

    "tile_size": (512, 512),
    "tile_stride": (512, 512),
}
PATHS = PathFactory(experiment["dataset_name"], experiment["task"])
LOGS_DIR = PATHS.experiments_path / experiment["name"]

# NOTE: t.Normalize(DATASET.MEANS, DATASET.STD_DEVS),
image_transform = t.Compose([t.ToImage(), t.ToDtype(torch.float32, scale=True)])
mask_transform = t.Compose([t.ToImage(), t.ToDtype(torch.float32, scale=False)])
#augmentations = t.Compose([t.RandomHorizontalFlip(0.5))
augmentations = None

datamodule = ImageDatasetDataModule(
    root = PATHS.path,
    is_remote = False,
    is_streaming = False,
    dataset_constructor = DATASET, 
    image_transform = image_transform,
    target_transform = mask_transform,
    common_transform = augmentations,
    **experiment
)
display(datamodule)
logger = setup_logger(PATHS.experiments_path, experiment["name"])
wandb_logger = setup_wandb_logger(PATHS.experiments_path, experiment["name"])
checkpoint = setup_checkpoint(Path(logger.log_dir, "model_ckpts"), experiment["monitor_metric"], experiment["monitor_mode"], "all") 
reset_dir(LOGS_DIR)


        
Local Dataset: urban_footprint @ [/home/sambhav/datasets/urban_footprint]
Configured For: segmentation
        
Random Seed: 69
Train: 60.0%
Val: 20.0%
Test: 20.0%
Batch Size(//grad_accum): 2
Tile Kernel: (512, 512), Stride: (512, 512)
        

Local Logging To : /home/sambhav/experiments/urban_footprint_segmentation/test_run
WandB Logging To: /home/sambhav/experiments/urban_footprint_segmentation/test_run/wandb
Monitoring: val/iou, Checkpoints Saved To: /home/sambhav/experiments/urban_footprint_segmentation/test_run/model_ckpts


In [4]:
# Models

#from torchvision.models import resnet18, ResNet18_Weights
#model = resnet18(weights = ResNet18_Weights.DEFAULT)
#model.fc = torch.nn.Linear(512, experiment["num_classes"])

#from torchvision.models import alexnet, AlexNet_Weights
#model = alexnet(weights=AlexNet_Weights.DEFAULT)
#model = alexnet(weights=None)
#model.classifier[-1] = torch.nn.Linear(4096, experiment.get("num_classes", 10))

from segmentation_models_pytorch import Unet
model = Unet(experiment["model_params"]["encoder"], classes=experiment["num_classes"]) 

In [5]:
BEST_CKPT = checkpoint.best_model_path 
LAST_CKPT = checkpoint.last_model_path
evaluation = eval_callback(experiment["task"])
trainer = Trainer(
    callbacks=[checkpoint, evaluation],
    logger = [logger, wandb_logger],
    enable_model_summary=False,
    #fast_dev_run=True,
    num_sanity_val_steps=0,
    max_epochs=1,
    #check_val_every_n_epoch=3, 
    limit_train_batches=100,
    limit_val_batches=100,
)

In [6]:

#experiment["optimizer_params"]["lr"] =  5e-6 
trainer.fit(
    model=ClassificationTask(model, **experiment),
    datamodule=datamodule,
    ckpt_path=LAST_CKPT if Path(LAST_CKPT).is_file() else None,
    #verbose=False
)

#trainer.test(
    #model=ClassificationTask(model, **experiment),
    #datamodule=datamodule,
    #ckpt_path=LAST_CKPT if Path(LAST_CKPT).is_file() else None,
    #verbose = False
#)


Monitor Metric: iou


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113569433333245, max=1.0…

Tiled Dataset
train dataset at /home/sambhav/datasets/urban_footprint
Tiled Dataset
val dataset at /home/sambhav/datasets/urban_footprint


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
from training.evaluation import checkpoints_df, plot_checkpoints, plot_checkpoint_attribution

plot_checkpoints(LOGS_DIR, experiment["monitor_metric"], checkpoints_df(LOGS_DIR, experiment["monitor_metric"]))

In [None]:
epoch = 11 
step = 3024 
split = "best"
k = 25 
plot_checkpoint_attribution(model, LOGS_DIR, PATHS.path, DATASET, epoch, step, split, k, **experiment)