In [None]:
%load_ext autoreload
%load_ext memory_profiler 
%load_ext dotenv
%autoreload 2
%dotenv

In [None]:
from lightning import Trainer
from torchvision.transforms import v2 as T # type: ignore
from geovision.logging import get_logger
from geovision.config.basemodels import ExperimentConfig # noqa
from geovision.data.module import ImageDatasetDataModule
from geovision.io.local import get_ckpt_path, get_experiments_dir
from geovision.training.module import ClassificationModule
from geovision.training.loggers import (
    get_csv_logger, 
    get_wandb_logger,
    get_ckpt_logger,
    get_lr_logger,
    get_classification_logger
)

In [None]:
from torch import float32
from geovision.data.imagenette import Imagenette

logger = get_logger("experiment_logger")
transforms: dict[str, T.Transform | None] = {
    "image_transform": T.Compose([T.ToImage(), T.ToDtype(float32), T.Normalize(Imagenette.means, Imagenette.std_devs), T.RandomResizedCrop(224, antialias=True)]),
    "target_transform": None,
    "common_transform": T.RandomHorizontalFlip(),
}
config = ExperimentConfig.from_yaml("config.yaml", transforms)
datamodule = ImageDatasetDataModule(config)

loggers: list = list()
loggers.append(csv_logger := get_csv_logger(config))
#loggers.append(wandb_logger := get_wandb_logger(config))

callbacks: list = list()
#callbacks.append(metrics_logger := get_classification_logger(config))
#callbacks.append(lr_logger := get_lr_logger(config))
callbacks.append(ckpt_logger := get_ckpt_logger(config))

In [None]:

trainer = Trainer(
    logger = loggers,
    callbacks = callbacks,
    max_epochs = 6,
    check_val_every_n_epoch = 1,
    num_sanity_val_steps = 0,
    log_every_n_steps = 1 
    #limit_train_batches=50
    #limit_val_batches=1
)



In [None]:
# litmodule = ClassificationModule(ExperimentConfig.from_yaml("config.yaml", transforms))
# wandb_logger.watch(litmodule)
trainer.fit(
    model = ClassificationModule(ExperimentConfig.from_yaml("config.yaml", transforms)),
    datamodule = datamodule,
    ckpt_path = sorted(get_experiments_dir(config).rglob("*.ckpt"))[-1]
)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
pd.set_option("display.max_rows", 10)
pd.set_option("display.max_columns", None)

df = pd.read_csv(csv_logger.log_dir + "/metrics.csv")
df["epoch"] = df["epoch"].ffill()
#df = df[["epoch", "step", "train/loss_epoch"]]
df

In [None]:
tlsdf = df[["step", "train/loss_step"]].dropna(axis=0).reset_index(drop=True)
xts, yts = tlsdf["step"], tlsdf["train/loss_step"]
plt.plot(xts, yts)

tledf = df[["step", "train/loss_epoch"]].dropna(axis=0).reset_index(drop=True)
xte, yte = tledf["step"], tledf["train/loss_epoch"]
plt.plot(xte, yte)

vledf = df[["step", "val/loss_epoch"]].dropna(axis=0).reset_index(drop=True)
xve, yve = vledf["step"], vledf["val/loss_epoch"]
plt.plot(xve, yve)

vlsdf = df[["epoch", "step", "val/loss_step"]].dropna(axis=0).reset_index(drop=True)
xvs, yvs = vlsdf["step"], vlsdf["val/loss_step"]
plt.plot(xvs, yvs)


In [None]:
vmedf = df[["epoch", "step", f"val/{config.metric}_epoch"]].dropna(axis=0).reset_index(drop=True)
vmedf