In [None]:
from datasets import Dataset, DatasetDict, Image
import os
import json
import cv2
from natsort import natsorted
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim import AdamW
from tqdm.auto import tqdm
import random
import wandb
import math
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

In [None]:
import os
import cv2
import json
from natsort import natsorted
from datasets import Dataset, DatasetDict, Image

def get_paths():
    train_path_imgs = "/home/aleximu/gunes/dinov2/project/dataset/fishency/train/fishes"
    train_path_masks = "/home/aleximu/gunes/dinov2/project/dataset/fishency/train/masks"
    val_path_imgs = "/home/aleximu/gunes/dinov2/project/dataset/fishency/validation/imgs"
    val_path_masks = "/home/aleximu/gunes/dinov2/project/dataset/fishency/validation/masks"
    
    return train_path_imgs, train_path_masks, val_path_imgs, val_path_masks

train_path_imgs, train_path_masks, val_path_imgs, val_path_masks = get_paths()

def convert_masks(folder_path):
    for filename in os.listdir(folder_path):
        image_path = os.path.join(folder_path, filename)
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        _, black_white_image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
        cv2.imwrite(image_path, black_white_image)
    
convert_masks(train_path_masks)
convert_masks(val_path_masks)

def get_image_paths(folder_path, limit=None):
    image_paths = []
    for dirpath, _, filenames in os.walk(folder_path):
        for filename in filenames:
            image_paths.append(os.path.join(dirpath, filename))
            if limit and len(image_paths) >= limit:
                return natsorted(image_paths)
    return natsorted(image_paths)

def create_dataset_dict(image_paths, mask_paths):
    dataset = Dataset.from_dict({"image": image_paths, "label": mask_paths})
    dataset = dataset.cast_column("image", Image())
    dataset = dataset.cast_column("label", Image())
    return dataset

def initialize_dataset():
    train_path_imgs, train_path_masks, val_path_imgs, val_path_masks = get_paths()
    image_paths_train = get_image_paths(train_path_imgs, limit=10000)
    label_paths_train = get_image_paths(train_path_masks, limit=10000)
    image_paths_val = get_image_paths(val_path_imgs)
    label_paths_val = get_image_paths(val_path_masks)

    train_dataset = create_dataset_dict(image_paths_train, label_paths_train)
    val_dataset = create_dataset_dict(image_paths_val, label_paths_val)
    dataset = DatasetDict({"train": train_dataset, "validation": val_dataset})
    return dataset

def create_id2label():
    id2label = {0: 'background', 1: 'fish'}
    with open('id2label.json', 'w') as fp:
        json.dump(id2label, fp)

dataset = initialize_dataset()

In [None]:
dataset

In [None]:
# Load your example data
example = dataset["train"][5]
image, segmentation_map = example["image"], example["label"]

# Plot both image and segmentation map
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(image)
ax1.set_title('Image')
ax1.axis('off')
ax2.imshow(segmentation_map, cmap='viridis')
ax2.set_title('Segmentation Map')
ax2.axis('off')
plt.show()

In [None]:
id2label = {0: "background", 1: "fish"}
print(id2label)
segmentation_map = np.array(segmentation_map)
print(segmentation_map.shape)

In [None]:
def visualize_map(image, segmentation_map):
    # Assuming segmentation_map is (H, W, 3) and each channel has the same values
    # Convert it to a 2D array (H, W), this assumes all channels are the same so we use the first one
        # Convert segmentation map to 2D if it's 3D, assuming all channels are the same
    if segmentation_map.ndim == 3:
        segmentation_map_2d = segmentation_map[:, :, 0]
    else:
        segmentation_map_2d = segmentation_map
    
    # Find unique labels in the segmentation map
    unique_labels = np.unique(segmentation_map_2d)
    
    # Generate a random color for each label
    id2color = {label: list(np.random.choice(range(256), size=3)) for label in unique_labels}

    # Initialize an empty color_seg array with the same shape as the original image
    color_seg = np.zeros((segmentation_map_2d.shape[0], segmentation_map_2d.shape[1], 3), dtype=np.uint8)

    # Apply colors based on the simplified segmentation map
    for label, color in id2color.items():
        mask = segmentation_map_2d == label
        color_seg[mask] = color

    # Blend the original image with the colored segmentation map
    img = np.array(image).astype(float) * 0.5 + color_seg.astype(float) * 0.5
    img = img.astype(np.uint8)

    # Display the result
    plt.figure(figsize=(9, 6))
    plt.imshow(img)
    plt.axis('off')  # Hide axis for better visualization
    plt.show()

