In [1]:
%load_ext autoreload
%load_ext memory_profiler 
%load_ext dotenv
%autoreload 2
%dotenv

In [4]:
from lightning import Trainer
from torchvision.transforms import v2 as T # type: ignore
from geovision.logging import get_logger
from geovision.config.basemodels import ExperimentConfig # noqa
from geovision.data.module import ImageDatasetDataModule
from geovision.io.local import get_ckpt_path, get_experiments_dir
from geovision.training.module import ClassificationModule
from geovision.training.loggers import (
    get_csv_logger, 
    get_wandb_logger,
    get_ckpt_logger,
    get_lr_logger,
    get_classification_logger
)

In [5]:
from torch import float32
from geovision.data.imagenette import Imagenette

logger = get_logger("experiment_logger")
transforms: dict[str, T.Transform | None] = {
    "image_transform": T.Compose([
        T.ToImage(), 
        T.Resize((224, 224), antialias=True), 
        T.ToDtype(float32, scale = True),
        T.Normalize(Imagenette.means, Imagenette.std_devs),
    ]),
    "target_transform": None,
    #"common_transform": None,
    "common_transform": T.RandomOrder([
        T.RandomResizedCrop(224),
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
        T.RandomInvert(),
        T.RandomAutocontrast(),
        T.RandomAdjustSharpness(2.0)
    ]),
}
config = ExperimentConfig.from_yaml("config.yaml", transforms)
experiments_dir = get_experiments_dir(config)
datamodule = ImageDatasetDataModule(config)

loggers: list = list()
loggers.append(csv_logger := get_csv_logger(config))
#loggers.append(wandb_logger := get_wandb_logger(config))

callbacks: list = list()
#callbacks.append(metrics_logger := get_classification_logger(config))
#callbacks.append(lr_logger := get_lr_logger(config))
#callbacks.append(LearningRateFinder(num_training_steps=147, early_stop_threshold=None))
callbacks.append(ckpt_logger := get_ckpt_logger(config))

In [8]:
config.optimizer_params.model_dump(exclude_none=True)

{'lr': 1e-05, 'momentum': 0.9, 'weight_decay': 0.0005}

In [10]:
trainer = Trainer(
    max_epochs = 2,
    check_val_every_n_epoch = 1,
    num_sanity_val_steps = 0,
    log_every_n_steps = 1,

    logger = loggers,
    callbacks = callbacks,
    enable_checkpointing = True,
    enable_model_summary = False
)

litmodule = ClassificationModule(ExperimentConfig.from_yaml("config.yaml", transforms))
# trainer.fit(
    # model = litmodule,
    # datamodule = datamodule,
    # ckpt_path = get_ckpt_path(config)
# )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

pd.set_option("display.max_rows", 10)
pd.set_option("display.max_columns", None)

df = pd.read_csv(experiments_dir / "metrics.csv")
df["epoch"] = df["epoch"].ffill()
#df = df[["epoch", "step", "train/loss_epoch"]]
df

In [None]:
tlsdf = df[["step", "train/loss_step"]].dropna(axis=0).reset_index(drop=True)
xts, yts = tlsdf["step"], tlsdf["train/loss_step"]
plt.plot(xts, yts, label = "train/loss_step")

tledf = df[["step", "train/loss_epoch"]].dropna(axis=0).reset_index(drop=True)
xte, yte = tledf["step"], tledf["train/loss_epoch"]
plt.plot(xte, yte, label = "train/loss_epoch")

vledf = df[["step", "val/loss_epoch"]].dropna(axis=0).reset_index(drop=True)
xve, yve = vledf["step"], vledf["val/loss_epoch"]
plt.plot(xve, yve, label = "val/loss_epoch")

#plt.ylim((0, 6))
plt.legend()

In [None]:
vlsdf = df[["epoch", "step", "val/loss_step"]].dropna(axis=0).reset_index(drop=True)
xvs, yvs = vlsdf["step"], vlsdf["val/loss_step"]
plt.plot(xvs, yvs, label = "val/loss_step")
# pd.set_option("display.max_rows", None)
# vlsdf

In [None]:
vmedf = df[["epoch", "step", f"val/{config.metric}_epoch"]].dropna(axis=0).reset_index(drop=True)
xve, yve = vmedf["step"], vmedf["epoch"]

tmedf = df[["epoch", "step", f"train/{config.metric}_epoch"]].dropna(axis=0).reset_index(drop=True)
xte, yte = tmedf["step"], tmedf["epoch"]

#plt.plot(xve, yve, label = "val/iou")
#plt.plot(xte, yte, label = "train/iou")
display(vmedf)
display(tmedf)