### installing required dependencies

In [None]:
!pip install segmentation-models-pytorch albumentations --no-deps

In [None]:
!pip install fvcore

### Importing images

In [None]:
import os, json
import subprocess, sys
import random
import math
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from tqdm import tqdm
import copy
from copy import deepcopy
import cv2
from PIL import Image, ImageDraw

from sklearn.metrics import average_precision_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
from torchvision.transforms import v2
from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
from torch.optim.lr_scheduler import _LRScheduler
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import losses
from torchinfo import summary
from abc import ABC
import pathlib, torchvision

In [None]:
import wandb

WANDB_USER = "chri-project"
WANDB_PROJECT = "ML4CV--assignment"
wandb.login(key='2b387b514b9fcec8902df2b863ae0646f56125d6')

In [None]:
def fix_random(seed: int) -> None:
    """
    Fix all the possible sources of randomness.

    Args:
        seed: the seed to use.
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

fix_random(seed=42)

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
"""
Source: https://github.com/hendrycks/anomaly-seg/issues/15#issuecomment-890300278
"""
COLORS = np.array([
    [  0,   0,   0],  # unlabeled    =   0,
    [ 70,  70,  70],  # building     =   1,
    [190, 153, 153],  # fence        =   2, 
    [250, 170, 160],  # other        =   3,
    [220,  20,  60],  # pedestrian   =   4, 
    [153, 153, 153],  # pole         =   5,
    [157, 234,  50],  # road line    =   6, 
    [128,  64, 128],  # road         =   7,
    [244,  35, 232],  # sidewalk     =   8,
    [107, 142,  35],  # vegetation   =   9, 
    [  0,   0, 142],  # car          =  10,
    [102, 102, 156],  # wall         =  11, 
    [220, 220,   0],  # traffic sign =  12,
    [ 60, 250, 240],  # anomaly      =  13,
]) 

## TODO: Show the imbalance of the classes, if any. In this way you can justify that there are errors

In [None]:
class StreetHazardsDataset(Dataset):
    def __init__(self, odgt_file, image_resize=(512, 896), spatial_transforms=None, images_only_transforms=None):
        """
        Args:
            odgt_file (str): Path to the .odgt file (train, val, or test).
            transform (callable, optional): Transformations to apply to images and masks.
            compute_dist_map: is used to pre-compute the distance maps to use then in the Boundary loss.
        """

        self.spatial_transforms = spatial_transforms
        self.images_only_transforms = images_only_transforms
        self.image_resize = image_resize

        # Load the .odgt file
        with open(odgt_file, "r") as f:
            odgt_data = json.load(f)
        

        self.paths = [
            {
                "image": os.path.join(Path(odgt_file).parent, data["fpath_img"]),
                "labels": os.path.join(Path(odgt_file).parent, data["fpath_segm"]),
            }
            for data in odgt_data 
        ]
    
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):

        image = Image.open(self.paths[idx]["image"]).convert("RGB")
        labels = Image.open(self.paths[idx]["labels"])

        #to_tensor
        #image = transforms.ToTensor()(image)
        #labels = torch.as_tensor(transforms.functional.pil_to_tensor(labels), dtype=torch.int64) - 1

        if self.image_resize:
            image = transforms.Resize(self.image_resize, transforms.InterpolationMode.BILINEAR)(image)
            labels = transforms.Resize(self.image_resize, transforms.InterpolationMode.NEAREST)(labels)
            
        if self.spatial_transforms:
            image, labels  = self.spatial_transforms(image, labels)         

        #to_tensor
        image = transforms.ToTensor()(image)
        labels = torch.as_tensor(transforms.functional.pil_to_tensor(labels), dtype=torch.int64) - 1
        
        labels = labels.squeeze(0)
        
        if self.images_only_transforms:
            image = self.images_only_transforms(image)

        return {'image' : image, 'labels' : labels}

In [None]:
def visualize_annotation(annotation_img: np.ndarray|torch.Tensor, ax=None):
    """
    Adapted from https://github.com/CVLAB-Unibo/ml4cv-assignment/blob/master/utils/visualize.py
    """
    if ax is None: ax = plt.gca()
    annotation_img = np.asarray(annotation_img)
    img_new = np.zeros((*annotation_img.shape, 3))

    for index, color in enumerate(COLORS):
        img_new[annotation_img == index] = color

    ax.imshow(img_new / 255.0)
    ax.set_xticks([])
    ax.set_yticks([])

def visualize_scene(image: np.ndarray|torch.Tensor, ax=None):
    if ax is None: ax = plt.gca()
    image = np.asarray(image)
    ax.imshow(np.moveaxis(image, 0, -1))
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
class MixDataset(Dataset):
    def __init__(self, inlier_dataset, outlier_dataset, images_only_transforms= None, anomaly_probability=1, max_anomalies = 4, anomaly_idx = 13):
        
        self.inlier_dataset = inlier_dataset
        self.outlier_dataset = outlier_dataset
        self.anomaly_idx = anomaly_idx
        self.anomaly_probability = anomaly_probability
        self.max_anomalies = max_anomalies
        self.images_only_transforms = images_only_transforms

    def inject_anomalies(self, sh_img, sh_lbl):

        n_anomalies = random.randint(1, self.max_anomalies)
        for i in range(n_anomalies):

            rand_idx = random.randint(0, len(self.outlier_dataset)-1)

            anomaly_size = (np.random.randint(sh_img.shape[1]*0.1, sh_img.shape[1]*0.3), np.random.randint(sh_img.shape[2]*0.1, sh_img.shape[2]*0.3))
            i, j, h, w = transforms.RandomCrop.get_params(sh_img, output_size=anomaly_size)
            possible_classes = []
            while len(possible_classes) == 0: # In some cases there are no classes available
                anomaly_idx = np.random.randint(0, len(self.outlier_dataset))
                anomaly_image = transforms.ToTensor()(self.outlier_dataset[anomaly_idx][0])
                anomaly_annot = torch.from_numpy(np.array(self.outlier_dataset[anomaly_idx][1])).unsqueeze(0)
                possible_classes = np.unique(anomaly_annot)[1:-1]

            anomaly_class = np.random.choice(possible_classes)
            anomaly_image = F.interpolate(anomaly_image.unsqueeze(0), size=(h, w), mode="bilinear").squeeze(0)
            anomaly_annot = F.interpolate(anomaly_annot.unsqueeze(0), size=(h, w), mode="nearest").squeeze((0, 1))

            # Insert anomaly
            sh_img[:, i:i+h, j:j+w][:, anomaly_annot == anomaly_class] = anomaly_image[:, anomaly_annot == anomaly_class]
            sh_lbl[i:i+h, j:j+w][anomaly_annot == anomaly_class] = self.anomaly_idx
            anomaly_annot[anomaly_annot == anomaly_class] = self.anomaly_idx
            anomaly_annot[anomaly_annot != self.anomaly_idx] = 0

        return sh_img, sh_lbl, anomaly_image, anomaly_annot

    def __len__(self):
        return len(self.inlier_dataset)

    def __getitem__(self, idx):

        sh_img, sh_lbl = self.inlier_dataset[idx].values()
        p = random.random()
        if p < self.anomaly_probability:
            sh_img, sh_lbl, ood_img, ood_lbl = self.inject_anomalies(sh_img, sh_lbl)
            w, h = sh_lbl.shape[0], sh_lbl.shape[1]
            ood_img = transforms.Resize((w, h), transforms.InterpolationMode.BILINEAR)(ood_img)
            ood_lbl = transforms.Resize((w, h), transforms.InterpolationMode.NEAREST)(ood_lbl.unsqueeze(0)).squeeze(0)

        if self.images_only_transforms:
            sh_img = self.images_only_transforms(sh_img)
            
        return {"image" : sh_img,"labels": sh_lbl}
        #return sh_img, sh_lbl, ood_img, ood_lbl

In [None]:
def compute_mean_std(loader):
    """
    compute mean and standard deviation to normalize the images; can be used in alternative of imagenet mean and std.
    (seems to worse results)
    """
    mean = 0.0
    std = 0.0
    nb_samples = 0

    for batch in tqdm(loader):
        images = batch["image"]
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1) 
    
        mean += images.mean(2).sum(0)  
        std += images.std(2).sum(0)
        nb_samples += batch_samples 
        del batch

    mean /= nb_samples
    std /= nb_samples
    return mean, std

In [None]:
def plot_class_frequency(pixels_per_class, total_pixels, ax):
    percentages = (pixels_per_class / total_pixels) * 100
    class_names = list(range(len(pixels_per_class)))
    
    bars = ax.barh(class_names, percentages, color='skyblue')

    ax.set_ylabel('Class Name', fontsize=12)
    ax.set_xlabel('Total Pixel Count', fontsize=12)
    ax.set_title('Pixel Frequencies Per Class', fontsize=14)
    ax.set_yticks(class_names)
    ax.set_yticklabels(class_names, rotation=0, fontsize=10)
    ax.tick_params(axis='x', labelsize=10)

    for bar in bars:
        xval = bar.get_width()
        ax.text(xval + (max(percentages) * 0.01), 
                 bar.get_y() + bar.get_height() / 2,
                 f'{xval:.2f}%', ha='left', va='center', fontsize=9)

    # Optional: return bars if needed for further customization
    return bars


In [None]:
def compute_class_frequency(dataset, num_classes, normalize = False, plot_frequencies = False, ax= None):
    """
    count the number of pixels belonging to each class
    return: the computed weights
    """
    data_loader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=2)
    pixels_per_class = torch.zeros(num_classes, dtype=torch.long)

    for batch in tqdm(data_loader):
        labels = batch['labels']  # Assumes shape: (B, H, W)
        labels = labels.view(-1)  # Flatten all pixels
        counts = torch.bincount(labels, minlength=num_classes)
        pixels_per_class += counts

    total_pixels = pixels_per_class.sum()
    frequencies = pixels_per_class.float() / total_pixels
    weights = 1.0 / frequencies
    if normalize:
        weights = weights / weights.sum() * len(weights)

    if plot_frequencies:
        plot_class_frequency(pixels_per_class, total_pixels, ax)

    return weights

#weight_ce = compute_class_frequency(dataset= train_dataset, num_classes=13, normalize = True)
#print(weight_ce)

In [None]:
anomaly_dataset_path = './data_voc'
voc_train = torchvision.datasets.VOCSegmentation(anomaly_dataset_path, image_set="train", download=True)
voc_val = torchvision.datasets.VOCSegmentation(anomaly_dataset_path, image_set="val", download=True)

In [None]:
def as_numpy(obj):
    if torch.is_tensor(obj):
        return obj.cpu().numpy()
    else:
        return np.array(obj)

In [None]:
class AUPR:
    def __init__(self, anomaly_idx = 13):

        self.mean_aupr = []
        self.anomaly_idx = anomaly_idx

    def update(self, anomaly_score, labels):
    
        for i in range(anomaly_score.shape[0]):
            
            preds, lbl = anomaly_score[i], (labels[i] == self.anomaly_idx)
            
            if preds.dim() == 2: preds = preds.unsqueeze(0)
            if lbl.dim() == 2: lbl = lbl.unsqueeze(0)
            preds, lbl = preds.cpu(), lbl.cpu()
            self.mean_aupr.append(average_precision_score(lbl.type(torch.int32).flatten().numpy(), preds.type(torch.float32).flatten().numpy()))
            
    def get_results(self):
        
        return sum(self.mean_aupr)/len(self.mean_aupr) * 100

In [None]:
class MeanIoU:
    """
    taken from https://github.com/Jun-CEN/Open-World-Semantic-Segmentation/blob/main/DeepLabV3Plus-Pytorch/metrics/stream_metrics.py
    """
    def __init__(self, n_classes= 13):
        self.n_classes = n_classes
        self.confusion_matrix = np.zeros((n_classes, n_classes))
        
    def update(self, label_trues, logits):
        label_preds = torch.argmax(logits, dim=1)
        label_preds, label_trues = label_preds.cpu().numpy(), label_trues.cpu().numpy()
        for lt, lp in zip(label_trues, label_preds):
            self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten())

    def _fast_hist(self, label_true, label_pred):
        mask = (label_true >= 0) & (label_true < self.n_classes)
        hist = np.bincount(
            self.n_classes * label_true[mask].astype(int) + label_pred[mask],
            minlength=self.n_classes ** 2,
        ).reshape(self.n_classes, self.n_classes)
        return hist

    def get_results(self):
        """Returns accuracy score evaluation result.
            - overall accuracy
            - mean accuracy
            - mean IU
            - fwavacc
        """
        hist = self.confusion_matrix
        acc = np.diag(hist).sum() / hist.sum()
        acc_cls = np.diag(hist) / hist.sum(axis=1)
        acc_cls = np.nanmean(acc_cls)
        iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
        mean_iu = np.nanmean(iu, axis= 0)
        freq = hist.sum(axis=1) / hist.sum()
        fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
        cls_iu = dict(zip(range(self.n_classes), iu))

        return {
                #"Overall Acc": acc,
                #"Mean Acc": acc_cls,
                #"FreqW Acc": fwavacc,
                "Mean IoU": mean_iu,
                "Class IoU": cls_iu,
            }


In [None]:
from kornia.morphology import dilation, erosion
from scipy import ndimage as ndi

d_k1 = torch.zeros((1, 1, 2 * 1 + 1, 2 * 1 + 1)).cuda()
d_k2 = torch.zeros((1, 1, 2 * 2 + 1, 2 * 2 + 1)).cuda()
d_k3 = torch.zeros((1, 1, 2 * 3 + 1, 2 * 3 + 1)).cuda()
d_k4 = torch.zeros((1, 1, 2 * 4 + 1, 2 * 4 + 1)).cuda()
d_k5 = torch.zeros((1, 1, 2 * 5 + 1, 2 * 5 + 1)).cuda()
d_k6 = torch.zeros((1, 1, 2 * 6 + 1, 2 * 6 + 1)).cuda()
d_k7 = torch.zeros((1, 1, 2 * 7 + 1, 2 * 7 + 1)).cuda()
d_k8 = torch.zeros((1, 1, 2 * 8 + 1, 2 * 8 + 1)).cuda()
d_k9 = torch.zeros((1, 1, 2 * 9 + 1, 2 * 9 + 1)).cuda()

d_ks = {1: d_k1, 2: d_k2, 3: d_k3, 4: d_k4, 5: d_k5, 6: d_k6, 7: d_k7, 8: d_k8, 9: d_k9}


selem = torch.ones((3, 3)).cuda()
selem_dilation = torch.FloatTensor(ndi.generate_binary_structure(2, 1)).cuda()

for k, v in d_ks.items():
    v[:,:,k,k] = 1
    for i in range(k):
        v = dilation(v, selem_dilation)
    d_ks[k] = v.squeeze(0).squeeze(0)

def find_boundaries(labels):
    """
    Calculate boundary mask by getting diff of dilated and eroded prediction maps
    """
    assert len(labels.shape) == 4
    boundaries = (dilation(labels.float(), selem_dilation) != erosion(labels.float(), selem)).float()
    ### save_image(boundaries, f'boundaries_{boundaries.float().mean():.2f}.png', normalize=True)

    return boundaries

def expand_boundaries(boundaries, r=0):
    """
    Expand boundary maps with the rate of r
    """
    if r == 0:
        return boundaries
    expanded_boundaries = dilation(boundaries, d_ks[r])
    ### save_image(expanded_boundaries, f'expanded_boundaries_{r}_{boundaries.float().mean():.2f}.png', normalize=True)
    return expanded_boundaries

In [None]:
class BoundarySuppressionWithSmoothing(nn.Module):
    """
    Apply boundary suppression and dilated smoothing
    """
    def __init__(self, boundary_suppression=True, boundary_width=4, boundary_iteration=4,
                 dilated_smoothing=True, kernel_size=7, dilation=6):
        super(BoundarySuppressionWithSmoothing, self).__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.boundary_suppression = boundary_suppression
        self.boundary_width = boundary_width
        self.boundary_iteration = boundary_iteration

        sigma = 1.0
        size = 7
        gaussian_kernel = np.fromfunction(lambda x, y: (1/(2*math.pi*sigma**2)) * math.e ** ((-1*((x-(size-1)/2)**2+(y-(size-1)/2)**2))/(2*sigma**2)), (size, size))
        gaussian_kernel /= np.sum(gaussian_kernel)
        gaussian_kernel = torch.Tensor(gaussian_kernel).unsqueeze(0).unsqueeze(0)
        self.dilated_smoothing = dilated_smoothing

        self.first_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, bias=False)
        self.first_conv.weight = torch.nn.Parameter(torch.ones_like((self.first_conv.weight)))

        self.second_conv = nn.Conv2d(1, 1, kernel_size=self.kernel_size, stride=1, dilation=self.dilation, bias=False)
        self.second_conv.weight = torch.nn.Parameter(gaussian_kernel)


    def forward(self, x, prediction=None):
        if len(x.shape) == 3:
            x = x.unsqueeze(1)
        x_size = x.size()
        # B x 1 x H x W
        assert len(x.shape) == 4
        out = x
        if self.boundary_suppression:
            # obtain the boundary map of width 2 by default
            # this can be calculated by the difference of dilation and erosion
            boundaries = find_boundaries(prediction.unsqueeze(1))
            expanded_boundaries = None
            if self.boundary_iteration != 0:
                assert self.boundary_width % self.boundary_iteration == 0
                diff = self.boundary_width // self.boundary_iteration
            for iteration in range(self.boundary_iteration):
                if len(out.shape) != 4:
                    out = out.unsqueeze(1)
                prev_out = out
                # if it is the last iteration or boundary width is zero
                if self.boundary_width == 0 or iteration == self.boundary_iteration - 1:
                    expansion_width = 0
                # reduce the expansion width for each iteration
                else:
                    expansion_width = self.boundary_width - diff * iteration - 1
                # expand the boundary obtained from the prediction (width of 2) by expansion rate
                expanded_boundaries = expand_boundaries(boundaries, r=expansion_width)
                # invert it so that we can obtain non-boundary mask
                non_boundary_mask = 1. * (expanded_boundaries == 0)

                f_size = 1
                num_pad = f_size

                # making boundary regions to 0
                x_masked = out * non_boundary_mask
                x_padded = nn.ReplicationPad2d(num_pad)(x_masked)

                non_boundary_mask_padded = nn.ReplicationPad2d(num_pad)(non_boundary_mask)

                # sum up the values in the receptive field
                y = self.first_conv(x_padded)
                # count non-boundary elements in the receptive field
                num_calced_elements = self.first_conv(non_boundary_mask_padded)
                num_calced_elements = num_calced_elements.long()

                # take an average by dividing y by count
                # if there is no non-boundary element in the receptive field,
                # keep the original value
                avg_y = torch.where((num_calced_elements == 0), prev_out, y / num_calced_elements)
                out = avg_y

                # update boundaries only
                out = torch.where((non_boundary_mask == 0), out, prev_out)
                del expanded_boundaries, non_boundary_mask

            # second stage; apply dilated smoothing
            if self.dilated_smoothing == True:
                out = nn.ReplicationPad2d(self.dilation * 3)(out)
                out = self.second_conv(out)

            return out.squeeze(1)
        else:
            if self.dilated_smoothing == True:
                out = nn.ReplicationPad2d(self.dilation * 3)(out)
                out = self.second_conv(out)
            else:
                out = x

        return out.squeeze(1)


In [None]:
class MaximumSoftmaxProbability(nn.Module):
    def __init__(self, segmenter, multi_scale = None):
        super().__init__()
        self.segmenter = segmenter.eval().to(DEVICE)
        self.multi_scale = multi_scale
        if self.multi_scale:
            self.multi_scale.to(DEVICE)

    @torch.no_grad()
    def forward(self, inputs):

        logits = self.segmenter(inputs)
        anomaly_score, prediction = torch.max(nn.functional.softmax(logits, dim=1),dim=1)
        
        anomaly_score = 1 - anomaly_score
        
        if self.multi_scale:
            with torch.no_grad():
                anomaly_score = self.multi_scale(anomaly_score, prediction)

        return logits, anomaly_score

In [None]:
class MaxLogit(nn.Module):
    def __init__(self, segmenter, multi_scale = None):
        super().__init__()
        self.segmenter = segmenter.eval().to(DEVICE)
        self.multi_scale = multi_scale
        if self.multi_scale:
            self.multi_scale.to(DEVICE)

    @torch.no_grad()
    def forward(self, inputs):

        logits = self.segmenter(inputs)
        anomaly_score, prediction = torch.max(logits,dim=1)
        
        anomaly_score = 1 - anomaly_score
        
        if self.multi_scale:
            with torch.no_grad():
                anomaly_score = self.multi_scale(anomaly_score, prediction)

        return logits, anomaly_score

In [None]:
class StandardizedMaxLogit(nn.Module):
    def __init__(self, segmenter, train_dl, multi_scale = None):
        super().__init__()
        self.segmenter = segmenter.eval().to(DEVICE)
        
        self.multi_scale = multi_scale
        if self.multi_scale:
            self.multi_scale.to(DEVICE)

        self.class_mean, self.class_var = self.compute_mean_std(train_dl)

    @torch.no_grad()
    def compute_mean_std(self, train_dl, num_classes = 13):
        
        class_mean = np.zeros(num_classes)
        class_std = np.zeros(num_classes)
        iter_count = 0
        
        for batch in tqdm(train_dl, desc="Computing mean and std on train"):
            imgs = batch['image'].to(DEVICE)
            #labels = batch['labels'].to(DEVICE)
    
            logits = self.segmenter(imgs)
            
            for logit in logits:
                iter_count += 1
                conf, labels = torch.max(logit, 0)
                for c in range(num_classes):
                    tens = torch.where(labels == c, conf, 0)
                    mean, std = torch.std_mean(tens)
        
                    class_mean[c] += as_numpy(mean)
                    class_std[c] += as_numpy(std)
            
        class_mean = class_mean/iter_count
        class_std = class_std/iter_count
        
        return class_mean, class_std**2
    
    @torch.no_grad()
    def forward(self, inputs):

        logits = self.segmenter(inputs)
        
        anomaly_score, prediction  = torch.max(logits,dim=1)
        for c in range(len(self.class_mean)):
            anomaly_score = torch.where(
                prediction == c,
                (anomaly_score - self.class_mean[c]) / np.sqrt(self.class_var[c]),
                anomaly_score)
        
        anomaly_score = 1 - anomaly_score
        
        if self.multi_scale:
            with torch.no_grad():
                anomaly_score = self.multi_scale(anomaly_score, prediction)

        return logits, anomaly_score

In [None]:
class EnergyScore(nn.Module):
    """
    taken from paper Residual Pattern Learning for Pixel-wise Out-of-Distribution Detection in Semantic Segmentation
    """
    def __init__(self, segmenter, use_gaussian = True):
        super().__init__()
        self.segmenter = segmenter.eval().to(DEVICE)
        self.use_gaussian = use_gaussian

    @torch.no_grad()
    def forward(self, inputs):

        _, score = self.segmenter(inputs)
        
        anomaly_score = -(1. * torch.logsumexp(score, dim=1))
        
        if self.use_gaussian:
            anomaly_score = anomaly_score.unsqueeze(0)
            anomaly_score = transforms.GaussianBlur(7, sigma=1)(anomaly_score)
            anomaly_score = anomaly_score.squeeze(0)
            
        return score, anomaly_score

In [None]:
class EnergyEntropyScore(nn.Module):
    """
    taken from paper Open-set Anomaly Segmentation in Complex Scenarios (no code, since the paper has been published on arXiv 28/05/2025)
    """
    def __init__(self, segmenter, use_gaussian = True):
        super().__init__()
        self.segmenter = segmenter.eval().to(DEVICE)
        self.use_gaussian = use_gaussian

    @torch.no_grad()
    def forward(self, inputs):
        
        _, score = self.segmenter(inputs)

        prob = torch.softmax(score, dim=1)
        entorpy_score = -torch.sum(prob * torch.log(prob), dim=1)
        energy = -torch.log(torch.sum(torch.exp(score),dim=1))

        anomaly_score = energy*1 + entorpy_score*1
        
        if self.use_gaussian:
            anomaly_score = anomaly_score.unsqueeze(0)
            anomaly_score = transforms.GaussianBlur(7, sigma=1)(anomaly_score)
            anomaly_score = anomaly_score.squeeze(0)
            
        return score, anomaly_score

In [None]:
def disimilarity_entropy(logits, vanilla_logits, t=1.):
    """
    loss from RPL anomaly detection
    taken from https://github.com/yyliu01/RPL/blob/main/rpl_corocl.code/loss/PositiveEnergy.py
    """
    n_prob = torch.clamp(torch.softmax(vanilla_logits, dim=1), min=1e-7)
    a_prob = torch.clamp(torch.softmax(logits, dim=1), min=1e-7)

    n_entropy = -torch.sum(n_prob * torch.log(n_prob), dim=1) / t
    a_entropy = -torch.sum(a_prob * torch.log(a_prob), dim=1) / t
    if torch.isnan(logits).any():
        print(f"logits: {logits}, a_prob: {a_prob}, a_entropy: {a_entropy}")
    entropy_disimilarity = F.mse_loss(input=a_entropy, target=n_entropy, reduction="none")
    assert ~torch.isnan(entropy_disimilarity).any(), print(torch.min(n_entropy), torch.max(a_entropy))

    return entropy_disimilarity

def energy_loss(logits, targets, vanilla_logits, out_idx=13, t=1.):
    """
    loss from RPL anomaly detection
    taken from https://github.com/yyliu01/RPL/blob/main/rpl_corocl.code/loss/PositiveEnergy.py
    """
    out_msk = (targets == out_idx)
    void_msk = (targets == 255)

    pseudo_targets = torch.argmax(vanilla_logits, dim=1)
    outlier_msk = (out_msk | void_msk)
    entropy_part = F.cross_entropy(input=logits, target=pseudo_targets, reduction='none')[~outlier_msk]
    reg = disimilarity_entropy(logits=logits, vanilla_logits=vanilla_logits)[~outlier_msk]
    if torch.sum(out_msk) > 0:
        logits = logits.flatten(start_dim=2).permute(0, 2, 1)
        energy_part = F.relu(torch.log(torch.sum(torch.exp(logits),dim=2))[out_msk.flatten(start_dim=1)]).mean()
    else:
        energy_part = torch.tensor([.0], device=targets.device)

    inlier_loss = entropy_part.mean() + reg.mean()
    outlier_loss = energy_part * 0.05 # 0.05 = energy_weight (taken from paper)
    loss_res = inlier_loss + outlier_loss

    return loss_res

In [None]:
'''def energy_entropy_loss(logits, targets, vanilla_logits, out_idx=13, t=1., alpha = 1):
    """
    TODO: review that the implementation is correct
    loss from Open-set Anomaly Segmentation in Complex Scenarios (code not published yet)
    """
    out_msk = (targets == out_idx)
    void_msk = (targets == 255)

    pseudo_targets = torch.argmax(vanilla_logits, dim=1)
    outlier_msk = (out_msk | void_msk)
    entropy_part = F.cross_entropy(input=logits, target=pseudo_targets, reduction='none')[~outlier_msk]
    reg = disimilarity_entropy(logits=logits, vanilla_logits=vanilla_logits)[~outlier_msk]
    
    if torch.sum(out_msk) > 0:
        
        prob = torch.softmax(logits, dim=1).flatten(start_dim=2).permute(0, 2, 1)
        logits = logits.flatten(start_dim=2).permute(0, 2, 1)
        
        energy = torch.log(torch.sum(torch.exp(logits),dim=2))
        entropy = -torch.sum(prob * torch.log(prob), dim=2)# / torch.log(torch.tensor(13.))
        outlier_part = -torch.log(torch.sigmoid(-energy[out_msk.flatten(start_dim=1)])) - alpha*(entropy[out_msk.flatten(start_dim=1)])
        inlier_part = torch.log(1 - torch.sigmoid(-energy[~out_msk.flatten(start_dim=1)])) + alpha*(entropy[~out_msk.flatten(start_dim=1)])
        energy_entropy = outlier_part.mean() - inlier_part.mean()
        
    else:
        energy_entropy = torch.tensor([.0], device=targets.device).mean()
        
    inlier_loss = entropy_part.mean() + reg.mean()
    outlier_loss = energy_entropy*0.05
    loss_res = inlier_loss + outlier_loss
    
    return loss_res'''
    
def energy_entropy_loss(logits, targets, vanilla_logits, out_idx=13, t=1., alpha = 1):
    """
    TODO: review that the implementation is correct
    loss from Open-set Anomaly Segmentation in Complex Scenarios (code not published yet)
    """
    out_msk = (targets == out_idx)
    void_msk = (targets == 255)

    pseudo_targets = torch.argmax(vanilla_logits, dim=1)
    outlier_msk = (out_msk | void_msk)
    entropy_part = F.cross_entropy(input=logits, target=pseudo_targets, reduction='none')[~outlier_msk]
    reg = disimilarity_entropy(logits=logits, vanilla_logits=vanilla_logits)[~outlier_msk]
    
    if torch.sum(out_msk) > 0:
        
        eps = 1e-6  # stability constant

        prob = torch.clamp(torch.softmax(logits, dim=1), min=eps).flatten(start_dim=2).permute(0, 2, 1)
        logits = logits.flatten(start_dim=2).permute(0, 2, 1)
        
        energy = torch.logsumexp(logits, dim=2)
        entropy = -torch.sum(prob * torch.log(prob), dim=2)  # no need for / log(C)
        
        sigmoid_energy = torch.sigmoid(-energy)
        sigmoid_energy = torch.clamp(sigmoid_energy, min=eps, max=1 - eps)
        
        outlier_mask_flat = out_msk.flatten(start_dim=1)
        inlier_mask_flat = ~outlier_mask_flat
        
        outlier_part = -torch.log(sigmoid_energy[outlier_mask_flat]) - alpha * entropy[outlier_mask_flat]
        inlier_part = torch.log(1 - sigmoid_energy[inlier_mask_flat]) + alpha * entropy[inlier_mask_flat]
        
        energy_entropy = outlier_part.mean() - inlier_part.mean()
        
    else:
        energy_entropy = torch.tensor([.0], device=targets.device).mean()
        
    inlier_loss = entropy_part.mean() + reg.mean()
    outlier_loss = energy_entropy*0.05
    loss_res = inlier_loss + outlier_loss
    
    return loss_res

In [None]:
#The followings are scheduler used to balance the weights of the Region Loss and the Boundary loss during the loss computation, as written in the paper Boundary loss for highly unbalanced segmentation

class Dummy_scheduler():
    "inspired by Boundary loss for highly unbalanced segmentation"
    def __init__(self, alpha = 0.5, beta = 0.5):
        """
        alpha equals to beta is the configuration leading to the best result in the paper, when alpha is constant
        """
        self.alpha = alpha
        self.beta = beta
    def __call__(self, loss1, loss2, epoch):
        loss_res = self.beta*loss1 + self.alpha*loss2
        return loss_res

class Increamental_scheduler():
    "inspired by Boundary loss for highly unbalanced segmentation"
    def __init__(self, alpha_start = 0.01, beta = 0.99, n_epochs = 40, alpha_end=0.99):
        
        self.alpha_end = alpha_end
        self.alpha_start = alpha_start
        self.alpha = alpha_start
        self.beta = beta
        self.increment = abs(self.alpha - self.beta)/(n_epochs-10)
        self.e = 0

    def update_weights(self, epoch):
        self.alpha = min(self.alpha_start + epoch * self.increment, self.alpha_end)
        
    def __call__(self, loss1, loss2, epoch):
        loss_res = self.beta*loss1 + self.alpha*loss2
        if self.e != epoch:
            print(f"before update alpha: {self.alpha}, beta: {self.beta}")
            self.update_weights(epoch)
            print(f"after update alpha: {self.alpha}, beta: {self.beta}")
            self.e +=1
        return loss_res

class Rebalance_scheduler():
    "inspired by Boundary loss for highly unbalanced segmentation"
    def __init__(self, alpha_start = 0.01, n_epochs = 40, alpha_end=0.99):
        """
        this scheduler is the one that leads to best results in the paper
        """
        self.alpha_end = alpha_end
        self.alpha_start = alpha_start
        self.alpha = alpha_start #start with 0.01
        self.beta = 1-self.alpha
        self.increment = abs(self.alpha - self.beta)/(n_epochs-10)
        self.e = 0

    def update_weights(self, epoch):
        self.alpha = min(self.alpha_start + epoch * self.increment, self.alpha_end)
        self.beta = 1-self.alpha

    def __call__(self, loss1, loss2, epoch):
        loss_res = self.beta*loss1 + self.alpha*loss2
        if self.e != epoch:
            print(f"before update alpha: {self.alpha}, beta: {self.beta}")
            self.update_weights(epoch)
            print(f"after update alpha: {self.alpha}, beta: {self.beta}")
            self.e +=1
        return loss_res

In [None]:
class RPLDeepLab(nn.Module):
    def __init__(self, segmenter):
        super().__init__()
        
        self.encoder = self.copy_un_freeze_params(segmenter.encoder, unfreeze=False)
        self.decoder = self.copy_un_freeze_params(segmenter.decoder, unfreeze=False)
        
        self.final = nn.Sequential(
            self.copy_un_freeze_params(segmenter.decoder.block2, unfreeze=False),
            self.copy_un_freeze_params(segmenter.segmentation_head, unfreeze=False),    
        )
        
        self.atten_aspp_final = nn.Conv2d(256, 304, kernel_size=1, bias=False)
        
        self.residual_anomaly_block = nn.Sequential(
            self.copy_un_freeze_params(segmenter.decoder.aspp, unfreeze=True),
            self.copy_un_freeze_params(segmenter.decoder.up, unfreeze=True),
            self.atten_aspp_final
        )

    def copy_un_freeze_params(self, layer: nn.Module, unfreeze: bool=True) -> nn.Module:
        """
        function that create a deepcopy of a layer and unfreeze its parameters if unfreeze is True, otherwise freeze it

        return: deepcopy of the layer freezed or unfreezed
        """
        layer_copy = deepcopy(layer)
        for param in layer_copy.parameters():
            param.requires_grad = unfreeze
        return layer_copy

    def forward(self, x):

        features = self.encoder(x)
        aspp_features = self.decoder.aspp(features[-1])
        aspp_features = self.decoder.up(aspp_features)
        high_res_features = self.decoder.block1(features[2])
        concat_features = torch.cat([aspp_features, high_res_features], dim=1)
        
        res = self.residual_anomaly_block(features[-1])

        out1 = self.final(concat_features)
        out2 = self.final(concat_features + res)

        return out1, out2

In [None]:
class WarmUpPolyLRScheduler(_LRScheduler):
    """
    adapted from RPL for pixel-wise OOD in semantica segmentation
    """
    def __init__(self, optimizer, start_lr, total_iters, warmup_steps=0, lr_power=0.9, end_lr=1e-8, last_epoch=-1):
        self.start_lr = start_lr
        self.total_iters = total_iters
        self.warmup_steps = warmup_steps
        self.lr_power = lr_power
        self.end_lr = end_lr
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        cur_iter = self.last_epoch
        if cur_iter < self.warmup_steps:
            lr = self.start_lr * (cur_iter / self.warmup_steps)
        else:
            lr = self.start_lr * ((1 - float(cur_iter) / self.total_iters) ** self.lr_power)
            lr = np.clip(lr, a_min=self.end_lr, a_max=self.start_lr)

        return [lr for _ in self.optimizer.param_groups]

In [None]:
class Trainer:
    def __init__(self,
                 #processor,
                 model: nn.Module,
                 train_loader: DataLoader,
                 val_loader: DataLoader,
                 cfg: dict,
                 loss1,
                 loss2 = None,
                 loss_scheduler= None,
                 device: torch.device = DEVICE,
                 num_classes: int = len(COLORS)-1,
                 resume_ckpt: dict = None,
                 
        ) -> None:

        self.loss_scheduler = loss_scheduler
        self.model_name = cfg["model_name"]
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.num_classes = num_classes
        self.patience = cfg["patience"]
        self.loss1 = loss1
        self.loss2 = loss2
        #self.best_model = self.model
        
        if resume_ckpt:

            self.model = model.to(device)
            self.model.load_state_dict(resume_ckpt['model_state_dict'])
            
            self.num_epochs = cfg["num_epochs"] - resume_ckpt['epoch']
            num_steps = self.num_epochs * len(train_loader)
            
            self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg["lr"], weight_decay=cfg["wd"])
            self.optimizer.load_state_dict(resume_ckpt['optimizer_state_dict'])
            
            self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, cfg["lr"], total_steps=num_steps)
            self.scheduler.load_state_dict(resume_ckpt['scheduler_state_dict'])
            
        else:
            self.model = model.to(device)
            self.num_epochs = cfg["num_epochs"]
            num_steps = self.num_epochs * len(train_loader)
            self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg["lr"], weight_decay=cfg["wd"])
            #self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, cfg["lr"], total_steps=num_steps)
            #self.scheduler = WarmUpPolyLRScheduler(self.optimizer, start_lr= cfg["lr"], total_iters= num_steps)
            self.scheduler = cfg["scheduler"](self.optimizer, cfg["lr"], num_steps)

        self.mean_iou = 0.0
        self.step = 0
        self.best_acc = 0.0

        self.ckpt_path = Path("ckpts")
        self.ckpt_path.mkdir(exist_ok=True)

        wandb.init(name=self.model_name, entity=WANDB_USER, project=WANDB_PROJECT, config=cfg)

    def wandb_log(self, split, loss, mean_iou, epoch):
        
        wandb.log({
            f"{split}_loss": loss,
            f"{split}_mean_iou": mean_iou,
            }, step=(epoch))
        
    def train(self, verbose= False) -> None:
        
        for epoch in tqdm(range(self.num_epochs), desc="Epoch"):
            
            self.model.train()

            losses = []

            for batch in self.train_loader: 
                    
                imgs = batch['image'].to(self.device)
                labels = batch['labels'].to(self.device)
                
                logits = self.model(imgs)
                
                if type(logits) == tuple:
                    
                    vanilla_logits, logits = logits
                    if torch.isnan(logits).any():
                        print(f"check if an image is nan: {torch.isnan(imgs).any()}, \ncheck if a label is nan: {torch.isnan(labels).any()}")
                    loss_res = self.loss1(logits=logits, targets=labels.clone(),
                                         vanilla_logits=vanilla_logits)
                else:
    
                    if not self.loss2:
                        loss_res = self.loss1(logits, labels)
    
                    else:

                        loss1_res = self.loss1(logits, labels)
                        loss2_res = self.loss2(logits, labels)
                        loss_res = self.loss_scheduler(loss1= loss1_res, loss2= loss2_res, epoch= epoch)
                        
                    del imgs, labels
                            
                losses.append(loss_res.item())
                
                self.optimizer.zero_grad()
                loss_res.backward()
                self.optimizer.step()
                self.scheduler.step()
            
                del loss_res
                

            l = sum(losses) / len(losses)
            print(f'epoch {epoch} | Training')
            print(f'   total training loss : {l}')

            print(f"Epoch {epoch + 1}", end = ' ')
            self.eval("train", epoch)
            self.eval("val", epoch)

            if self.patience < self.step:
                wandb.finish()
                break
        wandb.finish()

    @torch.no_grad()
    def eval(self, split: str, epoch: int) -> None:
        
        self.model.eval()

        loader = self.train_loader if split == "train" else self.val_loader
        
        mean_iou = MeanIoU()
        losses = []
        mean_avg = []
        std_avg = []
        
        for batch in loader:

            imgs = batch['image'].to(self.device)
            labels = batch['labels'].to(self.device)
            
            logits = self.model(imgs)
            
            if type(logits) == tuple:
                
                vanilla_logits, logits = logits

                loss_res = self.loss1(logits=logits, targets=labels.clone(),
                                     vanilla_logits=vanilla_logits)
            else:

                if not self.loss2:
                    loss_res = self.loss1(logits, labels)

                else:
                    
                    loss1_res = self.loss1(logits, labels)
                    loss2_res = self.loss2(logits, labels)
                    loss_res = self.loss_scheduler(loss1= loss1_res, loss2= loss2_res, epoch= epoch)

            
            losses.append(loss_res.item())

            mean_iou.update(labels, logits)
        
        results = mean_iou.get_results()
        mean_iou = results['Mean IoU']
        
        l = sum(losses) / len(losses)
        print(f"| {split.upper()} Metrics:")
        print(f"  Loss: {l:.4f}")
        print(f"  Mean IoU: {mean_iou:.4f}\n")

        self.wandb_log(split= split, loss= l, mean_iou= mean_iou, epoch= epoch+1)

        if (mean_iou > self.mean_iou or epoch + 1 == self.num_epochs) and split == "val" :
            self.mean_iou = mean_iou

            if epoch + 1 == self.num_epochs:
                
                torch.save(self.model.state_dict(), self.ckpt_path/f"{self.model_name}_lastepoch.pt")
                torch.save({
                    'epoch': epoch,
                    'mean_iou': self.mean_iou,
                    #'loss': loss,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    }, self.ckpt_path / "last_checkpoint")
    
                wandb.save(self.ckpt_path/f"{self.model_name}_lastepoch.pt")
                wandb.save(self.ckpt_path / "last_checkpoint")
                
                self.best_model = copy.deepcopy(self.model)
            else:
                
                torch.save(self.model.state_dict(), self.ckpt_path/f"{self.model_name}.pt")
                torch.save({
                    'epoch': epoch,
                    'mean_iou': self.mean_iou,
                    #'loss': loss,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    }, self.ckpt_path / "best_checkpoint")
    
                wandb.save(self.ckpt_path/f"{self.model_name}.pt")
                wandb.save(self.ckpt_path / "best_checkpoint")
                
                self.best_model = copy.deepcopy(self.model)
            self.step = 0

        elif split == "val":
            self.step += 1

In [None]:
@torch.no_grad()
def compute_metrics(model, loader, aupr = True):

    model.eval()
    model.to(DEVICE)
    mean_iou = MeanIoU()
    mean_aupr = AUPR() if aupr is True else None

    for batch in tqdm(loader):
        imgs = batch['image'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        pred = model(imgs)
        
        if type(pred) == tuple:
            logits, anomaly_score = pred
        else:
            logits = pred

        mean_iou.update(label_trues=labels, logits= logits)

        if aupr:
            mean_aupr.update(anomaly_score, labels)

    if aupr:
        return {"mean_aupr": mean_aupr.get_results(), "mean_iou": mean_iou.get_results()}
    else:
        return {"mean_iou": mean_iou.get_results()}

In [None]:
@torch.no_grad()
def plot_anomaly_heatmap(anomaly_score, image=None, title="Anomaly Heatmap"):
    """
    Plots a heatmap where red = high anomaly, blue = low anomaly.
    
    Parameters:
    - anomaly_score: torch.Tensor or np.array of shape [H, W]
    - image: optional RGB image [H, W, 3] in range [0, 1] or [0, 255]
    - title: string title for plot
    """
    if isinstance(anomaly_score, torch.Tensor):
        anomaly_score = anomaly_score.squeeze().cpu().numpy()
    
    # Normalize to [0, 1]
    #anomaly_score = (anomaly_score - anomaly_score.min()) / (anomaly_score.max() - anomaly_score.min())


    if image is not None:
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        #if isinstance(image, torch.Tensor):
            #image = image.permute(1, 2, 0).cpu().numpy()
        if image.max() > 1.0:
            image = image / 255.0

        
        image1 = ax[0].imshow(anomaly_score, cmap='jet')  # 'jet': blue -> red
        fig.colorbar(image1, ax=ax[0], label='Anomaly Score')
        #ax[0].colorbar(label="Anomaly Score")
        ax[0].set_title(title)
        ax[0].axis("off")
        
        visualize_scene(image, ax[1])

        
    else:
        
        plt.imshow(anomaly_score, cmap='jet')

        plt.colorbar(label="Anomaly Score")
        plt.title(title)
        plt.axis("off")
        
    plt.figure(figsize=(10, 12))
    plt.tight_layout()
    plt.show()

In [None]:
from fvcore.nn import FlopCountAnalysis, parameter_count
def compute_flops_params(model):
    input_tensor = torch.randn(8, 3, 512, 896).to(DEVICE)
    flops = FlopCountAnalysis(model, input_tensor)
    params = parameter_count(model)
    
    print("FLOPs (GFLOPs):", f"{flops.total() / 1e9:.2f}")
    print("Parameters (Millions):", f"{params[''] / 1e6:.2f}")

In [None]:
class Segmenter(nn.Module):
    def __init__(self, encoder_name, encoder_weights = "imagenet", activation= None, num_classes= 13):
        super().__init__()

        self.model = smp.DeepLabV3Plus(encoder_name=encoder_name, 
                                       encoder_weights=encoder_weights, 
                                        classes=num_classes,
                                        activation=activation).to(DEVICE)

    def forward(self, inputs):
        
        logits = self.model(inputs)
        return logits
        
    def wandb_load_weights(self, run_id, model_name):

        api = wandb.Api()
        run = api.run(run_id)
        
        files = run.files()
        for f in files:
            if f.name.startswith("ckpts/"):
                f.download(replace=True)
            
        weight_path = f"/kaggle/working/ckpts/{model_name}.pt"
        model_weigths = torch.load(weight_path, weights_only=True)
        self.load_state_dict(model_weigths)

    def load_weights(self, weight_path):

        model_weigths = torch.load(weight_path, weights_only=True)
        self.model.load_state_dict(model_weigths)
    
    @torch.no_grad()
    def predict(self, inputs):
        
        self.model.eval()
        pred = self.model(inputs.unsqueeze(0).to(DEVICE))
        pred = pred.squeeze(0)
        pred = torch.argmax(pred, dim = 0)
        return pred.cpu()

    def get_flops(self, input_shape):
        print(summary(self.model, input_shape))

In [None]:
def wandb_load_weights(model, run_id, model_name):

    api = wandb.Api()
    run = api.run(run_id)
    
    files = run.files()
    for f in files:
        if f.name.startswith("ckpts/"):
            f.download(replace=True)
        
    weight_path = f"/kaggle/working/ckpts/{model_name}.pt"
    model_weigths = torch.load(weight_path, weights_only=True)
    model.load_state_dict(model_weigths)

def load_weights(model, weight_path):

    model_weigths = torch.load(weight_path, weights_only=True)
    model.load_state_dict(model_weigths)

@torch.no_grad()
def predict(model, inputs):
    
    model.eval()
    pred = model(inputs.unsqueeze(0).to(DEVICE))
    pred = pred.squeeze(0)
    pred = torch.argmax(pred, dim = 0)
    return pred.cpu()

# 1. Introduction
## 1.1 Overall Approach

The approach used to solve the problem follows the idea of many paper that can be found in the literature:
- train a segmenter model, which solves the inlier segmentation without considering the anomaly part
- add on top of the segmentation model an anomaly detector
For the inlier segmentation task the choosen architecture has been DeepLabV3plus, a well-know architecure in the literature developed for segmentation purposes, for the feature extractor the idea was to compare the performances of two lightweight models:
- EfficientNet-b0
- MobileNetV2
At the end the choosen architecure has been EfficientNet-b0, which allows to obtain best performances without increasing too much the complexity of the model and the execution time.

For the anomaly detection part, the implemented approach use Residual Pattern Learning idea presented in the paper: Residual Pattern Learning for Pixel-wise Out-of-Distribution Detection in Semantic Segmentation, the peak performance has been reached with the new loss presented in a recent publication: Open-set Anomaly Segmentation in Complex Scenarios, Song Xia et al. 

All the experiments have been performed using Kaggle platform, the choice of lighter architecture, which clearly does not allow to reach top score performances, is motivated by the strong limit in computational resources. Nonethless, i have been able to reach performances in line with the results presented in papers.

## 1.2 Metrics and Evaluation

The architecures have been compared, as required, using two metrics:
- MeanIoU, computed according to the definition given during the lectures, using the implementation provided in the repository of DeepLabV3+. It has been used to measure the inlier segmentation of the models
- AUPR, computed using the scikit-learn metrics `average_precision_score`, by averaging over each individual samples.

# 2. Dataset
The Dataset used is StreetHazards, one thing that is noticeable among labels is the followings:
- BACKGROUND, which should represent mainly the sky, contains inside it self also other obects;
- OTHERS, as the name says, represent a multitude of objects.
this makes the task even more complicated, since the model could be tricky.

In [None]:
shape_resize = (512, 896)
#shape_resize = (688, 688)

mean_imagenet, std_imagenet = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

#to apply to both image and labels before tensorise them
spatial_transforms = transforms.v2.Compose([
    #transforms.v2.RandomCrop(shape_resize),
    transforms.v2.RandomHorizontalFlip(),
])

#to be applied only to images after tensorise them
images_only_transforms = transforms.Compose([
    transforms.Normalize(mean = mean_imagenet, std = std_imagenet),
    #transforms.Normalize(mean = mean_streethazards, std = std_streethazards),
    #transforms.RandomErasing(scale=(0.02, 0.15))
])

#only apply resize, to_tensor and normalization (computed on train)
val_test_transforms = transforms.Normalize(
    #mean = mean_streethazards, std = std_streethazards
    mean = mean_imagenet, std = std_imagenet
)

train_dataset = StreetHazardsDataset(
    odgt_file="/kaggle/input/streethazards_train/train/train.odgt",
    image_resize = shape_resize,
    spatial_transforms=spatial_transforms,
    images_only_transforms=images_only_transforms
)

val_dataset = StreetHazardsDataset(
    odgt_file="/kaggle/input/streethazards_train/train/validation.odgt",
    image_resize = shape_resize,
    spatial_transforms=None,
    images_only_transforms=val_test_transforms
)

test_dataset = StreetHazardsDataset(
    odgt_file="/kaggle/input/streethazards_test/test/test.odgt",
    image_resize = shape_resize,
    spatial_transforms=None,
    images_only_transforms=val_test_transforms
)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 12))
idx = 2
img, label = train_dataset[idx].values()
visualize_scene(img, axs[0])
visualize_annotation(label, axs[1])

In [None]:
train_dl = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_dl = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)
test_dl = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2)

## 2.2 Data Exploration
It can be observed that there are sequences of images, correlated by spatial dependencies. Since the prediction is assumed to be made to single images, it has been decided to shuffle the train dataset. This should avoid any leakage of information during training that could, illegitimately, boost model performances

In [None]:
'''fig, axs = plt.subplots(1, 3, figsize=(10, 12))
compute_class_frequency(dataset = train_dataset, num_classes = len(COLORS)-1, plot_frequencies = True, ax= axs[0]);
compute_class_frequency(dataset = val_dataset, num_classes = len(COLORS)-1, plot_frequencies = True, ax= axs[1]);
compute_class_frequency(dataset = test_dataset, num_classes = len(COLORS), plot_frequencies = True, ax= axs[2]);
plt.show()'''

## 2.3 Anomalies dataset

The Residual Pattern Learning architecture required an additional training over a frozen segmenter using dataset with anomalies, and the anomaly pixels are only present in the test set, both train and validation for that part will be enanched with anomlies take from PascalVoc. Most of the papers uses for anomaly injection COCO dataset, but due to lack of memory in the kaggle run time  directory PascalVoc has been choosen as lighter alternative. 
As steated in Open-set Anomaly Segmentation in Complex Scenarios, Song Xia et al. this copy-paste is not the best solutions, since anomalies are placed in a randomly, but the usage of a Diffusion Model in my case would have been too computationally expensive.

In [None]:
mix_train = MixDataset(inlier_dataset= StreetHazardsDataset(odgt_file="/kaggle/input/streethazards_train/train/train.odgt",
                                                            image_resize = shape_resize,
                                                            spatial_transforms=None,
                                                            images_only_transforms=None), 
                       outlier_dataset= voc_train, 
                       images_only_transforms= images_only_transforms)

mix_val = MixDataset(inlier_dataset= StreetHazardsDataset(odgt_file="/kaggle/input/streethazards_train/train/validation.odgt",
                                                          image_resize = shape_resize,
                                                          spatial_transforms=None,
                                                          images_only_transforms=None), 
                     outlier_dataset= voc_val, 
                     images_only_transforms= images_only_transforms)

In [None]:
mix_train_dl = DataLoader(mix_train, batch_size=8, shuffle=True, num_workers=2)
mix_val_dl = DataLoader(mix_val, batch_size=8, shuffle=False, num_workers=2)

# 3 Inlier Segmentation

In this section i will show the results of the inlier segmentation and the ablation study which has been made to arrive at this results

In [None]:
efficientnetb0 = smp.DeepLabV3Plus(encoder_name='efficientnet-b0', 
                                   encoder_weights='imagenet', 
                                   classes=13,
                                   activation=None).to(DEVICE)

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/8yuhosdd', model_name= 'CrossEntropyDice_HorizontalFlip_lower_wd') #wd = 1e-5
aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_iou: ', iou/3)

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/6rpqmc22', model_name= 'CrossEntropy_Dice_HorizontalFlip_no_weight_decay') #wd = 0

aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_iou: ', iou/3)

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/baeu3kou', model_name= 'CrossEntropyDice_HorizontalFlip_wd_1eminus4') #wd = 1e-4

aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_iou: ', iou/3)

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/2txwhmej', model_name= 'CrossEntropyDice_HorizontalFlip_higher_wd') #ed = 1e-1

print(compute_metrics(MaxLogit(efficientnetb0), mix_val_dl))

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/4h6rbhuw', model_name= 'CrossEntropyDice_HorizontalFlip_wd_1eminus2')#wd = 1e-2
print(compute_metrics(MaxLogit(efficientnetb0), mix_val_dl))

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/z2x53n4i', model_name= 'CrossEntropyDice_with_HorizontalFlip') #wd = 1e-3

print(compute_metrics(MaxLogit(efficientnetb0), mix_val_dl))

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/yxjodd6i', model_name= 'CrossEntropyFocal_HorizontalFlip') #wd = 1e-3

print(compute_metrics(MaxLogit(efficientnetb0), mix_val_dl))

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/a9ivthlz', model_name= 'CrossEntropyJaccard_HorizontalFlip') #wd = 1e-3

print(compute_metrics(MaxLogit(efficientnetb0), mix_val_dl))

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/ndoh4hkz', model_name= 'CrossEntropyLovasz_HorizontalFlip') #wd = 1e-3

print(compute_metrics(MaxLogit(efficientnetb0), mix_val_dl))

In [None]:
err

## 3.1 mobilenetv2 vs efficientnet-b0
- wd
- lr
- the loss used has been cross entropy

In [None]:
'''cfg_original = {
    "num_epochs" : 40,
    "lr": 2e-4,
    "wd": 0.001,
    "patience": 1000,
    "model_name": "higher_lr_mobilenetv2",
    #scheduler = torch.optim.lr_scheduler.OneCycleLR
}'''

In [None]:
#efficientnetb0 = Segmenter(encoder_name= 'efficientnet-b0')

In [None]:
cfg = {
    "num_epochs" : 40,
    "lr": 2e-4,
    "wd": 0.001,
    "patience": 1000,
    "model_name": "CrossEntropy_Dice_HorizontalFlip_mobilenetv2",
    "scheduler": torch.optim.lr_scheduler.OneCycleLR
}

In [None]:
'''trainer = Trainer(
    model= efficientnetb0,
    train_loader= train_dl,
    val_loader= val_dl,
    loss1 = F.cross_entropy,
    loss2 = losses.DiceLoss(mode = 'multiclass'),
    loss_scheduler = Dummy_scheduler(),
    cfg= cfg
)

#trainer.train()'''

In [None]:
'''wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/2txwhmej', model_name= 'CrossEntropyDice_HorizontalFlip_higher_wd') #0.1

print(compute_metrics(MaxLogit(efficientnetb0), mix_val_dl))'''

In [None]:
'''wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/z2x53n4i', model_name= 'CrossEntropyDice_with_HorizontalFlip') #wd= 0.001

print(compute_metrics(MaxLogit(efficientnetb0), mix_val_dl))'''

In [None]:
'''wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/8yuhosdd', model_name= 'CrossEntropyDice_HorizontalFlip_lower_wd') #0.00001

print(compute_metrics(MaxLogit(efficientnetb0), mix_val_dl))'''

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/6rpqmc22', model_name= 'CrossEntropy_Dice_HorizontalFlip_no_weight_decay')

#print(compute_metrics(MaxLogit(efficientnetb0), mix_val_dl))

In [None]:
rpl_cfg = {
    "num_epochs" : 30,
    "lr": 7.5e-5,
    "wd": 1e-4,
    "patience": 1000,
    "model_name": "trial_energy_entropy_loss",
    "scheduler": WarmUpPolyLRScheduler
}

In [None]:
rpl = RPLDeepLab(efficientnetb0)

In [None]:
trainer = Trainer(
    model= rpl,
    train_loader= mix_train_dl,
    val_loader= mix_val_dl,
    loss1 = energy_loss,
    cfg= rpl_cfg
)

trainer.train()

In [None]:
load_weights(model= rpl, weight_path= f'/kaggle/working/ckpts/{rpl_cfg["model_name"]}.pt')
print(compute_metrics(EnergyEntropyScore(rpl), test_dl))
print(compute_metrics(EnergyScore(rpl), test_dl))

In [None]:
load_weights(model= rpl, weight_path= f'/kaggle/working/ckpts/{rpl_cfg["model_name"]}_lastepoch.pt')
print(compute_metrics(EnergyEntropyScore(rpl), test_dl))
print(compute_metrics(EnergyScore(rpl), test_dl))

In [None]:
#energy loss rpl
wandb_load_weights(model= rpl, run_id = 'chri-project/ML4CV--assignment/5l6w1nue', model_name= 'rpl_efficientnet_ce_dice_horizontal_flip')

print(compute_metrics(EnergyEntropyScore(rpl), test_dl))
print(compute_metrics(EnergyScore(rpl), test_dl))

In [None]:
err

### other losses

688, 688 has the same flops consumption at training time of the other size, but it is squared

### Dice

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/n6nxoni1', model_name= 'CrossentropyDice_with_HorizontalFlip_RandomCrop')

aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_aupr: ', iou/3)

### Lovasz

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/dj7yebxb', model_name= 'CrossEntropyLovasz_HorizontalFlip_RandomCrop')
aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_aupr: ', iou/3)

### Jaccard

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/eglcncxi', model_name= 'CrossEntropyJaccard_HorizontalFlip_RandomCrop')
aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_aupr: ', iou/3)

### Focal

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/56rxyhvp', model_name= 'CrossEntropyFocal_HorizontalFlip_RandomCrop')
aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_aupr: ', iou/3)

### higher weight decay

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/zzgcb1sj', model_name= 'Higher_wd_CE_Dice_HorizontalFlip_RandomCrop')
aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_iou: ', iou/3)
print((aupr/3 + iou/3)/2)

### lower weight decay

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/l1rcyoa5', model_name= 'Lower_wd_CE_Dice_HorizontalFlip_RandomCrop')
aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_iou: ', iou/3)
print((aupr/3 + iou/3)/2)

### no normalization

In [None]:
no_norm_val_dataset = StreetHazardsDataset(
    odgt_file="/kaggle/input/streethazards_train/train/validation.odgt",
    image_resize = shape_resize,
    spatial_transforms=None,
    images_only_transforms=None
)

no_norm_mix_val = MixDataset(inlier_dataset= no_norm_val_dataset, outlier_dataset= voc_val, images_only_transforms= None)
no_norm_mix_val_dl = DataLoader(no_norm_mix_val, batch_size=8, shuffle=False, num_workers=2)

wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/x068ojxv', model_name= 'CrossEntropyDice_HorizontalFlip_RandomCrop_without_Normalization')
aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), no_norm_mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_iou: ', iou/3)
print((aupr/3 + iou/3*100)/2)

### only random flip

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/z2x53n4i', model_name= 'CrossEntropyDice_with_HorizontalFlip')
aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_iou: ', iou/3)
print((aupr/3 + iou/3)/2)

### no augmentations

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/aqdwcrnh', model_name= 'DiceLoss_and_cross_entropy_without_normalization')
aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_iou: ', iou/3)
print((aupr/3 + iou/3)/2)

### square images

In [None]:
wandb_load_weights(model= efficientnetb0, run_id = 'chri-project/ML4CV--assignment/53vpt4ew', model_name= 'CrossEntropyDice_HorizontalFlip_RandomCrop_square_images')
aupr = 0
iou = 0
for i in range(3):
    res = compute_metrics(MaxLogit(efficientnetb0), mix_val_dl)
    aupr += res['mean_aupr']
    iou += res['mean_iou']['Mean IoU']

print('mean_aupr: ', aupr/3)
print('mean_iou: ', iou/3)
print((aupr/3 + iou/3)/2)

## 3.1.1 Data Preprocessing

### Image resizing
Using lower image size, make the image lossing small details as shown by lower Mean IoU. The best choice would have been to keep the image as it is, but due to low computational power i have had to resize images. Furthermore i have decided to keep the same relationship between width and height since this increases a little the performances instead of making the image squares.

### Data normalization
The images have been normalized according to ImageNet mean and std, since the models have pretrained weights on imagenet. Indeed, removing data normalization, or changing it using mean and std computed on the StreetHazards train set worse both Mean IoU and AUPR.

## 3.1.1 hyperparameter tuning

### Learning rate

### Weight decay

## 3.1.2 data augmentation

Data augmentation with data augmentation i found that that using both Horizontal flips and Random Crop improves not only MeanIoU, but also AUPR curve

## 3.2 losses
Following the expoloration on Cityscapes datset of the paper: LOSS FUNCTIONS IN THE ERA OF SEMANTIC SEGMENTATION: A SURVEY AND OUTLOOK, I have experimented the losses which does not requires hyperparameter tuning in combination with Cross Entropy, using a so called Dummy Scheduler, which combine the $L_{CE}$ and $L_{var}$. The final loss has been $L_{total}$ = 0.5 * $L_{CE}$ + 0.5 * $L_{var}$.

## 3.3 Results

# 4 Anomaly segmentation

In this section I will shoe the results of the anomaly segmentation and the ablation study made to arrive at this results

# 5 Final Results


## 5.1 Quantitative results

## 5.2 Qualitative results

#### try to freeze some layer (?)

# RPL implementation

In [None]:
model = get_deeplab_model(encoder_name="mobilenet_v2", encoder_weights = "imagenet", activation= None, num_classes= 13)
mn = load_model_weights(model = model,
                               run_id = "chri-project/ML4CV--assignment/mgg66q6x", 
                               model_name= "cross_entropy_mobilenetv2_without_normalization")

In [None]:
rpl = RPLDeepLab(mn)
summary(rpl, input_size= (8, 3, 512, 896))

In [None]:
model_name = "rpl_energy_entropy_loss_mobilenetv2_crossentropy"

rpl_cfg = {
    "num_epochs" : 40,
    "lr": 5e-5,
    "wd": 0.001,
    "patience": 1000,
    "segmenter": "mobilenetv2_crossentropy",
    "run_name": model_name,
}

In [None]:
rpl_trainer = Trainer(
    model= rpl,
    train_loader= mix_train_dl,
    val_loader= mix_val_dl ,
    #loss1 = energy_loss,
    loss1 = energy_entropy_loss,
    device= DEVICE,
    num_classes = len(COLORS),
    model_name = model_name,
    cfg= rpl_cfg,
    scheduler = WarmUpPolyLRScheduler
    #resume_ckpt = torch.load(resume_ckpt)
)

rpl_trainer.train()

# JUMP HERE

In [None]:
run_id_rpl = "chri-project/ML4CV--assignment/a17kpqqs"
rpl_model_name = "rpl_training_energy_loss"

rpl = load_model_weights(run_id_rpl, rpl_model_name, rpl)

#### rpl with energy loss results (RPL original paper)

In [None]:
run_id_rpl = "chri-project/ML4CV--assignment/ht31svvx"
rpl_model_name = "rpl_energy_entropy_loss_over_tversky"
directory = resume_run(run_id_rpl, rpl_model_name)
model_weights_path = torch.load(directory, weights_only=True)
rpl.load_state_dict(model_weights_path)

#### rpl with energy entropy loss results (RPL new paper)

In [None]:
class ContrastLoss(nn.Module, ABC):
    def __init__(self, engine=None, config=None):
        super(ContrastLoss, self).__init__()
        self.engine = engine
        self.temperature = 0.10
        self.ignore_idx = 255
        self.ood_idx = 13
        self.max_views = 512

    def forward(self, city_proj, city_gt, city_pred, ood_proj, ood_gt, ood_pred):
        city_gt = torch.nn.functional.interpolate(city_gt.unsqueeze(1).float(), size=city_proj.shape[2:],
                                                  mode='nearest').squeeze().long()

        ood_gt = torch.nn.functional.interpolate(ood_gt.unsqueeze(1).float(), size=ood_proj.shape[2:],
                                                 mode='nearest').squeeze().long()

        # normalise the embed results
        city_proj = torch.nn.functional.normalize(city_proj, p=2, dim=1)
        ood_proj = torch.nn.functional.normalize(ood_proj, p=2, dim=1)

        # randomly extract embed samples within a batch
        anchor_embeds, anchor_labels, contrs_embeds, contrs_labels = self.extraction_samples(city_proj, city_gt,
                                                                                             ood_proj, ood_gt)

        # calculate the CoroCL
        loss = self.info_nce(anchors_=anchor_embeds, a_labels_=anchor_labels.unsqueeze(1), contras_=contrs_embeds,
                             c_labels_=contrs_labels.unsqueeze(1)) if anchor_embeds.nelement() > 0 else \
            torch.tensor([.0], device=city_proj.device)

        return loss

    # The implementation of cross-image contrastive learning is based on:
    # https://github.com/tfzhou/ContrastiveSeg/blob/287e5d3069ce6d7a1517ddf98e004c00f23f8f99/lib/loss/loss_contrast.py
    def info_nce(self, anchors_, a_labels_, contras_, c_labels_):
        # calculates the binary mask: same category => 1, different categories => 0
        mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float()

        # calculates the dot product
        anchor_dot_contrast = torch.div(torch.matmul(anchors_, torch.transpose(contras_, 0, 1)),
                                        self.temperature)

        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # calculates the negative mask
        neg_mask = 1 - mask
        
        # avoid the self duplicate issue
        mask = mask.fill_diagonal_(0.)

        # sum the negative odot results
        neg_logits = torch.exp(logits) * neg_mask
        neg_logits = neg_logits.sum(1, keepdim=True)

        exp_logits = torch.exp(logits)

        # log_prob -> log(exp(x))-log(exp(x) + exp(y))
        # log_prob -> log{exp(x)/[exp(x)+exp(y)]}
        log_prob = logits - torch.log(exp_logits + neg_logits)

        # calculate the info-nce based on the positive samples (under same categories)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
        return - mean_log_prob_pos.mean()

    def extraction_samples(self, city_embd, city_label, ood_embd, ood_label):
        
        #checking the correct shape
        if city_label.dim() == 2:
            city_label = city_label.unsqueeze(0)
        if ood_label.dim() == 2:
            ood_label = ood_label.unsqueeze(0)

        # reformat the matrix
        city_embd = city_embd.flatten(start_dim=2).permute(0, 2, 1)
        city_label = city_label.flatten(start_dim=1)
        ood_embd = ood_embd.flatten(start_dim=2).permute(0, 2, 1)
        ood_label = ood_label.flatten(start_dim=1)

        # define different types of embeds
        city_positive = city_embd[city_label == self.ood_idx]
        city_negative = city_embd[(city_label != self.ood_idx) & (city_label != self.ignore_idx)]
        ood_positive = ood_embd[ood_label == self.ood_idx]
        ood_negative = ood_embd[(ood_label != self.ood_idx) & (ood_label != self.ignore_idx)]

        # define the number of choice
        sample_num = int(min(self.max_views, city_positive.shape[0], ood_positive.shape[0],
                             city_negative.shape[0], ood_negative.shape[0]))

        # randomly extract the anchor set with {city_ood, city_inlier}
        city_positive_anchor = city_positive[torch.randperm(city_positive.shape[0])][:sample_num]
        city_negative_anchor = city_negative[torch.randperm(city_negative.shape[0])][:sample_num]

        anchor_embed = torch.cat([city_positive_anchor, city_negative_anchor], dim=0)

        anchor_label = torch.cat([torch.empty(city_positive_anchor.shape[0],
                                              device=city_positive_anchor.device).fill_(1.),
                                  torch.empty(city_negative_anchor.shape[0],
                                              device=city_negative_anchor.device).fill_(0.)])

        # randomly extract the contras set with {city_ood, city_inlier, coco_ood, coco_inlier}
        city_positive_contras = city_positive_anchor.clone()
        city_negative_contras = city_negative_anchor.clone()
        ood_positive_contras = ood_positive[torch.randperm(ood_positive.shape[0])][:sample_num]
        ood_negative_contras = ood_negative[torch.randperm(ood_negative.shape[0])][:sample_num]

        contrs_embed = torch.cat([city_positive_contras, city_negative_contras,
                                  ood_positive_contras, ood_negative_contras], dim=0)

        contrs_label = torch.cat([torch.empty(city_positive_contras.shape[0],
                                              device=city_positive_contras.device).fill_(1.),
                                  torch.empty(city_negative_contras.shape[0],
                                              device=city_negative_contras.device).fill_(0.),
                                  torch.empty(ood_positive_contras.shape[0],
                                              device=ood_positive_contras.device).fill_(1.),
                                  torch.empty(ood_negative_contras.shape[0],
                                              device=ood_negative_contras.device).fill_(0.)])

        return anchor_embed, anchor_label, contrs_embed, contrs_label

In [None]:
class RPLCLDeepLab(nn.Module):
    def __init__(self, model):
        super().__init__()
        
        self.encoder = self.copy_un_freeze_params(model.encoder, unfreeze=False)
        self.decoder = self.copy_un_freeze_params(model.decoder, unfreeze=False)
        self.final = nn.Sequential(
            self.copy_un_freeze_params(model.decoder.block2, unfreeze=False),
            self.copy_un_freeze_params(model.segmentation_head, unfreeze=False),    
        )
        
        self.atten_aspp_final = nn.Conv2d(256, 304, kernel_size=1, bias=False)
        
        self.projection_head = nn.Sequential(
            nn.Conv2d(256, 304, kernel_size=1)
        )
        
        self.residual_anomaly_block = nn.Sequential(
            self.copy_un_freeze_params(model.decoder.aspp, unfreeze=True),
            self.copy_un_freeze_params(model.decoder.up, unfreeze=True),
        )


    def copy_un_freeze_params(self, layer: nn.Module, unfreeze: bool=True) -> nn.Module:
        """
        function that create a deepcopy of a layer and unfreeze its parameters if unfreeze is True, otherwise freeze it

        return: deepcopy of the layer freezed or unfreezed
        """
        layer_copy = deepcopy(layer)
        for param in layer_copy.parameters():
            param.requires_grad = unfreeze
        return layer_copy

    def forward(self, x):

        features = self.encoder(x)
        aspp_features = self.decoder.aspp(features[-1])
        aspp_features = self.decoder.up(aspp_features)
        high_res_features = self.decoder.block1(features[2])
        concat_features = torch.cat([aspp_features, high_res_features], dim=1)
        
        res = self.residual_anomaly_block(features[-1])

        res1 = self.atten_aspp_final(res)
        proj = self.projection_head(res)

        out1 = self.final(concat_features)
        out2 = self.final(concat_features + res1)

        return out1, out2, proj

In [None]:
''''rpl_cr_trainer = Trainer(
    model= rplcl,
    train_loader= mix_train_dl,
    val_loader= mix_val_dl ,
    loss = energy_loss,
    contrastive_loss = ContrastLoss(),
    device= DEVICE,
    num_classes = len(COLORS),
    model_name = model_name,
    cfg= rpl_cfg,
    #resume_ckpt = torch.load(resume_ckpt)
)

#rpl_cr_trainer.train()''''