visualize_map(image, segmentation_map)

In [None]:
from PIL import Image
class TwoRandomApply:
    def __init__(self,transform_image, transform_label):
        self.transform_image = transform_image
        self.transform_label = transform_label
        self.state = None

    def __call__(self, img, target):
        # Save the current RNG state
        state = torch.get_rng_state()
        
        # Apply the transformation to the image
        img = self.transform_image(img)
        
        # Restore the RNG state to ensure the same randomness for the label
        torch.set_rng_state(state)
        target = self.transform_label(target)

        return img, target

In [None]:
from torch.utils.data import Dataset, DataLoader

class SegmentationDataset(Dataset):
  def __init__(self, dataset, transform):
    self.dataset = dataset
    self.transform = transform

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    original_image = np.array(item["image"])
    original_segmentation_map = np.array(item["label"])
    
    transformed_image, transformed_target = self.transform(original_image, original_segmentation_map)
    image = torch.tensor(transformed_image)
    target = (torch.tensor(transformed_target)).to(torch.int64)
    target = target.view(448, 448)

    return image, target, original_image, original_segmentation_map


In [None]:
class ResizeAndPad:
    def __init__(self, target_size, multiple):
        self.target_size = target_size
        self.multiple = multiple
        
    def __call__(self, image_array):
        image = Image.fromarray(image_array)
        # Resize the image
        img = transforms.Resize(self.target_size)(image)
        
        # Calculate padding
        pad_width = (self.multiple - img.width % self.multiple) % self.multiple
        pad_height = (self.multiple - img.height % self.multiple) % self.multiple

        # Apply padding
        img = transforms.Pad((pad_width // 2, pad_height // 2, pad_width - pad_width // 2, pad_height - pad_height // 2))(img)
        
        return img

target_size = (448, 448)

In [None]:
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])

train_transform_image = transforms.Compose([ResizeAndPad(target_size, 14),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomVerticalFlip(),
                                      transforms.RandomRotation(360),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=MEAN, std=STD)])

train_transform_label = transforms.Compose([ResizeAndPad(target_size, 14),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.RandomVerticalFlip(),
                                            transforms.RandomRotation(360),
                                            transforms.Grayscale(num_output_channels=1),
                                            transforms.ToTensor()])

validation_transform_image = transforms.Compose([ResizeAndPad(target_size, 14),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(mean=MEAN, std=STD)])

validation_transform_label = transforms.Compose([ResizeAndPad(target_size, 14),
                                                 transforms.Grayscale(num_output_channels=1),
                                                 transforms.ToTensor()])

train_transformation = TwoRandomApply(train_transform_image, train_transform_label)
validation_transformation = TwoRandomApply(validation_transform_image, validation_transform_label)

train_dataset = SegmentationDataset(dataset["train"], transform=train_transformation)
validation_dataset = SegmentationDataset(dataset["validation"], transform=validation_transformation)

In [None]:
pixel_values, target, original_image, original_segmentation_map = train_dataset[0]
print(pixel_values.shape)
print(target.shape)

In [None]:
def show_transformed_images_and_labels(dataset, index):
    # Get the transformed image, target, and their originals from the dataset
    pixel_values, target, original_image, original_segmentation_map = dataset[index]
    
    # Convert the transformed image and label back to NumPy arrays for visualization
    transformed_image = pixel_values.permute(1, 2, 0).numpy()
    transformed_image = (transformed_image * STD + MEAN)  # Unnormalize
    transformed_image = transformed_image.clip(0, 1)  # Clip to ensure it's between 0 and 1
    transformed_label = target.numpy()

    # Plotting
    fig, axs = plt.subplots(2, 2, figsize=(10, 8))

    axs[0, 0].imshow(original_image)
    axs[0, 0].set_title('Original Image')
    axs[0, 0].axis('off')

    axs[0, 1].imshow(original_segmentation_map, cmap='gray')
    axs[0, 1].set_title('Original Label')
    axs[0, 1].axis('off')

    axs[1, 0].imshow(transformed_image)
    axs[1, 0].set_title('Transformed Image')
    axs[1, 0].axis('off')

    axs[1, 1].imshow(transformed_label, cmap='gray')
    axs[1, 1].set_title('Transformed Label')
    axs[1, 1].axis('off')

    plt.show()

# Example usage
# Assuming 'train_dataset' is your dataset instance and you want to inspect the first item
show_transformed_images_and_labels(train_dataset, 0)


