## Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from dataset.image_dataset import ImageDataset

## Setup

In [None]:
# Specify the device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
if torch.backends.mps.is_available():
    device = 'mps'

# Load the model and map it to the GPU
model = torch.load("saved_models/model_full.pth", map_location=device)

# Set the model to evaluation mode
model.eval()

print("Model loaded onto", device)

In [None]:
image_dataset = ImageDataset("data", test=True)
# Load 5 random images and targets
fig, axes = plt.subplots(5, 2, figsize=(10, 20))
for i in range(5):
    # Get random image and target
    idx = np.random.randint(0, len(image_dataset))
    image, target = image_dataset[idx]
    
    # Plot image
    axes[i,0].imshow(image[0], cmap='gray')
    axes[i,0].set_title(f'Image {i+1}')
    axes[i,0].axis('off')
    
    # Plot target
    axes[i,1].imshow(target[0], cmap='gray')
    axes[i,1].set_title(f'Target {i+1}')
    axes[i,1].axis('off')

plt.tight_layout()
plt.show()

## Make Prediction
Load random (image, target) pair from the test set and let the model make a prediction.
Plot the original image, the probability heatmap, the binary output as well as the ground truth.

In [None]:
image_dataset = ImageDataset("data", test=True)
# load random image and target
image, target = image_dataset[np.random.randint(0, len(image_dataset))]
image = image.to(device).unsqueeze(0)
target = target.to(device).unsqueeze(0)
target_image = target[0][0].cpu().numpy()

with torch.no_grad():  # Disable gradient calculation for evaluation
    y = model(input)  # Get the model's predictions

y_image = y[0][0].cpu().numpy()
binary_output = (y_image > 0.5).astype(np.uint8)
#show image, probability heatmap and ground truth
plt.figure(figsize=(12, 6))
plt.subplot(1, 4, 1)
plt.imshow(image, cmap='gray')
plt.title('Input Image')
plt.axis('off')
plt.subplot(1, 4, 2)
plt.imshow(y_image, cmap='jet')
plt.title('Probability Heatmap')
plt.axis('off')
plt.subplot(1, 4, 3)
plt.imshow(binary_output, cmap='gray')
plt.title('Binary Output')
plt.axis('off')
plt.subplot(1, 4, 4)
plt.imshow(target_image, cmap='gray')
plt.title('Ground Truth')
plt.axis('off')
plt.show()

#Calculate loss
loss = torch.nn.BCELoss()
target = transforms.ToTensor()(target_image).unsqueeze(0).to(device)
loss_value = loss(y, target)
print('Loss:', loss_value.item())

