# Train Semantic Segmentation

* Define `GarrulusAoiDataModule` that uses `RandomBatchAoiSampler` for the train data
* [`Update 23.08.2024`]: Test and validation also use `RandomBatchAoiSampler`, ToDo: `GridAoiSampler`

## Create Garrulus AOI Data Module

In [None]:
import os
import numpy as np

import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from torchgeo.models import ResNet18_Weights, ResNet50_Weights

from gdl.datamodules.geo import GarrulusAoiDataModule
from gdl.trainers.segmentation import GarrulusSemanticSegmentationTask

In [None]:
grid_path='../../../field-D/grid-10m-squares/grid-10m-squares.shp'
fenced_area_path='../../../field-D/boundary-shape/boundary-shape.shp'
raster_image_root_path = "../../../field-D"
mask_root_path = "../../../field-D/d-RGB-9mm-mask"

batch_size = 64
size_lims = (128,256) # size of the window to sample (min,max)
length = 1000 # the number of data to sample from the raster image given the size limits
img_size = 224 # image size for the model input. since the sample windows vary, they will be transformed to img_size

gdm_aoi = GarrulusAoiDataModule(
    raster_image_path=raster_image_root_path,
    mask_path=mask_root_path,
    grid_shape_path=grid_path,
    fenced_area_shape_path=fenced_area_path,
    batch_size=batch_size,
    size_lims=size_lims,
    img_size=img_size,
    class_set=5,
    length=length
)

## Create Segmentation Task

In [None]:
num_workers = 16
max_epochs = 100
fast_dev_run = False
num_classes = 5

accelerator = "gpu" if torch.cuda.is_available() else "cpu"
default_root_dir = os.path.join("./logs", "experiments")
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", dirpath=default_root_dir, save_top_k=1, save_last=True
)
early_stopping_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=10)
logger = TensorBoardLogger(save_dir=default_root_dir, name="all_logs")

In [None]:
# create segmentationt ask
task = GarrulusSemanticSegmentationTask(
    model = 'unet',
    backbone = 'resnet50',
    loss="ce",
    weights=ResNet50_Weights.SENTINEL2_RGB_SECO,
    in_channels=3,
    num_classes=num_classes,
    lr=0.001,
    patience=5,
    # labels=["BACKGROUND","CWD","MISC","STUMP","VEGETATION"]
)

### Training

In [None]:
trainer = Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    fast_dev_run=fast_dev_run,
    log_every_n_steps=1,
    logger=logger,
    min_epochs=80,
    max_epochs=max_epochs,
    accelerator="auto",
    devices=1,
    strategy="auto",
)

In [None]:
trainer.fit(model=task, datamodule=gdm_aoi)

### Test

In [None]:
trainer.test(model=task, datamodule=gdm_aoi, ckpt_path="logs/experiments/last-v2.ckpt")

### Prediction (Inference)

In [None]:
preds = trainer.predict(model=task, datamodule=gdm_aoi, 
                ckpt_path="logs/experiments/last-v2.ckpt", return_predictions=True)

In [None]:
model = torch.load("logs/experiments/last-v2.ckpt")
task.load_state_dict(model['state_dict'])

In [None]:
# visualize prediction
num_samples = 10
for batch in gdm_aoi.test_dataloader():
    image = batch['image']
    mask = batch['mask']

    with torch.no_grad():
        logits = task(image)
        preds = torch.argmax(logits, dim=1)
        
    for i in range(num_samples):
        sample = {}
        sample['image'] = image[i]
        sample['mask'] = mask[i]
        sample['prediction'] = preds[i]
        gdm_aoi.dataset.plot(sample)
    break