In [None]:
from torch.utils.data import Dataset
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
from torch import nn

In [None]:
class hrgldd_dataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        img = self.x[index]
        label = self.y[index]

        img = torch.from_numpy(img).float().permute(2,0,1)
        label = torch.from_numpy(label).float().permute(2,0,1)

        return img, label

In [None]:
path_to_testX = '/content/drive/MyDrive/major_proj/HR-GLDD/testX.npy'
path_to_testY = '/content/drive/MyDrive/major_proj/HR-GLDD/testY.npy'

data_testY = np.load(path_to_testY)
data_testX = np.load(path_to_testX)

In [None]:

def dice_loss(y_pred, y_true, smooth = 1e-6):
    y_pred = y_pred.float().view(-1)
    y_true = y_true.float().view(-1)

    intersection = (y_pred * y_true).sum()
    union = y_pred.sum()+y_true.sum()
    if ((2.0*intersection)+smooth) > (union+smooth):
        print("Error !")
    dice = ((2.0 * intersection) + smooth) / (union + smooth)

    return 1-dice

In [None]:


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels,kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True)
        )

    def forward(self, x):
        return self.conv(x)


class DownSample(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(DownSample, self).__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

    def forward(self,x):
        down = self.conv(x)
        p = self.pool(down)
        return down, p


class UpSample(nn.Module):

    def __init__(self,in_channels,out_channels):
        super(UpSample,self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2,
                                     kernel_size = 2,
                                     stride = 2)
        self.conv = DoubleConv(in_channels,out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1,x2],dim = 1)
        x2 = self.conv(x)
        return x2