In [None]:
def collate_fn(inputs):
    batch = dict()
    batch["pixel_values"] = torch.stack([i[0] for i in inputs], dim=0)
    batch["labels"] = torch.stack([i[1] for i in inputs], dim=0)
    batch["original_images"] = [i[2] for i in inputs]
    batch["original_segmentation_maps"] = [i[3] for i in inputs]

    return batch

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
validation_dataloader = DataLoader(validation_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v,torch.Tensor):
    print(k,v.shape)

In [None]:
batch["pixel_values"].dtype

In [None]:
batch["labels"].dtype

In [None]:
unnormalized_image = (batch["pixel_values"][5].numpy() * np.array(STD)[:, None, None]) + np.array(MEAN)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
unnormalized_image = Image.fromarray(unnormalized_image)
unnormalized_image

In [None]:
[id2label[id] for id in torch.unique(batch["labels"][5]).tolist()]

In [None]:
label_2d = batch["labels"][5].numpy()

visualize_map(unnormalized_image, label_2d)

In [None]:
from transformers import Dinov2Model, Dinov2PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput

class LinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=2):
        super(LinearClassifier, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH

        #This conv2D layer converts the patch embeddings into a logits tensor of shape (batch_size, num_labels, height, width)
        # in_channels ==> embedding_dimension
        self.classifier = torch.nn.Conv2d(in_channels, num_labels, (1,1))

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels) #(batch_size, 32, 32, 768)
        embeddings = embeddings.permute(0,3,1,2)  # (batch_size, 768, 32, 32)

        return self.classifier(embeddings)

In [None]:
class DeepLinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=2):
        super(DeepLinearClassifier, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH

        self.classifier = torch.nn.Sequential(
                                            torch.nn.Conv2d(in_channels, 128, (1,1)),
                                            torch.nn.ReLU(), #to add non-linearity
                                            torch.nn.Conv2d(128, 64, (1,1)),
                                            torch.nn.ReLU(),
                                            torch.nn.Conv2d(64, 32, (1,1)),
                                            torch.nn.ReLU(),
                                            torch.nn.Conv2d(32, num_labels, (1,1))
                                            )

        
    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels) #(batch_size, 32, 32, 768)
        embeddings = embeddings.permute(0,3,1,2)  # (batch_size, 768, 32, 32)

        return self.classifier(embeddings)


In [None]:
import torch.nn.functional as F
class Dinov2ForSemanticSegmentation(Dinov2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.dinov2 = Dinov2Model(config)
        
        
        #takes patch embeddings from an image based on the features extracted by DINO /// predicts a label for each patch
        #self.classifier = LinearClassifier(config.hidden_size, 32, 32, config.num_labels) 
                                                                                          
        self.classifier = DeepLinearClassifier(config.hidden_size, 32, 32, config.num_labels) 

    def forward(self, pixel_values, output_hidden_states=True, output_attentions=True, labels=None):
        
        #Frozen features are used
        #pixel_values are input ! dinov2 model is applied to the input to extract visual features
        
        #print(self.dinov2)
        
        outputs = self.dinov2(pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions)

        # get the patch embeddings to exclude the CLS token
        #CLS token includes summerized informarmation (global features) but we need to local features to segmentation (pixel)
        #[batch_size, sequence_length, embedding_dimension] ==> sequence_length includes the embeddings for all patches and the CLS token
        print(outputs.last_hidden_state.shape)
        patch_embeddings = outputs.last_hidden_state[:, 1:, :]
        print(patch_embeddings.shape)

        # convert to logits and upsample to the size of the pixel values
        logits = self.classifier(patch_embeddings)

        # Interpolation refers to the method used to estimate the values at new points based on the values at known points
        # resizing logits according to height and weight of target image to calculate loss
        #print("first logits shape: ", logits.shape)
        logits = F.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)
        #print("second logits shape: ", logits.shape)
        
        loss = None
        if labels is not None:
            labels = labels.squeeze()
            print("labels shape: ", labels.shape)
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        return SemanticSegmenterOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

In [None]:
model = Dinov2ForSemanticSegmentation.from_pretrained("facebook/dinov2-base", id2label=id2label, num_labels=len(id2label))


In [None]:
for name, param in model.named_parameters():
  if name.startswith("dinov2"):
    param.requires_grad = False

In [None]:
outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])


In [None]:
criterion = torch.nn.CrossEntropyLoss()

def pixel_accuracy_for_fish(preds, labels, fish_class=1):
    fish_mask = (labels == fish_class)
    correct = (preds[fish_mask] == fish_class).sum().item()
    total = fish_mask.sum().item()
    return correct / total if total != 0 else float('nan')



