# Fine-tunning SETR PUP on Parihaka Dataset

## Imports

In [None]:
from common import get_data_module, get_trainer_pipeline
import torch
from minerva.models.nets.image.setr import SETR_PUP
from functools import partial
from torchmetrics import JaccardIndex

## Variaveis

In [2]:
root_data_dir = "/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images"
root_annotation_dir = "/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations"
img_size = (1008, 784)          # Change this to the size of the images in the dataset
model_name = "setr_pup"       # Model name (just identifier)
dataset_name = "seam_ai"        # Dataset name (just identifier)
single_channel = False          # If True, the model will be trained with single channel images (instead of 3 channels)

log_dir = "./logs"              # Directory to save logs
batch_size = 1                  # Batch size    
seed = 42                       # Seed for reproducibility
num_epochs = 100                # Number of epochs to train
is_debug = True                 # If True, only 3 batch will be processed for 3 epochs
accelerator = "gpu"             # CPU or GPU
devices = 1                     # Num GPUs

## Data Module

In [None]:
data_module = get_data_module(
    root_data_dir=root_data_dir,
    root_annotation_dir=root_annotation_dir,
    img_size=img_size,
    batch_size=batch_size,
    seed=seed,
    single_channel=single_channel
)

data_module

In [None]:
# Just to check if the data module is working
data_module.setup("fit")
train_batch_x, train_batch_y = next(iter(data_module.train_dataloader()))
train_batch_x.shape, train_batch_y.shape

## **** Create and Load model HERE ****

In [None]:
model = SETR_PUP(
    image_size=img_size,
    num_classes=6,
    load_backbone_path="/home/shared/notebooks_e_pesos/SETR_pesos/vit_large_p16_new.pth",
    original_resolution=(256, 576),
    train_metrics={"IoU": JaccardIndex(task="multiclass", num_classes=6)},
    val_metrics={"IoU": JaccardIndex(task="multiclass", num_classes=6)},
    test_metrics={"IoU": JaccardIndex(task="multiclass", num_classes=6)},
)


model

## Pipeline

In [None]:
pipeline = get_trainer_pipeline(
    model=model,
    model_name=model_name,
    dataset_name=dataset_name,
    log_dir=log_dir,
    num_epochs=num_epochs,
    accelerator=accelerator,
    devices=devices,
    is_debug=is_debug,
    seed=seed,
)

In [None]:
pipeline.run(data_module, task="fit")

In [None]:
print(f"Checkpoint saved at {pipeline.trainer.checkpoint_callback.last_model_path}")