### Overall info

Folder structure will be: 

Dataset 
    |__RGB 
    |__HS 
    |__DEM 
    |__annotations 
    |__labels.csv 

### Observe
- All of the below script has been written but not tested 
- to run it, several installs might be needed. 

##### Useful links: 
- [source code](https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/image_processing_segformer.py)
- https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb



#### Define dataset 

In [2]:
from transformers import SegformerImageProcessor
import pandas as pd 
from torch.utils.data import Dataset, random_split
from torch.utils.data import DataLoader
import os
from PIL import Image
import numpy as np

# adapted from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb
class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir, image_processor):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            image_processor (SegformerImageProcessor): image processor to prepare images + segmentation maps.
        """
        self.root_dir = root_dir
        self.image_processor = image_processor

        self.img_dir = os.path.join(self.root_dir, "images")
        self.ann_dir = os.path.join(self.root_dir, "masks")
        
        # Get all image filenames without extension
        self.filenames = [os.path.splitext(f)[0] for f in os.listdir(self.img_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_name = self.filenames[idx]
        img_path = os.path.join(self.img_dir, f"{img_name}.jpg")
        ann_path = os.path.join(self.ann_dir, f"{img_name}.png")

        image = Image.open(img_path)
        segmentation_map = Image.open(ann_path)

        # Convert segmentation map to numpy array
        segmentation_map_np = np.array(segmentation_map)
        
        # Convert 255 to 1 (target class) and keep 0 as background
        segmentation_map_np = (segmentation_map_np > 0).astype(np.uint8)
        
        # Convert back to PIL Image
        segmentation_map = Image.fromarray(segmentation_map_np)

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.image_processor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs
    
"""
WARNING: 
by default the image processor below will resize the image (to 512*512).
Essentially i don't want this, HOWEVER it might be necessary to be able
to use the weights from pretraining. TODO: Find out whether that's so.
"""

root_dir = "/Users/nadja/Documents/1. Project/Thesis/Permafrost-Segmentation/Supervised_dataset"
image_processor = SegformerImageProcessor(
    image_mean = [74.90, 85.26, 80.06], # use mean calculated over our dataset
    image_std = [15.05, 13.88, 12.01], # use std calculated over our dataset
    do_reduce_labels=True
    # additionally, do i want some augmentation? 
    )

# Create the full dataset
full_dataset = SemanticSegmentationDataset(root_dir, image_processor)

# Split the dataset into 85% train and 15% validation
total_size = len(full_dataset)
train_size = int(0.85 * total_size)
valid_size = total_size - train_size

train_dataset, valid_dataset = random_split(full_dataset, [train_size, valid_size])

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=64)

NameError: name '_C' is not defined

### Define model

### IMPORTANT: DO I WANT TO FREEZE LAYERS??

In [5]:
from transformers import SegformerForSemanticSegmentation

# define model
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b0", 
    num_labels=1# since we treat '0' as a background, the only class is palsa.
) 

# Freeze encoder layers
for param in model.segformer.encoder.parameters():
    param.requires_grad = False

# Optionally, unfreeze the last few layers of the encoder
# Adjust the number of unfrozen blocks as needed
num_unfrozen_blocks = 2
for i in range(len(model.segformer.encoder.block) - num_unfrozen_blocks, len(model.segformer.encoder.block)):
    for param in model.segformer.encoder.block[i].parameters():
        param.requires_grad = True

# The decoder (model.decode_head) will be trained by default

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


check the size of next iteration of valid loader image and labels, as well as last model output size. 

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torchmetrics.functional import jaccard_index

# Set device to CPU
device = torch.device("cpu")
model.to(device)

# Minimal example of forward & backward pass on one image
def train_step(model, batch):
    pixel_values = batch["pixel_values"].to(device)
    labels = batch["labels"].to(device)

    model.train()
    optimizer.zero_grad()

    outputs = model(pixel_values=pixel_values, labels=labels)
    loss, logits = outputs.loss, outputs.logits

    loss.backward()
    optimizer.step()

    return loss.item()

# Minimal example of evaluation on one image
def eval_step(model, batch):
    pixel_values = batch["pixel_values"].to(device)
    labels = batch["labels"].to(device)

    model.eval()
    with torch.no_grad():
        outputs = model(pixel_values=pixel_values, labels=labels)
        logits = outputs.logits

        predicted = (logits.squeeze(1) > 0).float()
        upsampled_predicted = F.interpolate(predicted.unsqueeze(1), size=labels.shape[-2:], mode="nearest")

        jaccard = jaccard_index(upsampled_predicted.squeeze(1), labels, task="binary", num_classes=2)
        target_jaccard = jaccard_index(upsampled_predicted.squeeze(1), labels, task="binary", num_classes=2, ignore_index=0)

    return jaccard.item(), target_jaccard.item()

# Get one batch from each dataloader
train_batch = next(iter(train_dataloader))
valid_batch = next(iter(valid_dataloader))

# Perform one training step
loss = train_step(model, train_batch)
print(f"Training loss: {loss:.4f}")

# Perform one evaluation step
jaccard, target_jaccard = eval_step(model, valid_batch)
print(f"Validation Jaccard Score: {jaccard:.4f}, Target Class Jaccard Score: {target_jaccard:.4f}")


### Finetune
based on [huggingface tutorial](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb)

DO i want to log all of this with wandb?

In [7]:
import torch
from torch import nn
from torchmetrics.functional import jaccard_index
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
import torch.nn.functional as F

epochs = 20
lr = 0.00006
warmup_steps = 100  # Adjust this value as needed

# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# define scheduler
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

# move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Move optimizer to GPU (possibly unneccessary)
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

# Early stopping parameters
patience = 5
best_target_jaccard = 0
epochs_no_improve = 0

model.train()
for epoch in range(epochs):
    print(f"Epoch: {epoch}")
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}")
    for batch in progress_bar:  
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()
        scheduler.step()  # Update learning rate

        # Update progress bar
        progress_bar.set_postfix({"Loss": f"{loss.item():.4f}", "LR": f"{scheduler.get_last_lr()[0]:.6f}"})

    model.eval()
    jaccard_scores = []
    target_jaccard_scores = []
    with torch.no_grad():
        for batch in valid_dataloader:
            # get the inputs;
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            # forward pass
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss, logits = outputs.loss, outputs.logits

            # Calculate Jaccard score
            # Since we only have one feature map, we can use a threshold to determine the segmentation
            predicted = (logits.squeeze(1) > 0).float()  # Threshold at 0
            upsampled_predicted = F.interpolate(predicted.unsqueeze(1), size=labels.shape[-2:], mode="nearest")

            # Calculate Jaccard score (IoU) for both classes
            jaccard = jaccard_index(upsampled_predicted.squeeze(1), labels, task="binary", num_classes=2)
            jaccard_scores.append(jaccard.item())

            # Calculate Jaccard score (IoU) for target class only
            target_jaccard = jaccard_index(upsampled_predicted.squeeze(1), labels, task="binary", num_classes=2, ignore_index=0)
            target_jaccard_scores.append(target_jaccard.item())

    avg_jaccard = sum(jaccard_scores) / len(jaccard_scores)
    avg_target_jaccard = sum(target_jaccard_scores) / len(target_jaccard_scores)
    print(f"Epoch {epoch}, Average Jaccard Score: {avg_jaccard:.4f}, Target Class Jaccard Score: {avg_target_jaccard:.4f}")
    
    # Early stopping check based on target Jaccard score
    if avg_target_jaccard > best_target_jaccard:
        best_target_jaccard = avg_target_jaccard
        epochs_no_improve = 0
        # Save the best model
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        epochs_no_improve += 1
        if epochs_no_improve == patience:
            print(f"Early stopping triggered. No improvement in target Jaccard score for {patience} epochs.")
            break
    
    model.train()

# Load the best model after training
model.load_state_dict(torch.load('best_model.pth'))


Epoch: 0


Epoch 1/20: 100%|██████████| 194/194 [16:30<00:00,  5.11s/it, Loss=0.0190, LR=0.000059]


ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [128] and output size of torch.Size([512, 512]). Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.

### Visualize a result with the trained model 

In [None]:
image = Image.open('imgpath')
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
    outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu()
predicted_segmentation_map = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = predicted_segmentation_map.cpu().numpy()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

color_seg = np.zeros((predicted_segmentation_map.shape[0],
                      predicted_segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3

color = np.array([4, 250, 7])
color_seg[predicted_segmentation_map == 0, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]

# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()