In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
import cv2
from torch.utils.data import DataLoader
from model import UNet  
from dataset import ISICDataset  

# Define paths and parameters
images_path = "data/images"  
masks_path = "data/labels"  
img_size = 256  
num_samples = 5  # Number of images to visualize

In [None]:
# Load test dataset
ids = [image_file[:-4] for image_file in os.listdir(images_path) if image_file.endswith('.jpg')]
train_size = int(0.8 * len(ids))
val_size = int(0.1 * len(ids))
test_ids = ids[train_size + val_size:]
test_transform = A.Compose([A.Resize(height=img_size, width=img_size)])
test_dataset = ISICDataset(images_path, masks_path, test_ids, img_size, test_transform)

In [None]:
# Initial visualization of test images and true masks
fig, axes = plt.subplots(num_samples, 2, figsize=(10, 4 * num_samples))
for i in range(num_samples):
    img, true_mask = test_dataset[i]
    img = img.permute(1, 2, 0).numpy()
    true_mask = true_mask.permute(1, 2, 0).squeeze().numpy()

    true_mask_resized = cv2.resize(true_mask, (img.shape[1], img.shape[0]))
    true_mask_stacked = np.stack((true_mask_resized,) * 3, axis=-1)

    axes[i, 0].imshow(img)
    axes[i, 0].set_title('Original Image')
    axes[i, 0].axis('off')

    axes[i, 1].imshow(true_mask_stacked)
    axes[i, 1].set_title('True Mask')
    axes[i, 1].axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Load trained model
checkpoint_path = "checkpoint.pth.tar"  # Update with actual path
model = UNet(n_channels=3, n_classes=1)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

In [None]:
# Visualization with generated masks
fig, axes = plt.subplots(num_samples, 3, figsize=(10, 4 * num_samples))
for i in range(num_samples):
    img, true_mask = test_dataset[i]
    img = img.permute(1, 2, 0).numpy()
    true_mask = true_mask.permute(1, 2, 0).squeeze().numpy()

    with torch.no_grad():
        img_tensor = torch.Tensor(img).unsqueeze(0).permute(0, 3, 1, 2).to(device)
        generated_mask = model(img_tensor).squeeze().cpu().numpy()
        generated_mask = (generated_mask > 0.5).astype(np.float32)

    true_mask_resized = cv2.resize(true_mask, (img.shape[1], img.shape[0]))
    true_mask_stacked = np.stack((true_mask_resized,) * 3, axis=-1)

    generated_mask_resized = cv2.resize(generated_mask, (img.shape[1], img.shape[0]))
    generated_mask_stacked = np.stack((generated_mask_resized,) * 3, axis=-1)

    axes[i, 0].imshow(img)
    axes[i, 0].set_title('Original Image')
    axes[i, 0].axis('off')

    axes[i, 1].imshow(true_mask_stacked)
    axes[i, 1].set_title('True Mask')
    axes[i, 1].axis('off')

    axes[i, 2].imshow(generated_mask_stacked)
    axes[i, 2].set_title('Generated Mask')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()