def manual_iou(preds, labels, num_classes):
    iou_per_class = []
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (labels == cls)
        intersection = (pred_inds & target_inds).sum()
        union = pred_inds.sum() + target_inds.sum() - intersection
        if union == 0:
            iou = float('nan')  # avoid division by zero
        else:
            iou = intersection / union
        iou_per_class.append(iou if not math.isnan(iou) else float('nan'))
    valid_iou = [iou for iou in iou_per_class if not math.isnan(iou)]
    mean_iou = sum(valid_iou) / len(valid_iou) if valid_iou else float('nan')
    return mean_iou, iou_per_class

def pixel_accuracy(preds, labels):
    correct = (preds == labels).sum().item()
    total = labels.numel()
    return correct / total

def mean_pixel_accuracy(preds, labels, num_classes):
    accuracies = []
    for cls in range(num_classes):
        cls_mask = (labels == cls)
        cls_total = cls_mask.sum().item()
        if cls_total == 0:
            accuracies.append(float('nan'))
        else:
            correct = (preds[cls_mask] == cls).sum().item()
            accuracies.append(correct / cls_total)
    valid_accuracies = [acc for acc in accuracies if not math.isnan(acc)]
    mean_accuracy = sum(valid_accuracies) / len(valid_accuracies) if valid_accuracies else float('nan')
    return mean_accuracy, accuracies


def precision_recall_f1(preds, labels, fish_class=1):
    pred_inds = (preds == fish_class)
    target_inds = (labels == fish_class)
    
    true_positive = (pred_inds & target_inds).sum().item()
    false_positive = (pred_inds & ~target_inds).sum().item()
    false_negative = (~pred_inds & target_inds).sum().item()
    
    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) != 0 else 0.0
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) != 0 else 0.0
    
    if precision + recall == 0:
        f1_score = 0.0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)
    
    return precision, recall, f1_score

def evaluate_model(model, dataloader, device, num_classes):
    model.eval()
    eval_loss = 0.0
    eval_steps = 0
    total_mean_iou = 0.0
    total_pixel_accuracy = 0.0
    total_mean_pixel_accuracy = 0.0
    total_pixel_accuracy_fish = 0.0
    total_precision_fish = 0.0
    total_recall_fish = 0.0
    total_f1_fish = 0.0
    
    with torch.no_grad():
        for batch in dataloader:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            
            if torch.cuda.device_count() > 1:
                loss = loss.mean() 

            eval_loss += loss.item()
            eval_steps += 1

            # Convert to probabilities and predictions for IoU and PA
            logits = outputs
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            # Calculate IoU
            mean_iou, iou_per_class = manual_iou(preds, labels, num_classes)
            total_mean_iou += mean_iou

            # Calculate Pixel Accuracy (PA)
            pa = pixel_accuracy(preds, labels)
            total_pixel_accuracy += pa

            # Calculate Mean Pixel Accuracy (mPA)
            mean_pa, accuracies_per_class = mean_pixel_accuracy(preds, labels, num_classes)
            total_mean_pixel_accuracy += mean_pa

            # Calculate Pixel Accuracy for Fish
            pa_fish = pixel_accuracy_for_fish(preds, labels, fish_class=1)
            total_pixel_accuracy_fish += pa_fish

            # Calculate Precision, Recall, F1 Score for Fish
            precision_fish, recall_fish, f1_fish = precision_recall_f1(preds, labels, fish_class=1)
            total_precision_fish += precision_fish
            total_recall_fish += recall_fish
            total_f1_fish += f1_fish

    avg_eval_loss = eval_loss / eval_steps
    avg_mean_iou = total_mean_iou / eval_steps
    avg_pixel_accuracy = total_pixel_accuracy / eval_steps
    avg_mean_pixel_accuracy = total_mean_pixel_accuracy / eval_steps
    avg_pixel_accuracy_fish = total_pixel_accuracy_fish / eval_steps
    avg_precision_fish = total_precision_fish / eval_steps
    avg_recall_fish = total_recall_fish / eval_steps
    avg_f1_fish = total_f1_fish / eval_steps

    return avg_eval_loss, avg_mean_iou, iou_per_class, avg_pixel_accuracy, avg_mean_pixel_accuracy, avg_pixel_accuracy_fish, avg_precision_fish, avg_recall_fish, avg_f1_fish

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="DINOv2-FineTuning",
    
    # track hyperparameters and run metadata
    config={
    "Architecture": "DINOv2-Supervised-Deep-Layer",
    "Dataset": "Fishency",
    "output_hidden_states": True,
    "Batch Size": 64,
    "Learning_Rate": 0.000062,
    "Scheduler": "StepLR(step_size=10, gamma=0.1)",
    "Epochs": 50,
    "Optimizer": torch.optim.SGD,
    }
)
learning_rate = 0.000062
epochs = 50

