In [2]:
import torch
from utils.EnsembleUNetTorch import UNet, UNetEnsemble

device = torch.device("cuda" if torch.cuda.is_available() else "mps")

holdout_path = "datasets/kitti_holdout/image_2"
fgsm_path = "datasets/kitti_holdout_FGSM/image_2"
pgd_path = "datasets/kitti_holdout_PGD/image_2"

single_unet = UNet(in_channels=3, out_channels=3).to(device)
single_checkpoint = torch.load("best_unet_model_3.pth", map_location=device)
single_unet.load_state_dict(single_checkpoint["model_state_dict"])
single_unet.eval()

model_paths = ["unet_model_0.pth", "unet_model_1.pth", "unet_model_2.pth"]
ensemble_unet = UNetEnsemble(model_paths=model_paths, device=device).to(device)
ensemble_unet.eval()

  single_checkpoint = torch.load("best_unet_model_3.pth", map_location=device)


UNetEnsemble(
  (models): ModuleList(
    (0-2): 3 x UNet(
      (encoder1): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
      (encoder2): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
      (encoder3): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (middle): Sequential(
        (0): Conv2d(256, 512, kernel_s

In [14]:
from PIL import Image
import torch
import torchvision.transforms as transforms
import numpy as np

image_path = "datasets/kitti_holdout/image_2/umm_000008.png"  # Example
mask_path = "datasets/kitti_holdout/gt_image_2/umm_000008.png"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Define a size compatible with UNet (divisible by 32)
H, W = 384, 1280  # Choose the closest multiple of 32

# Transform with resizing
transform = transforms.Compose([
    transforms.Resize((H, W)),   # Resize to ensure divisibility
    transforms.ToTensor(),
])

image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension


print(f"Image Shape: {image_tensor.shape}")  # Debugging

Image Shape: torch.Size([1, 3, 384, 1280])


In [18]:
mask = Image.open(mask_path).convert("L")  # Load grayscale
mask_np = np.array(mask)  # Convert to NumPy
print("Original Unique Values in Mask:", np.unique(mask_np))

Original Unique Values in Mask: [  0  76 105]


In [21]:
import cv2

# Load mask as grayscale and convert to NumPy
mask = Image.open(mask_path).convert("L")
mask_np = np.array(mask)  # Convert to NumPy array
H, W = 384, 1280  # Same as the image

# Resize using OpenCV (avoids PyTorch interpolation issues)
mask_resized = cv2.resize(mask_np, (W, H), interpolation=cv2.INTER_NEAREST)

# Convert back to tensor
mask = torch.tensor(mask_resized, dtype=torch.long, device=device)

# Fix class labels
mask[mask == 76] = 1
mask[mask == 105] = 2

print(f"Fixed Mask Unique Values: {torch.unique(mask)}")

Fixed Mask Unique Values: tensor([0, 1, 2], device='mps:0')


In [22]:
# Ensure models are on the correct device
single_unet.to(device)
ensemble_unet.to(device)

# Ensure image tensor is also on the same device
image_tensor = image_tensor.to(device)


with torch.no_grad():
    single_pred = single_unet(image_tensor)  # Output: [1, num_classes, H, W]
    ensemble_pred = ensemble_unet(image_tensor)  # Output: [1, num_classes, H, W]

# Convert to Class Labels (Argmax over the channel dimension)
single_pred_mask = torch.argmax(single_pred, dim=1).squeeze().cpu()
ensemble_pred_mask = torch.argmax(ensemble_pred, dim=1).squeeze().cpu()

print(f"Single UNet Pred Mask Unique Values: {torch.unique(single_pred_mask)}")
print(f"Ensemble UNet Pred Mask Unique Values: {torch.unique(ensemble_pred_mask)}")

Single UNet Pred Mask Unique Values: tensor([0, 2])
Ensemble UNet Pred Mask Unique Values: tensor([0, 2])