class UNet(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        self.down_conv_1 = DownSample(in_channels, 64)
        self.down_conv_2 = DownSample(64, 128)
        self.down_conv_3 = DownSample(128, 256)
        self.down_conv_4 = DownSample(256, 512)

        self.bottleneck = DoubleConv(512, 1024)

        self.up_conv_1 = UpSample(1024,512)
        self.up_conv_2 = UpSample(512,256)
        self.up_conv_3 = UpSample(256,128)
        self.up_conv_4 = UpSample(128,64)

        self.out = nn.Conv2d(in_channels=64, out_channels = 1,
                             kernel_size = 3, padding = 1
                             )

    def forward(self,x):
        #print(f"Input Shape : {x.shape}")

        down1, p1 = self.down_conv_1(x)
        #print(f"Shape after doubl_conv_1_only : {down1.shape}")
        #print(f"Shape after down_conv_1 : {p1.shape}")

        down2, p2 = self.down_conv_2(p1)
        #print(f"Shape after doubl_conv_2_only : {down2.shape}")
        #print(f"Shape after down_conv_2: {p2.shape}")

        down3, p3 = self.down_conv_3(p2)
        #print(f"Shape after doubl_conv_3_only : {down3.shape}")
        #print(f"Shape after down_conv_3 : {p3.shape}")

        down4, p4 = self.down_conv_4(p3)
        #print(f"Shape after doubl_conv_4_only : {down4.shape}")
        #print(f"Shape after down_conv_4 : {p4.shape}")

        b = self.bottleneck(p4)
        #print(f"Shape after bottleneck : {b.shape}")

        up_1 = self.up_conv_1(b, down4)
        #print(f"Shape after up_1 : {up_1.shape}")
        up_2 = self.up_conv_2(up_1,down3)
        #print(f"Shape after up_2 : {up_2.shape}")
        up_3 = self.up_conv_3(up_2, down2)
        #print(f"Shape after up_3 : {up_3.shape}")
        up_4 = self.up_conv_4(up_3, down1)
        #print(f"Shape after up_4 : {up_4.shape}")

        op = self.out(up_4)
        #print(f"Shape of output : {op.shape}")

        return op




In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

models = ["/content/drive/MyDrive/major_proj/trained_models/unetwithbce.pth",
          "/content/drive/MyDrive/major_proj/trained_models/unetwithdiceloss.pth",
          "/content/drive/MyDrive/major_proj/trained_models/unet_aug_dice.pth",
          '/content/drive/MyDrive/major_proj/trained_models/unetauganddiceandbce.pth',
          '/content/drive/MyDrive/major_proj/trained_models/unetauganddicelovasz.pth'
          ]

def evaluate_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = UNet(4,1)

    model.load_state_dict(torch.load('/content/drive/MyDrive/major_proj/trained_models/unet_dice.pth',weights_only=True))
    model.to(device)
    model.eval()

    test_dataset = hrgldd_dataset(
        data_testX,
        data_testY
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False,
    )

    total_loss = 0.0
    dice_scores = []
    iou_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []
    accuracy_scores = []

    # Evaluation loop
    with torch.no_grad():  # No gradients needed for evaluation
        for i, data in enumerate(test_loader):
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = torch.sigmoid(model(inputs))

            loss = dice_loss(outputs, targets)
            total_loss += loss.item()

            predicted = (outputs > 0.5).float()
            targets = targets.float()

            # Calculate metrics
            # dice = calculate_dice(predicted, targets)
            # dice_scores.append(dice.item())

            # iou = calculate_iou(predicted, targets)
            # iou_scores.append(iou.item())

            precision = calculate_precision(predicted, targets)
            precision_scores.append(precision.item())

            recall = calculate_recall(predicted, targets)
            recall_scores.append(recall.item())

            f1 = calculate_f1(precision, recall)
            f1_scores.append(f1.item())

            accuracy = calculate_accuracy(predicted, targets)
            accuracy_scores.append(accuracy.item())

    avg_loss = total_loss / len(test_loader)
    avg_dice = sum(dice_scores) / len(dice_scores)
    avg_iou = sum(iou_scores) / len(iou_scores)
    avg_precision = sum(precision_scores) / len(precision_scores)
    avg_recall = sum(recall_scores) / len(recall_scores)
    avg_f1 = sum(f1_scores) / len(f1_scores)
    avg_accuracy = sum(accuracy_scores) / len(accuracy_scores)

    print(f"Test Loss: {avg_loss:.4f}")
    print(f"Average Dice Score: {avg_dice:.4f}")
    print(f"Average IoU: {avg_iou:.4f}")
    print(f"Average Precision: {avg_precision:.4f}")
    print(f"Average Recall: {avg_recall:.4f}")
    print(f"Average F1 Score: {avg_f1:.4f}")
    print(f"Average Accuracy: {avg_accuracy:.4f}")

    return {
        'loss': avg_loss,
        'dice': avg_dice,
        'iou': avg_iou,
        'precision': avg_precision,
        'recall': avg_recall,
        'f1': avg_f1,
        'accuracy': avg_accuracy
    }

def calculate_precision(pred, target):
    smooth = 1.0
    true_positives = torch.sum(pred * target)
    predicted_positives = torch.sum(pred)
    return (true_positives + smooth) / (predicted_positives + smooth)

def calculate_recall(pred, target):
    smooth = 1.0
    true_positives = torch.sum(pred * target)
    actual_positives = torch.sum(target)
    return (true_positives + smooth) / (actual_positives + smooth)

def calculate_f1(precision, recall):
    # Can also be calculated directly, but using precision and recall for clarity
    return 2 * (precision * recall) / (precision + recall)

def calculate_accuracy(pred, target):
    smooth = 1.0
    correct = torch.sum(pred == target)
    total = torch.numel(pred)
    return (correct + smooth) / (total + smooth)



In [None]:
metrics = evaluate_model()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_comparison(model, test_loader, device, num_images=10):
    model.to(device)
    model.eval()

    with torch.no_grad():
        for i, data in enumerate(test_loader):
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = torch.sigmoid(model(inputs))  # Apply sigmoid to get probabilities

            predicted = (outputs > 0.5).float()  # Threshold to get binary predictions

            if i == 0:
                for j in range(num_images):
                    input_img = inputs[j].cpu().numpy()[0:3,:,:].transpose(1, 2, 0) # Convert to HWC format
                    target_mask = targets[j].cpu().numpy().squeeze()  # Remove channel dimension for ground truth
                    pred_mask = predicted[j].cpu().numpy().squeeze()  # Remove channel dimension for prediction

                    # Plot original image, ground truth, and prediction
                    fig, axes = plt.subplots(1, 3, figsize=(12, 4))

                    axes[0].imshow(input_img)
                    axes[0].set_title("Input Image")
                    axes[0].axis('off')

                    axes[1].imshow(target_mask, cmap='gray')
                    axes[1].set_title("Ground Truth")
                    axes[1].axis('off')

                    axes[2].imshow(pred_mask, cmap='gray')
                    axes[2].set_title("Predicted Mask")
                    axes[2].axis('off')

                    plt.show()

            # Break after visualizing the required number of images
            if i >= num_images // len(test_loader):
                break

In [None]:
test_dataset = hrgldd_dataset(
        data_testX,
        data_testY
    )

test_loader = DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False,
    )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(4,1)

