In [1]:
%load_ext autoreload
%load_ext dotenv

%autoreload 2
%dotenv

In [None]:
## Environment Variables
from dotenv import load_dotenv
load_dotenv(".env");

## System Modules
from pathlib import Path

## General Purpose Libraries 
import torch
import matplotlib.pyplot as plt

## Paths and Directory Management
from etl.pathfactory import PathFactory
from etl.etl import reset_dir

## Datasets and Datamodules
from data.datamodules import ImageDatasetDataModule 
from datasets.inria import InriaBase, InriaImageFolder, InriaStreaming, InriaHDF5 

## Transforms
import torchvision.transforms.v2 as t

## Tasks
from training.tasks import SegmentationTask 

## Loggers
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger, CSVLogger
from lightning import seed_everything

##Trainers
from lightning import Trainer

#Types
from typing import Literal

from datasets.test import test_func

In [None]:
# Train Val Test Splits
# 1. Random Split: Each location will be split based on the test_split and val_split parameters
# 2. Continental Split: Train on Europe, Test on NA or Vice Versa 
# 3. Cultural Split: Train on Developed Locations Like Paris, Chicago and Zurich and Test on Rawanda, Kenya and Rio  
# 4. Unsupervised Split: Unsupervised Training on Inria-Test and Finetune on Inria-Train (with varying fractions of training data)

In [None]:
LOGS_DIR = Path.cwd() / "logs"
CHECKPOINTS_DIR = LOGS_DIR / "checkpoints"
reset_dir(LOGS_DIR)
reset_dir(CHECKPOINTS_DIR)

def setup_checkpoint(
        ckpt_dir: Path,
        metric: str,
        mode: Literal["min", "max"],
    ) -> ModelCheckpoint:
    return ModelCheckpoint(
        dirpath = ckpt_dir,
        monitor = metric,
        mode = mode,
        filename = "{epoch}-{" + metric + ":.2f}",
        save_top_k = 1,
        save_last = True,
        save_on_train_epoch_end = True)

def setup_logger(
        logs_dir: Path,
        name: str,
        version: int
    ):
    return CSVLogger(
        save_dir=logs_dir,
        name=name,
        version=version)

In [None]:
experiment = {
    "dataset_name": "urban-footprint",
    "task": "segmentation",
    "random_seed": 69,
    "tile_size": (512, 512),
    "tile_stride": (512, 512),

    "val_split": 0.2,
    "test_split": 0.2,
    "batch_size": 4,
    "grad_accum": 1,
    "num_workers": 4,

    "num_classes": 2,
    "loss": "binary_cross_entropy",
    "optimizer": "adam",
    "learning_rate": 1e-5,

    "checkpoint_metric": "val_macro_precision",
    "checkpoint_mode": "max"
}
seed_everything(experiment["random_seed"]);

model_ckpt = setup_checkpoint(
    CHECKPOINTS_DIR,
    experiment["checkpoint_metric"],
    experiment["checkpoint_mode"]
)

logger = setup_logger(
    LOGS_DIR,
    experiment["dataset_name"] + '-' + experiment["task"],
    version = 1 
)

paths = PathFactory(experiment["dataset_name"], experiment["task"])

image_transform = t.Compose([
    t.ToImage(),
    t.ToDtype(torch.float32, scale=True),
])

mask_transform = t.Compose([
    t.ToImage(),
    t.ToDtype(torch.float32, scale=True),
])

augmentations = t.Compose([
    t.Identity()
])

datamodule = ImageDatasetDataModule(
    root = paths.path,
    is_remote = False,
    is_streaming = False,
    dataset_constructor = InriaHDF5, 
    image_transform = image_transform,
    target_transform = mask_transform,
    common_transform = augmentations,
    **experiment
)

In [None]:
from torchgeo.models import FCN
model = FCN(
    in_channels = 3,
    classes = 2,
    num_filters = 32 
)

In [None]:
last_ckpt_path = (CHECKPOINTS_DIR / "last.ckpt").as_posix() if (CHECKPOINTS_DIR / "last.ckpt").is_file() else None
trainer = Trainer(
    logger=logger,
    callbacks=model_ckpt,
    max_epochs=10,
    check_val_every_n_epoch=11
)

In [None]:
trainer.fit(
    model = SegmentationTask(model, **experiment),
    datamodule = datamodule,
    ckpt_path = last_ckpt_path 
)

In [None]:
trainer.validate(
    model = SegmentationTask(model, **experiment),
    datamodule = datamodule,
    ckpt_path = last_ckpt_path 
)