In [None]:
%load_ext autoreload
%load_ext memory_profiler 
%load_ext dotenv
%autoreload 2
%dotenv

In [None]:
from lightning import Trainer
from torchvision.transforms import v2 as T # type: ignore
from geovision.logging import get_logger
from geovision.config.basemodels import ExperimentConfig # noqa
from geovision.data.module import ImageDatasetDataModule
from geovision.io.local import get_ckpt_path, get_experiments_dir
from geovision.training.module import ClassificationModule
from geovision.training.loggers import (
    get_csv_logger, 
    get_wandb_logger,
    get_ckpt_logger,
    get_lr_logger,
    get_classification_logger
)
from geovision.analysis.plot_experiment import plot_experiment

In [None]:
from torch import float32
from geovision.data.imagenette import Imagenette

logger = get_logger("experiment_logger")
transforms: dict[str, T.Transform | None] = {
    "image_transform": T.Compose([
        T.ToImage(), 
        T.Resize((224, 224), antialias=True), 
        T.ToDtype(float32, scale = True),
        T.Normalize(Imagenette.means, Imagenette.std_devs),
    ]),
    "target_transform": None,
    #"common_transform": None,
    "common_transform": T.RandomChoice([
        T.RandomHorizontalFlip(0.5),
        T.RandomVerticalFlip(0.5),
        T.RandomInvert(0.5),
        T.RandomAutocontrast(0.5)
    ]),
}
config = ExperimentConfig.from_yaml("config.yaml", transforms)
experiments_dir = get_experiments_dir(config)
datamodule = ImageDatasetDataModule(config)

loggers: list = list()
loggers.append(csv_logger := get_csv_logger(config))
# loggers.append(wandb_logger := get_wandb_logger(config))

callbacks: list = list()
callbacks.append(ckpt_logger := get_ckpt_logger(config))
callbacks.append(metrics_logger := get_classification_logger(config))
# callbacks.append(lr_logger := get_lr_logger(config))
# callbacks.append(LearningRateFinder(num_training_steps=147, early_stop_threshold=None))

In [None]:
trainer = Trainer(
    max_epochs = 15,
    check_val_every_n_epoch = 2,
    num_sanity_val_steps = 0,
    log_every_n_steps = 1,
    # limit_train_batches = 5,
    # limit_val_batches = 5,

    logger = loggers,
    callbacks = callbacks,
    enable_checkpointing = True,
    enable_model_summary = False
)

litmodule = ClassificationModule(ExperimentConfig.from_yaml("config.yaml", transforms))
trainer.fit(
    model = litmodule,
    datamodule = datamodule,
    ckpt_path = get_ckpt_path(config)
)

In [None]:
plot_experiment(config)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from geovision.analysis.viz import plot_confusion_matrix, plot_metrics_table
%matplotlib tk

fig, ax = plt.subplots(1, 1, figsize = (5, 5), layout = "constrained")
plot_confusion_matrix(ax, np.random.randint(0, 10, (5, 5)))