model.load_state_dict(torch.load("/content/drive/MyDrive/major_proj/trained_models/unet_dice.pth", map_location=device))

model.to(device)

model.eval()

visualize_comparison(model, test_loader, device, num_images=10)

## Newer Code

In [None]:
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader

idx = 2

# Define function to calculate metrics
def calculate_metrics(y_true, y_pred):
    iou = jaccard_score(y_true.flatten(), y_pred.flatten(), average='binary')
    dice = 2 * np.sum(y_true.flatten() * y_pred.flatten()) / (np.sum(y_true.flatten()) + np.sum(y_pred.flatten()))
    precision = precision_score(y_true.flatten(), y_pred.flatten(),zero_division = 0)
    recall = recall_score(y_true.flatten(), y_pred.flatten())
    f1 = f1_score(y_true.flatten(), y_pred.flatten())
    return iou, dice, precision, recall, f1

def plot_comparison(images, gt, predictions, model_names):
    # Create a new layout with len(model_names) + 2 columns:
    # The first column is for the RGB image, the second for ground truth,
    # and the remaining columns for model predictions
    fig, axes = plt.subplots(3, len(model_names) + 2, figsize=(20, 12))

    for i in range(3):
        # Plot the RGB Image (BGR to RGB conversion)
        image = images[i][:3, :, :]  # Take the first 3 channels (BGR)
        image = image[[2, 1, 0], :, :]  # Swap channels from BGR to RGB
        image = image.permute(1, 2, 0).cpu().numpy()  # Convert to HxWxC format
        axes[i, 0].imshow(image)
        axes[i, 0].set_title(f"RGB Image {i+1}")
        axes[i, 0].axis('off')

        # Plot the ground truth in the second column
        ground_truth = gt[i].squeeze()  # Assuming ground truth is a single-channel image
        axes[i, 1].imshow(ground_truth, cmap='gray')
        axes[i, 1].set_title("Ground Truth")
        axes[i, 1].axis('off')

        # Plot the predictions for each model in subsequent columns
        for j, model_name in enumerate(model_names):
            # Instead of plotting prediction[i], match predictions to the corresponding index
            prediction = predictions[model_name][i].squeeze()  # Using i to match the same index as images
            if isinstance(prediction, np.ndarray):  # If it's already a NumPy array
                axes[i, j+2].imshow(prediction, cmap='gray')
            else:  # If it's a PyTorch tensor, convert to NumPy
                axes[i, j+2].imshow(prediction.cpu().numpy(), cmap='gray')
            axes[i, j+2].set_title(f"{model_name}")
            axes[i, j+2].axis('off')

    plt.tight_layout()
    plt.show()




# Load your test dataset (use DataLoader for batch processing)
test_dataset = hrgldd_dataset(data_testX, data_testY)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

weights = [
    "/content/drive/MyDrive/major_proj/trained_models/unetwithbce.pth",
    "/content/drive/MyDrive/major_proj/trained_models/unetwithdiceloss.pth",
    "/content/drive/MyDrive/major_proj/trained_models/unet_aug_dice.pth",
    '/content/drive/MyDrive/major_proj/trained_models/unetauganddiceandbce.pth',
    '/content/drive/MyDrive/major_proj/trained_models/unetauganddicelovasz.pth'
]

