### Author
0001128790 - Christian Di Buò - christian.dibuo@studio.unibo.it

### Importing images

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import os
import json
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from PIL import Image

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms

from enum import IntEnum
from typing import Optional
import subprocess, sys

In [None]:
package_name = "evaluate"

try:
    __import__(package_name)
    print('already installed')
except ImportError:
    print(f"{package_name} is NOT installed! Installing now...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", package_name]);

In [None]:
import evaluate

"""
Source: https://github.com/hendrycks/anomaly-seg/issues/15#issuecomment-890300278
"""
COLORS = np.array([
    [  0,   0,   0],  # unlabeled    =   0,
    [ 70,  70,  70],  # building     =   1,
    [190, 153, 153],  # fence        =   2, 
    [250, 170, 160],  # other        =   3,
    [220,  20,  60],  # pedestrian   =   4, 
    [153, 153, 153],  # pole         =   5,
    [157, 234,  50],  # road line    =   6, 
    [128,  64, 128],  # road         =   7,
    [244,  35, 232],  # sidewalk     =   8,
    [107, 142,  35],  # vegetation   =   9, 
    [  0,   0, 142],  # car          =  10,
    [102, 102, 156],  # wall         =  11, 
    [220, 220,   0],  # traffic sign =  12,
    [ 60, 250, 240],  # anomaly      =  13,
]) 


def color(img_pil: str, colors: np.ndarray) -> Image.Image:
    #img_pil = Image.open(annot_path)
    img_np = np.array(img_pil)
    img_new = np.zeros((720, 1280, 3))

    for index, color in enumerate(colors):
        img_new[img_np == index + 1] = color
    
    return Image.fromarray(img_new.astype("uint8"), "RGB")

In [None]:
class StreetHazardsDataset(Dataset):
    def __init__(self, odgt_file, transform1=None, transform2=None):
        """
        Args:
            odgt_file (str): Path to the .odgt file (train, val, or test).
            transform (callable, optional): Transformations to apply to images and masks.
        """

        self.transform1 = transform1
        self.transform2 = transform2

        # Load the .odgt file
        with open(odgt_file, "r") as f:
            odgt_data = json.load(f)

        self.paths = [
            {
                "image": os.path.join(Path(odgt_file).parent, data["fpath_img"]),
                "annotation": os.path.join(Path(odgt_file).parent, data["fpath_segm"]),
            }
            for data in odgt_data 
        ]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):

        # Build full paths for image and mask
        image = Image.open(self.paths[idx]["image"]).convert("RGB")
        annotation = Image.open(self.paths[idx]["annotation"])

        if self.transform1:
            image = self.transform1(image)
            annotation = torch.as_tensor(transforms.functional.pil_to_tensor(annotation), dtype=torch.int64) - 1 # Make class indexes start from 0
            annotation = self.transform2(annotation).squeeze(0)

        return {"pixel_values": image, "labels": annotation}


def visualize_annotation(annotation_img: np.ndarray|torch.Tensor, ax=None):
    """
    Adapted from https://github.com/CVLAB-Unibo/ml4cv-assignment/blob/master/utils/visualize.py
    """
    if ax is None: ax = plt.gca()
    annotation_img = np.asarray(annotation_img)
    img_new = np.zeros((*annotation_img.shape, 3))

    for index, color in enumerate(COLORS):
        img_new[annotation_img == index] = color

    ax.imshow(img_new / 255.0)
    ax.set_xticks([])
    ax.set_yticks([])

def visualize_scene(image: np.ndarray|torch.Tensor, ax=None):
    if ax is None: ax = plt.gca()
    image = np.asarray(image)
    ax.imshow(np.moveaxis(image, 0, -1))
    ax.set_xticks([])
    ax.set_yticks([])



In [None]:
from torch.utils.data import DataLoader

# Define transforms
transform1 = transforms.Compose([
    transforms.Resize((520, 520), transforms.InterpolationMode.NEAREST),
    transforms.ToTensor()
])

transform2 = transforms.Compose([
    transforms.Resize((520, 520), transforms.InterpolationMode.NEAREST)
])

# Create dataset instance
train_dataset = StreetHazardsDataset(
    odgt_file="/kaggle/input/streethazards_train/train/train.odgt",
    transform1=transform1,
    transform2=transform2
)
val_dataset = StreetHazardsDataset(
    odgt_file="/kaggle/input/streethazards_train/train/validation.odgt",
    transform1=transform1,
    transform2=transform2
)

In [None]:
print(val_dataset[0]['labels'].shape, len(val_dataset))
visualize_annotation(val_dataset[0]['labels'])

In [None]:
from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation

num_classes = len(COLORS) - 1

#model_name = "apple/deeplabv3-mobilevit-xx-small"
model_name = "apple/deeplabv3-mobilevit-small"
feature_extractor = MobileViTImageProcessor.from_pretrained(model_name)
model = MobileViTForSemanticSegmentation.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device);

In [None]:
from transformers import TrainingArguments

epochs = 20
batch_size = 8

training_args = TrainingArguments(
    output_dir="test_dir",
    save_total_limit=1,
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    weight_decay=0.01,
    eval_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none",
    greater_is_better=True,
    metric_for_best_model="eval_mean_iou",
    seed=42,
    logging_strategy='epoch'
)

In [None]:
import torch
import torch.nn.functional as F

# Load mean_iou metric
metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    
    # Convert logits and labels to PyTorch tensors if they are NumPy arrays
    logits = torch.tensor(logits)
    labels = torch.tensor(labels)

    # Resize logits to match the label size (520x520)
    logits_resized = F.interpolate(logits, size=labels.shape[1:], mode='bilinear', align_corners=False)
    
    # Get the predicted class (argmax over the class dimension)
    pred = logits_resized.argmax(dim=1)  # Predicted class labels
    
    # Convert predictions and labels to numpy for the metric
    pred = pred.numpy()
    labels = labels.numpy()
    
    # Compute mean IoU
    result = metric.compute(predictions=pred, references=labels, num_labels= num_classes, ignore_index= 255)
    
    return {"eval_mean_iou": result["mean_iou"]}

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

In [None]:
trainer.train()

In [None]:
pred = trainer.predict(val_dataset)

In [None]:
test_metrics = compute_metrics((pred.predictions, pred.label_ids))
print(test_metrics)