In [None]:
#    [x] add fn to plot metrics from the experiment logger, with interactive buttons to choose runs
#    [] add fn to stream metrics.csv over ssh
# 5. [] Refactor and document geovision.io.local, add SSH and S3 handlers to geovision.io.remote
# 6. [] add LR logging (train_lr_epoch), read lighting lr logger callback

# Datasets
# 1. [x] Fix Imagenet, like errors when loading images. add the default ilsvrc-2012 train/val split
# 2. Add FMOW 
# 3. Add Pascal VOC / MS COCO / OxfordIIITPets
# 4. [x] Figure out how to do transformations properly, with preprocessing, train augmentations, eval augmentations, etc.
# 5. Add geosampler for large georegistered scenes (Geo-Tiling for Semantic Segmentation)

# Analysis
# 1. Add ~efficient (numba/cupy/mojo) functions to compute image dataset statistics.
# -> pixel values (bucket sort?), 

# Tests
# 1. Refactor dataset and datamodule tests. Add to test_dataset/datamodule.py to run with pytest. 
# 2. Test for expected output shapes for each sample and batch size after transformations, print any errors/inconsistencies  
# 3. Test against overlapping samples inter and intra split

In [1]:
%load_ext autoreload
%load_ext memory_profiler 
%load_ext dotenv
%autoreload 2
%dotenv

In [2]:
import yaml
import logging
from lightning import Trainer
from geovision.config import ExperimentConfig
from geovision.data.module import ImageDatasetDataModule
from geovision.models.module import ClassificationModule

from geovision.io.local import get_new_dir, get_ckpt_path
from geovision.loggers.experiment_loggers import (
    get_csv_logger, 
    get_ckpt_logger,
    get_classification_logger
)

In [3]:
with open("config.yaml") as f:
    config_dict = yaml.load(f, Loader=yaml.Loader)
    config_dict["dataset_params"]["random_seed"] = config_dict["random_seed"]
exec(config_dict["transforms_script"])
config_dict["dataset_params"]["image_pre"] = image_pre # type: ignore # noqa: F821
config_dict["dataset_params"]["target_pre"] = target_pre # type: ignore # noqa: F821
config_dict["dataset_params"]["train_aug"] = train_aug # type: ignore # noqa: F821
config_dict["dataset_params"]["eval_aug"] = eval_aug # type: ignore # noqa: F821
config = ExperimentConfig.model_validate(config_dict)

logger = logging.getLogger(__name__)
logging.basicConfig(
    filename = f"{get_new_dir("logs")/config.name}.log",
    filemode = "a",
    format = "%(asctime)s : %(name)s : %(levelname)s : %(message)s",
    level = logging.INFO
)
datamodule = ImageDatasetDataModule(config)

loggers: list = list()
loggers.append(csv_logger := get_csv_logger(config))

callbacks: list = list()
callbacks.append(ckpt_logger := get_ckpt_logger(config))
callbacks.append(metrics_logger := get_classification_logger(config))

In [None]:
config = ExperimentConfig.from_config("config.yaml")
trainer = Trainer(logger = loggers, callbacks = callbacks, **config.trainer_params)
trainer.fit(ClassificationModule(config), datamodule = datamodule, ckpt_path = get_ckpt_path(config))