In [None]:
%load_ext autoreload
%load_ext dotenv
%autoreload 2
%dotenv 

In [None]:
import torch
from pathlib import Path
from tqdm.auto import tqdm
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchvision.transforms import v2 as t
from typing import Callable, Optional, Literal, Any

import sys; sys.path.append("../") if "../" not in sys.path else None
from viz.dataset_plots import plot_segmentation_samples

import logging
from lightning.pytorch.utilities import disable_possible_user_warnings # type: ignore
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
disable_possible_user_warnings()

In [None]:
#from datasets.inria import InriaLitData, InriaSegmentation, InriaHDF5, InriaImageFolder
#inria_kwargs = {
    ##"root" : Path.home() / "datasets" / "urban_footprint",
    ###"shards": Path.home() / "shards" / "urban_footprint",
    #"test_split": 0.2, "val_split": 0.1, "random_seed": 69,
    #"tile_size": (512, 512), "tile_stride": (512, 512),
    ##"split": "train",
    ##"shard_size_in_mb": 256 
#}

#ds = InriaHDF5(
    #root = Path.home() / "datasets" / "urban_footprint" / "tiled" / "inria.h5",
    #split = "train"
    #**inria_kwargs,
#)
#dataset_df = InriaSegmentation.tiled_df(**inria_kwargs)
#dataset_df
#import pandas as pd
#train_df = InriaSegmentation.scene_df(**inria_kwargs)
#train_df.assign(tile_name = lambda df: df.apply(lambda x: x.iloc[0]), axis = 0)
#train_df = train_df[train_df["split"] == "train"]
#eval_df = InriaSegmentation.tiled_df(**inria_kwargs)
#eval_df = eval_df[eval_df["split"] != "train"]
#dataset_df = pd.concat([train_df, eval_df], axis = 0)
#dataset_df = dataset_df[dataset_df["split"] != "unsup"]

In [None]:
#InriaSegmentation.write_to_hdf(
    #root = Path.home() / "datasets" / "urban_footprint",
    #target = Path.home() / "datasets" / "urban_footprint" / "file_write_test",
    #df = dataset_df
#)

In [None]:
#plot_segmentation_samples(ds, 32)

In [None]:
# TODO: Use Pytorch Profiler Here

from segmentation_models_pytorch import Unet
from datasets.inria import InriaImageFolder, InriaLitData, InriaHDF5

inria_kwargs = {
    "root" : Path.home() / "datasets" / "urban_footprint" / "tiled" / "inria.h5", 
    "split": "train",
    "shuffle": True,
    "test_split": 0.2, "val_split": 0.1, "random_seed": 69,
    "tile_size": (512, 512), "tile_stride": (512, 512)
}

dl = DataLoader(
    dataset = InriaHDF5(**inria_kwargs), 
    batch_size = 2,
    num_workers = 4, 
    #prefetch_factor = 10, 
    pin_memory = True, 
    shuffle = True,
    persistent_workers = True
)

unet = Unet("resnet18", classes=2, encoder_weights="imagenet") 
loss_fn = torch.nn.BCEWithLogitsLoss()
adam = torch.optim.Adam(unet.parameters(), lr = 1e-5)

def train_one_epoch(dataloader: DataLoader, model: Module, criterion: Module, optimizer: Optimizer, limit_train_batches: Optional[int] = None):
    if limit_train_batches is None:
        limit_train_batches = len(dataloader)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train()
    for idx, batch in tqdm(enumerate(dataloader), total = limit_train_batches, unit = "steps"):
        if idx >= limit_train_batches:
            break
            
        images, masks = batch[0].to(device), batch[1].to(device)
        preds = model(images) 
        loss = criterion(preds.argmax(1).to(torch.float32), masks.argmax(1).to(torch.float32)).mean()
        loss.requires_grad_()
        #print(f"Step: {idx}, Loss: {loss}")
        #print(images.shape, images.dtype, images.min().item(), images.max().item())
        #print(masks.shape, masks.dtype, masks.min().item(), masks.max().item())
        #print(preds.shape, preds.dtype, preds.min().item(), preds.max().item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

train_one_epoch(dl, unet, loss_fn, adam)
torch.cuda.empty_cache()