optimizer = AdamW(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=4, verbose=True, threshold=1e-4, min_lr=1e-6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.device_count() > 1:
    print("Using ", torch.cuda.device_count(), "GPUs !!!")
    model = torch.nn.DataParallel(model)

model.to(device)

for epoch in range(epochs):
    model.train()
    print("Epoch:", epoch)

    for idx, batch in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()
        torch.cuda.empty_cache()

        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # forward
        outputs = model(pixel_values, labels=labels)
        loss = outputs.loss
        if torch.cuda.device_count() > 1:
            loss = loss.mean() 

        loss.backward()
        optimizer.step()
       
    # Evaluation step
    num_classes = len(id2label)
    avg_eval_loss, avg_mean_iou, iou_per_class = evaluate_model(model, validation_dataloader, device, num_classes)
    iou_for_fish = iou_per_class[1]

    scheduler.step()
    wandb.log({"learning_rate": scheduler.get_last_lr()[0]})

    wandb.log({
                "Validation Loss": avg_eval_loss,
               "Average Mean IoU": avg_mean_iou,
               "IoU scores for Fish": iou_for_fish
    })

    print(f"Validation Loss: {avg_eval_loss*100:.2f}%")
    print(f"Average Mean IoU: {avg_mean_iou*100:.2f}%")
    print(f"IoU scores for fish: {iou_for_fish*100:.2f}%")


wandb.finish()

In [None]:
model.eval()

random_index = random.randint(0, len(validation_dataset) - 1)
sample = validation_dataset[random_index]

with torch.no_grad():
    pixel_values, true_mask, original_image, original_segmentation_map = sample
    pixel_values = pixel_values.unsqueeze(0).to(device)  # Add batch dimension and send to device
    
    outputs = model(pixel_values)
    logits = outputs.logits
    probs = torch.softmax(logits, dim=1)
    predicted_mask = torch.argmax(probs, dim=1).squeeze().cpu().numpy()  # Remove batch dim

# Convert to displayable format
original_image_display = np.array(original_image).astype(np.uint8)
true_mask_display = np.array(original_segmentation_map).astype(np.uint8)
predicted_mask_image = Image.fromarray(predicted_mask.astype(np.uint8))
predicted_mask_resized = predicted_mask_image.resize(original_image_display.shape[1::-1], Image.NEAREST)


fig, axes = plt.subplots(1, 3, figsize=(20, 10))
axes[0].imshow(original_image_display)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(true_mask_display, cmap='jet')
axes[1].set_title('True Mask')
axes[1].axis('off')

axes[2].imshow(predicted_mask_resized, cmap='jet')
axes[2].set_title('Predicted Mask')
axes[2].axis('off')

plt.show()

In [None]:
# model.eval()

# random_index = random.randint(0, len(validation_dataset) - 1)
# sample = validation_dataset[random_index]

# with torch.no_grad():
#     pixel_values, true_mask, original_image, original_segmentation_map = sample
#     pixel_values = pixel_values.unsqueeze(0).to(device)  # Add batch dimension and send to device
    
#     outputs = model(pixel_values)
#     logits = outputs.logits
#     probs = torch.softmax(logits, dim=1)
#     predicted_mask = torch.argmax(probs, dim=1).squeeze().cpu().numpy()  # Remove batch dim

# # Convert to displayable format
# original_image_display = np.array(original_image).astype(np.uint8)
# true_mask_display = np.array(original_segmentation_map).astype(np.uint8)
# predicted_mask_image = Image.fromarray(predicted_mask.astype(np.uint8))
# predicted_mask_resized = predicted_mask_image.resize(original_image_display.shape[1::-1], Image.NEAREST)


# fig, axes = plt.subplots(1, 3, figsize=(20, 10))
# axes[0].imshow(original_image_display)
# axes[0].set_title('Original Image')
# axes[0].axis('off')

# axes[1].imshow(true_mask_display, cmap='jet')
# axes[1].set_title('True Mask')
# axes[1].axis('off')

# axes[2].imshow(predicted_mask_resized, cmap='jet')
# axes[2].set_title('Predicted Mask')
# axes[2].axis('off')

# plt.show()
