### Overall info

Folder structure will be: 

Dataset 
    |__RGB 
    |__HS 
    |__DEM 
    |__annotations 
    |__labels.csv 

### Observe
- All of the below script has been written but not tested 
- to run it, several installs might be needed. 

##### Useful links: 
- [source code](https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/image_processing_segformer.py)
- https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb



#### Define dataset 

In [None]:
from transformers import SegformerImageProcessor
import pandas as pd 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
from PIL import Image

# adapted from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb
class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir, labels_df, image_processor):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            image_processor (SegFormerImageProcessor): image processor to prepare images + segmentation maps.
            labels_df: pd.Dataframe of the image names to be used in the dataset
        """
        self.root_dir = root_dir
        self.image_processor = image_processor
        self.labels_df = labels_df

        self.img_dir = os.path.join(self.root_dir, "rgb")
        self.ann_dir = os.path.join(self.root_dir, "annotations")

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        img_name = self.labels_df['filenames'][idx]
        img_path = os.path.join(self.img_dir, f"{img_name}.jpg")
        ann_path = os.path.join(self.ann_dir, f"{img_name}.png")

        image = Image.open(img_path)
        segmentation_map = Image.open(ann_path)

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.image_processor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs
    
"""
WARNING: 
by default the image processor below will resize the image (to 512*512).
Essentially i don't want this, HOWEVER it might be necessary to be able
to use the weights from pretraining. TODO: Find out whether that's so.
"""

# split test and train set 
labels_file = "/Users/nadja/Documents/UU/Thesis/Data/TINYsampleFINAL/tinydataset/palsa_labels.csv"
all_imgs = pd.read_csv(labels_file, usecols=[0], header=0, names=['filenames'])
train_labels = all_imgs.sample(frac=0.8)
valid_labels = all_imgs.drop(train_labels.index).sample(frac=0.5)
test_labels = all_imgs.drop(train_labels.index).drop(valid_labels.index)

root_dir = "/Users/nadja/Documents/UU/Thesis/Data/TINYsampleFINAL/tinydataset"
image_processor = SegformerImageProcessor(
    img_mean = [74.90, 85.26, 80.06], # use mean calculated over our dataset
    img_std = [15.05, 13.88, 12.01], # use std calculated over our dataset
    do_reduce_labels=True
    # additionally, do i want some augmentation? 
    )

train_dataset = SemanticSegmentationDataset(root_dir, train_labels, image_processor)
valid_dataset = SemanticSegmentationDataset(root_dir, valid_labels, image_processor)

train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=20)

### Define model

### IMPORTANT: DO I WANT TO FREEZE LAYERS??

In [None]:
from transformers import SegformerForSemanticSegmentation

# define model
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b0", num_labels=1) # since we treat '0' as a background, the only class is palsa.

# Freeze encoder layers
for param in model.segformer.encoder.parameters():
    param.requires_grad = False

# Optionally, unfreeze the last few layers of the encoder
# Adjust the number of unfrozen blocks as needed
num_unfrozen_blocks = 2
for i in range(len(model.segformer.encoder.block) - num_unfrozen_blocks, len(model.segformer.encoder.block)):
    for param in model.segformer.encoder.block[i].parameters():
        param.requires_grad = True

# The decoder (model.decode_head) will be trained by default

### Finetune
based on [huggingface tutorial](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb)

DO i want to log all of this with wandb?

In [None]:
import torch
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
import evaluate # "from Datasets library" 

metric = evaluate.load("mean_iou")

epochs = 20
lr = 0.00006

# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) # these are the params used during training 
# move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(epochs):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for idx, batch in enumerate(train_dataloader):  
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()

        #### Look at tutorial notebook for metrics logging (and calculating) options 

    with torch.no_grad():
        for idx, batch in enumerate(valid_dataloader): # tqdm does a progress bar: do we want that? 
            # get the inputs;
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            # forward + backward + optimize
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss, logits = outputs.loss, outputs.logits


### Visualize a result with the trained model 

In [None]:
image = Image.open('imgpath')
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
    outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu()
predicted_segmentation_map = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = predicted_segmentation_map.cpu().numpy()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

color_seg = np.zeros((predicted_segmentation_map.shape[0],
                      predicted_segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3

color = np.array([4, 250, 7])
color_seg[predicted_segmentation_map == 0, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]

# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()