In [1]:
%load_ext autoreload
%load_ext dotenv
%autoreload 2
%dotenv

In [2]:
import sys
sys.path.append("../") if "../" not in sys.path else None

In [3]:
from pathlib import Path

import torch
import torchvision.transforms.v2 as t
from datasets.datamodules import ImageDatasetDataModule
#from datasets.inria import InriaHDF5
from datasets.plant_disease import PlantDiseaseClassification
from etl.pathfactory import PathFactory
from lightning import Trainer

# setup_eval
from etl.etl import reset_dir
from training.callbacks import EvaluateClassification, EvaluateSegmentation
from training.tasks import ClassificationTask
from training.utils import (
    setup_checkpoint,
    setup_logger,
    setup_wandb_logger,
)


In [4]:
import logging
import os

from lightning.pytorch.utilities import disable_possible_user_warnings  # type: ignore

logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
disable_possible_user_warnings()

os.environ["WANDB_CONSOLE"] = "off"
os.environ["WANDB_SILENT"] = "true"

In [7]:
DATASET = PlantDiseaseClassification 
# MODEL = Unet
experiment = {
    "name": "plant_disease_tuning",
    "model_name": "resnet18",
    "model_params": {
        "encoder": "resnet18",
        "decoder": "mlp",
        "weights": "imagenet",
    },
    "dataset_name": DATASET.NAME,
    "task": DATASET.TASK,
    "num_classes": DATASET.NUM_CLASSES,
    "class_names": DATASET.CLASS_NAMES,
    "random_seed": 69,
    "test_split": 0.2,
    "val_split": 0.1,
    "batch_size": 32,
    "grad_accum": 1,
    "num_workers": 4,
    "loss": "cross_entropy",
    "loss_params": {
        "reduction": "mean",
    },
    "optimizer": "adam",
    "optimizer_params": {
        "lr": 1e-5,
    },
    "monitor_metric": "f1",
    "monitor_mode": "max",
    # "tile_size": (512, 512),
    # "tile_stride": (512, 512),
}
PATHS = PathFactory(experiment["dataset_name"], experiment["task"])
LOGS_DIR = PATHS.experiments_path / experiment["name"]

# NOTE: t.Normalize(DATASET.MEANS, DATASET.STD_DEVS),
image_transform = t.Compose([t.ToImage(), t.ToDtype(torch.float32, scale=True), t.Resize((224, 224))])
# mask_transform = t.Compose([t.ToImage(), t.ToDtype(torch.float32, scale=False)])
#augmentations = t.AutoAugment(t.AutoAugmentPolicy.IMAGENET, )
augmentations = t.RandomChoice([
    t.RandomHorizontalFlip(),
    t.RandomEqualize(), 
    t.RandomAdjustSharpness(sharpness_factor=0.5), 
    t.RandomAutocontrast()
    ], p = [0.8, 0.5, 0.5, 0.2])

# import pandas as pd
# train_df = DATASET.scene_df(**experiment)
# train_df = train_df[train_df["split"] == "train"]
# eval_df = DATASET.tiled_df(**experiment)
# eval_df = eval_df.drop(columns = "tile_name")
# eval_df = eval_df[eval_df["split"] != "train"]
# dataset_df = pd.concat([train_df, eval_df], axis = 0)
# dataset_df = dataset_df[dataset_df["split"] != "unsup"]

datamodule = ImageDatasetDataModule(
    root= Path.home() / "datasets" / "plant_village",
    is_remote=False,
    is_streaming=False,
    dataset_constructor=DATASET,
    # dataframe = dataset_df,
    image_transform=image_transform,
    #target_transform=mask_transform,
    common_transform=augmentations,
    **experiment,
)

# Models

from torchvision.models import resnet18, ResNet18_Weights # noqa
model = resnet18(weights = ResNet18_Weights.DEFAULT)
model.fc = torch.nn.Linear(512, experiment["num_classes"])

# from torchvision.models import alexnet, AlexNet_Weights
# model = alexnet(weights=AlexNet_Weights.DEFAULT)
# model = alexnet(weights=None)
# model.classifier[-1] = torch.nn.Linear(4096, experiment.get("num_classes", 10))

# from segmentation_models_pytorch import Unet
# model = Unet(experiment["model_params"]["encoder"], classes=experiment["num_classes"])

