In [None]:
%load_ext autoreload
%load_ext dotenv

%autoreload 2
%dotenv

In [None]:
# System Modules
import os
from pathlib import Path

# General Purpose Libraries 
import torch
import wandb
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib widget
import imageio.v3 as iio

# Paths, Datasets and Datamodules
from etl.pathfactory import PathFactory
from etl.etl import reset_dir
from data.datamodules import ImageDatasetDataModule 
from datasets.imagenette import ImagenetteClassification

# Transforms
import torchvision.transforms.v2 as t

# Models
from torchvision.models import alexnet, AlexNet_Weights

# Tasks
from training.tasks import ClassificationTask
from training.callbacks import ClassificationReport 
from training.callbacks import setup_logger, setup_wandb_logger, setup_checkpoint
from lightning.pytorch.loggers import WandbLogger

#Trainers
from lightning import Trainer

# Type Hints
from typing import Callable, Any, Optional, Literal
from numpy.typing import NDArray

# Logging
import logging
from lightning.pytorch.utilities import disable_possible_user_warnings # type: ignore
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
disable_possible_user_warnings()

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "experiments.ipynb" 
experiment = {
    "name": "test_run",

    "dataset_name": "imagenette",
    "model_name": "alexnet-pretrained",
    "task": "classification",
    "num_classes": 10,
    "class_names": ImagenetteClassification.CLASS_NAMES,

    "random_seed": 69,

    "test_split": 0.2,
    "val_split": 0.2,
    "batch_size": 32,
    "grad_accum": 1,
    "num_workers": 4,

    "loss": "cross_entropy",
    "loss_params": {
        "reduction": "mean",
    },

    "optimizer": "adam",
    "optimizer_params": {
        "lr": 5e-6,
    },

    "monitor_metric": "accuracy",
    "monitor_mode": "max",
}
PATHS = PathFactory(experiment["dataset_name"], experiment["task"])

image_transform = t.Compose([
    t.ToImage(),
    t.ToDtype(torch.float32, scale=True),
    t.Normalize(ImagenetteClassification.MEANS, ImagenetteClassification.STD_DEVS),
    t.Resize((224, 224), antialias=True),
    t.RandomHorizontalFlip(0.5),
])

datamodule = ImageDatasetDataModule(
    root = PATHS.path,
    is_remote = False,
    is_streaming = False,
    dataset_constructor = ImagenetteClassification, 
    image_transform = image_transform,
    **experiment
)
display(datamodule)

alexnet_p = alexnet(weights=AlexNet_Weights.DEFAULT)
alexnet_p.classifier[-1] = torch.nn.Linear(4096, experiment.get("num_classes", 10))

In [None]:
logger = setup_logger(PATHS.experiments_path, experiment["name"])
reset_dir(logger.log_dir)
wandb_logger = setup_wandb_logger(PATHS.experiments_path, experiment["name"])
checkpoint = setup_checkpoint(Path(logger.log_dir, "model_ckpts"), experiment["monitor_metric"], experiment["monitor_mode"]) 
classification_report = ClassificationReport()

BEST_CKPT = checkpoint.best_model_path 
LAST_CKPT = checkpoint.last_model_path

trainer = Trainer(
    callbacks=[checkpoint, classification_report],
    logger = [logger, wandb_logger],
    enable_model_summary=False,
    #num_sanity_val_steps=0,

    max_epochs=8,
    check_val_every_n_epoch=2, 
    #limit_train_batches=10,
    #limit_val_batches=10,
    #limit_test_batches=10
)

In [None]:
trainer.fit(
    model=ClassificationTask(alexnet_p, **experiment),
    datamodule=datamodule,
    ckpt_path= LAST_CKPT if Path(LAST_CKPT).is_file() else None,
)

In [None]:
trainer.validate(
    model=ClassificationTask(alexnet_p, **experiment),
    datamodule=datamodule,
    ckpt_path=LAST_CKPT if Path(LAST_CKPT).is_file() else None,
    verbose = False
)