model_names = ["UNet + BCE", "UNet + Dice", "UNet + Aug + Dice", "UNet + Aug + Dice + BCE", "UNet + Aug + Dice + Lovasz"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize a dictionary for metrics and predictions
metrics = {model_name: {"IoU": [], "Dice": [], "Precision": [], "Recall": [], "F1": []} for model_name in model_names}
predictions = {model_name: [] for model_name in model_names}

for model_name, weight_path in zip(model_names, weights):

    model = UNet(4,1).to(device)

    model.load_state_dict(torch.load(weight_path, map_location = device, weights_only=True))
    model.eval()

    # Iterate through the test set
    for images, masks in test_loader:  # Replace this with your dataset's DataLoader

        images = images.float().to(device)  # Assuming GPU usage; change to .cpu() if using CPU

        # Forward pass
        with torch.no_grad():
            outputs = model(images)
            predicted_mask = torch.sigmoid(outputs).cpu().numpy() > 0.5  # Thresholding

        # Calculate metrics
        iou, dice, precision, recall, f1 = calculate_metrics(masks.cpu().numpy(), predicted_mask)

        # # Store metrics
        metrics[model_name]["IoU"].append(iou)
        metrics[model_name]["Dice"].append(dice)
        metrics[model_name]["Precision"].append(precision)
        metrics[model_name]["Recall"].append(recall)
        metrics[model_name]["F1"].append(f1)

        # Store predictions for comparison plot
        predictions[model_name].append(predicted_mask)

    # Print out average metrics for the model
    avg_metrics = {metric: np.mean(values) for metric, values in metrics[model_name].items()}
    print(f"Metrics for {model_name}: {avg_metrics}")




In [None]:
example_images = [test_dataset[i][0] for i in range(5,8)]  #  3 images
example_masks = [test_dataset[i][1] for i in range(5,8)]  # 3 masks

In [None]:
plot_comparison(example_images, example_masks, predictions, model_names, 5)

In [None]:


# Plot comparison for the 3 example images


In [None]:
import random

random_indices = random.sample(range(len(test_dataset)), 3)

example_images = [test_dataset[i][0] for i in random_indices]
example_masks = [test_dataset[i][1] for i in random_indices]

# Plot comparison for the 3 example images
plot_comparison(example_images, example_masks, predictions, model_names)

In [None]:
import random
import torch
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score
from torch.utils.data import DataLoader

# Select 3 random indices for the images to compare
random_indices = random.sample(range(len(test_dataset)), 3)

example_images = [test_dataset[i][0] for i in random_indices]
example_masks = [test_dataset[i][1] for i in random_indices]

# Define function to calculate metrics
def calculate_metrics(y_true, y_pred):
    iou = jaccard_score(y_true.flatten(), y_pred.flatten(), average='binary')
    dice = 2 * np.sum(y_true.flatten() * y_pred.flatten()) / (np.sum(y_true.flatten()) + np.sum(y_pred.flatten()))
    precision = precision_score(y_true.flatten(), y_pred.flatten(), zero_division=0)
    recall = recall_score(y_true.flatten(), y_pred.flatten())
    f1 = f1_score(y_true.flatten(), y_pred.flatten())
    return iou, dice, precision, recall, f1

def plot_comparison(images, gt, predictions, model_names):
    # Create a new layout with len(model_names) + 2 columns:
    # The first column is for the RGB image, the second for ground truth,
    # and the remaining columns for model predictions
    fig, axes = plt.subplots(3, len(model_names) + 2, figsize=(20, 12))

    for i in range(3):
        # Plot the RGB Image (BGR to RGB conversion)
        image = images[i][:3, :, :]  # Take the first 3 channels (BGR)
        image = image[[2, 1, 0], :, :]  # Swap channels from BGR to RGB
        image = image.permute(1, 2, 0).cpu().numpy()  # Convert to HxWxC format
        axes[i, 0].imshow(image)
        axes[i, 0].set_title(f"RGB Image {i+1}")
        axes[i, 0].axis('off')

        # Plot the ground truth in the second column
        ground_truth = gt[i].squeeze()  # Assuming ground truth is a single-channel image
        axes[i, 1].imshow(ground_truth, cmap='gray')
        axes[i, 1].set_title("Ground Truth")
        axes[i, 1].axis('off')

        # Plot the predictions for each model in subsequent columns
        for j, model_name in enumerate(model_names):
            prediction = predictions[model_name][i].squeeze()
            if isinstance(prediction, np.ndarray):  # If it's already a NumPy array
                axes[i, j+2].imshow(prediction, cmap='gray')
            else:  # If it's a PyTorch tensor, convert to NumPy
                axes[i, j+2].imshow(prediction.cpu().numpy(), cmap='gray')
            axes[i, j+2].set_title(f"{model_name}")
            axes[i, j+2].axis('off')

    plt.tight_layout()
    plt.show()


# Load your test dataset (use DataLoader for batch processing)
test_dataset = hrgldd_dataset(data_testX, data_testY)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

weights = [
    "/content/drive/MyDrive/major_proj/trained_models/unetwithbce.pth",
    "/content/drive/MyDrive/major_proj/trained_models/unetwithdiceloss.pth",
    "/content/drive/MyDrive/major_proj/trained_models/unet_aug_dice.pth",
    '/content/drive/MyDrive/major_proj/trained_models/unetauganddiceandbce.pth',
    '/content/drive/MyDrive/major_proj/trained_models/unetauganddicelovasz.pth'
]

model_names = ["UNet + BCE", "UNet + Dice", "UNet + Aug + Dice", "UNet + Aug + Dice + BCE", "UNet + Aug + Dice + Lovasz"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize a dictionary for metrics and predictions (only for selected images)
predictions = {model_name: [] for model_name in model_names}

for model_name, weight_path in zip(model_names, weights):
    model = UNet(4, 1).to(device)
    model.load_state_dict(torch.load(weight_path, map_location=device, weights_only=True))
    model.eval()

    # Collect predictions for the specific random images
    for i, (images, masks) in enumerate(test_loader):
        images = images.float().to(device)  # Assuming GPU usage; change to .cpu() if using CPU

        with torch.no_grad():
            outputs = model(images)
            predicted_mask = torch.sigmoid(outputs).cpu().numpy() > 0.5  # Thresholding

            if i in random_indices:
                # Store the predictions for the randomly selected images
                for j, model_name in enumerate(model_names):
                    predictions[model_name].append(predicted_mask)

# Plot the comparison of RGB images, ground truth, and model predictions
plot_comparison(example_images, example_masks, predictions, model_names)


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

def plot_comparison(images, gt, predictions, model_names, index):
    # Check if the index is valid
    if index < 0 or index >= len(images):
        raise ValueError(f"Index {index} is out of range. Please enter an index between 0 and {len(images)-1}.")

    # Create a new layout with len(model_names) + 2 columns:
    fig, axes = plt.subplots(1, len(model_names) + 2, figsize=(20, 6))

    # Plot the RGB Image (BGR to RGB conversion)
    image = images[index][:3, :, :]  # Take the first 3 channels (BGR)
    image = image[[2, 1, 0], :, :]  # Swap channels from BGR to RGB
    image = image.permute(1, 2, 0).cpu().numpy()  # Convert to HxWxC format
    axes[0].imshow(image)
    axes[0].set_title(f"RGB Image {index+1}")
    axes[0].axis('off')

    # Plot the ground truth in the second column
    ground_truth = gt[index].squeeze()  # Assuming ground truth is a single-channel image
    axes[1].imshow(ground_truth, cmap='gray')
    axes[1].set_title("Ground Truth")
    axes[1].axis('off')

    # Plot the predictions for each model in subsequent columns
    for j, model_name in enumerate(model_names):
        prediction = predictions[model_name][index].squeeze()
        if isinstance(prediction, np.ndarray):  # If it's already a NumPy array
            axes[j+2].imshow(prediction, cmap='gray')
        else:  # If it's a PyTorch tensor, convert to NumPy
            axes[j+2].imshow(prediction.cpu().numpy(), cmap='gray')
        axes[j+2].set_title(f"{model_name}")
        axes[j+2].axis('off')

    plt.tight_layout()
    plt.show()


# Load your test dataset (use DataLoader for batch processing)
test_dataset = hrgldd_dataset(data_testX, data_testY)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

weights = [
    "/content/drive/MyDrive/major_proj/trained_models/unetwithbce.pth",
    "/content/drive/MyDrive/major_proj/trained_models/unetwithdiceloss.pth",
    "/content/drive/MyDrive/major_proj/trained_models/unet_aug_dice.pth",
    '/content/drive/MyDrive/major_proj/trained_models/unetauganddiceandbce.pth',
    '/content/drive/MyDrive/major_proj/trained_models/unetauganddicelovasz.pth'
]

model_names = ["UNet + BCE", "UNet + Dice", "UNet + Aug + Dice", "UNet + Aug + Dice + BCE", "UNet + Aug + Dice + Lovasz"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize a dictionary for storing predictions
predictions = {model_name: [] for model_name in model_names}

# Load models and store predictions
for model_name, weight_path in zip(model_names, weights):
    model = UNet(4,1).to(device)
    model.load_state_dict(torch.load(weight_path, map_location=device, weights_only=True))
    model.eval()

    # Iterate through the test set
    for images, masks in test_loader:
        images = images.float().to(device)  # Assuming GPU usage; change to .cpu() if using CPU

        # Forward pass
        with torch.no_grad():
            outputs = model(images)
            predicted_mask = torch.sigmoid(outputs).cpu().numpy() > 0.5  # Thresholding

        # Store predictions for comparison plot
        predictions[model_name].append(predicted_mask)

# Check the length of example_images
print(f"Length of example_images: {len(example_images)}")
print(f"Length of example_masks: {len(example_masks)}")

# Now you can input an index to plot the image, ground truth, and predictions
index = 0  # Example index; replace this with the index of your choice
plot_comparison(example_images, example_masks, predictions, model_names, index)


In [None]:
import random
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

def run_model_inference(test_dataset, start_index=0):
    """
    Run inference on 3 consecutive images starting from the given index
    and plot the results for all models.

    Args:
    - test_dataset: The test dataset
    - start_index: Starting index for selecting images (default is 0)

    Returns:
    - Matplotlib figure with image comparisons
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    weights = [
        "/content/drive/MyDrive/major_proj/trained_models/unetwithbce.pth",
        "/content/drive/MyDrive/major_proj/trained_models/unetwithdiceloss.pth",
        "/content/drive/MyDrive/major_proj/trained_models/unet_aug_dice.pth",
        '/content/drive/MyDrive/major_proj/trained_models/unetauganddiceandbce.pth',
        '/content/drive/MyDrive/major_proj/trained_models/unetauganddicelovasz.pth'
    ]

    model_names = [
        "UNet + BCE",
        "UNet + Dice",
        "UNet + Aug + Dice",
        "UNet + Aug + Dice + BCE",
        "UNet + Aug + Dice + Lovasz"
    ]

    # Select 3 consecutive images starting from start_index
    example_images = [test_dataset[start_index+i][0] for i in range(3)]
    example_masks = [test_dataset[start_index+i][1] for i in range(3)]

    # Initialize a dictionary for predictions
    predictions = {model_name: [] for model_name in model_names}

    # Run inference for each model
    for model_name, weight_path in zip(model_names, weights):
        # Initialize and load model
        model = UNet(4, 1).to(device)
        model.load_state_dict(torch.load(weight_path, map_location=device, weights_only=True))
        model.eval()

        # Run inference for selected images
        for image in example_images:
            # Prepare image for inference
            input_image = image.unsqueeze(0).float().to(device)

            with torch.no_grad():
                outputs = model(input_image)
                predicted_mask = torch.sigmoid(outputs).cpu().numpy() > 0.5  # Thresholding
                predictions[model_name].append(predicted_mask[0])  # Remove batch dimension

    # Create plot comparison
    def plot_comparison(images, gt, predictions, model_names):
        # Create a new layout with len(model_names) + 2 columns
        fig, axes = plt.subplots(3, len(model_names) + 2, figsize=(20, 12))

        for i in range(3):
            # Plot the RGB Image (BGR to RGB conversion)
            image = images[i][:3, :, :]  # Take the first 3 channels (BGR)
            image = image[[2, 1, 0], :, :]  # Swap channels from BGR to RGB
            image = image.permute(1, 2, 0).cpu().numpy()  # Convert to HxWxC format
            axes[i, 0].imshow(image)
            axes[i, 0].set_title(f"RGB Image")
            axes[i, 0].axis('off')

            # Plot the ground truth in the second column
            ground_truth = gt[i].squeeze()  # Assuming ground truth is a single-channel image
            axes[i, 1].imshow(ground_truth, cmap='gray')
            axes[i, 1].set_title("Ground Truth")
            axes[i, 1].axis('off')

            # Plot the predictions for each model in subsequent columns
            for j, model_name in enumerate(model_names):
                prediction = predictions[model_name][i].squeeze()
                if isinstance(prediction, np.ndarray):  # If it's already a NumPy array
                    axes[i, j+2].imshow(prediction, cmap='gray')
                else:  # If it's a PyTorch tensor, convert to NumPy
                    axes[i, j+2].imshow(prediction.cpu().numpy(), cmap='gray')
                axes[i, j+2].set_title(f"{model_name}")
                axes[i, j+2].axis('off')

        plt.tight_layout()
        plt.show()
        return fig

    # Call the plot comparison function
    return plot_comparison(example_images, example_masks, predictions, model_names)

# Usage example:
run_model_inference(test_dataset, start_index=55)