display(datamodule)
logger = setup_logger(PATHS.experiments_path, experiment["name"])
wandb_logger = setup_wandb_logger(PATHS.experiments_path, experiment["name"])
checkpoint = setup_checkpoint(
    Path(logger.log_dir, "model_ckpts"),
    experiment["monitor_metric"],
    experiment["monitor_mode"],
    "all",
)
eval_callback = (
    EvaluateSegmentation()
    if experiment["task"] == "segmentation"
    else EvaluateClassification()
)
#reset_dir(LOGS_DIR)


        
Local Dataset: plant_village @ [/home/sambhav/datasets/plant_village]
Configured For: classification
        
Random Seed: 69
Train: 70.0%
Val: 10.0%
Test: 20.0%
Batch Size(//grad_accum): 32
        

Local Logging To : /home/sambhav/experiments/plant_village_classification/plant_disease_tuning
WandB Logging To: /home/sambhav/experiments/plant_village_classification/plant_disease_tuning/wandb
Monitoring: val/f1, Checkpoints Saved To: /home/sambhav/experiments/plant_village_classification/plant_disease_tuning/model_ckpts


PosixPath('/home/sambhav/experiments/plant_village_classification/plant_disease_tuning')

In [8]:
BEST_CKPT = checkpoint.best_model_path
LAST_CKPT = checkpoint.last_model_path

trainer = Trainer(
    callbacks=[checkpoint, eval_callback],
    # enable_checkpointing=False,
    logger=[logger, wandb_logger],
    enable_model_summary=False,
    #fast_dev_run=True,
    num_sanity_val_steps=0,
    max_epochs=6,
    check_val_every_n_epoch=3,
    #limit_val_batches=10,
    #limit_train_batches=10,
)

In [9]:
experiment["optimizer_params"]["lr"] =  5e-6
trainer.fit(
    model=ClassificationTask(model, **experiment),
    datamodule=datamodule,
    ckpt_path=LAST_CKPT if Path(LAST_CKPT).is_file() else None,
    # verbose=False
)

trainer.test(
    model=ClassificationTask(model, **experiment),
    datamodule=datamodule,
    ckpt_path=LAST_CKPT if Path(LAST_CKPT).is_file() else None,
    verbose = False
)

Monitor Metric: f1


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113053177779067, max=1.0…

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Monitor Metric: f1


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test/apple-apple_scab_precision': 0.9897959232330322,
  'test/apple-apple_scab_recall': 0.9700000286102295,
  'test/apple-apple_scab_iou': 0.9603960514068604,
  'test/apple-apple_scab_f1': 0.9797979593276978,
  'test/apple-apple_scab_support': 100.0,
  'test/apple-black_rot_precision': 1.0,
  'test/apple-black_rot_recall': 1.0,
  'test/apple-black_rot_iou': 1.0,
  'test/apple-black_rot_f1': 1.0,
  'test/apple-black_rot_support': 100.0,
  'test/apple-cedar_apple_rust_precision': 1.0,
  'test/apple-cedar_apple_rust_recall': 0.9900000095367432,
  'test/apple-cedar_apple_rust_iou': 0.9900000095367432,
  'test/apple-cedar_apple_rust_f1': 0.9949748516082764,
  'test/apple-cedar_apple_rust_support': 100.0,
  'test/apple-healthy_precision': 1.0,
  'test/apple-healthy_recall': 1.0,
  'test/apple-healthy_iou': 1.0,
  'test/apple-healthy_f1': 1.0,
  'test/apple-healthy_support': 100.0,
  'test/banana-bbs_precision': 1.0,
  'test/banana-bbs_recall': 0.15000000596046448,
  'test/banana-bbs_iou':

In [None]:
from training.evaluation import (
    checkpoints_df,
    plot_checkpoint_attribution,
    plot_checkpoints,
)

plot_checkpoints(
    LOGS_DIR,
    experiment["monitor_metric"],
    checkpoints_df(LOGS_DIR, experiment["monitor_metric"]),
)

In [10]:
epoch = 8
step = 7038
split = "worst"
k = 25
plot_checkpoint_attribution(
    model, LOGS_DIR, PATHS.path, DATASET, epoch, step, split, k, **experiment
)

NameError: name 'plot_checkpoint_attribution' is not defined