In [None]:
%load_ext autoreload
%load_ext dotenv
%autoreload 2
%dotenv

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import torch
import logging
from lightning import Trainer
from geovision.data import ImageDatasetDataModule
from geovision.experiment.config import ExperimentConfig
from geovision.models.interfaces import ClassificationModule

from geovision.io.local import FileSystemIO as fs
from geovision.experiment.loggers import (
    get_csv_logger, 
    get_ckpt_logger,
    get_classification_logger
)

In [3]:
config = ExperimentConfig.from_yaml("config.yaml")
datamodule = ImageDatasetDataModule(config.dataset_constructor, config.dataset_config, config.dataloader_config)

logging.basicConfig(
    filename = f"{fs.get_new_dir(config.experiments_dir, "logs")}/logfile.log",
    filemode = "a",
    format = "%(asctime)s : %(name)s : %(levelname)s : %(message)s",
    level = logging.INFO
)

loggers: list = list()
loggers.append(csv_logger := get_csv_logger(config))

callbacks: list = list()
callbacks.append(ckpt_logger := get_ckpt_logger(config))
callbacks.append(metrics_logger := get_classification_logger(config))

In [None]:
config = ExperimentConfig.from_yaml("config.yaml")
model = ClassificationModule(
    model_config=config.model_config,
    criterion_constructor=config.criterion_constructor, 
    criterion_params=config.criterion_params, 
    optimizer_constructor=config.optimizer_constructor,
    optimizer_params=config.optimizer_params,
    lr_scheduler_constructor=config.scheduler_constructor,
    lr_scheduler_params=config.scheduler_params,
    warmup_scheduler_constructor=config.warmup_scheduler_constructor,
    warmup_scheduler_params=config.warmup_scheduler_params,
    warmup_steps=config.warmup_steps,
    scheduler_config_params=config.scheduler_config_params
)
trainer = Trainer(logger = loggers, callbacks = callbacks, **config.trainer_params)
# trainer.fit(
    # model = model, 
    # datamodule = datamodule, 
    # ckpt_path = config.ckpt_path 
# )

In [None]:
trainer.fit(
    model = model, 
    datamodule = datamodule, 
    ckpt_path = config.ckpt_path 
)