In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from dataset.image_dataset import ImageDataset

In [None]:
# Specify the device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
if torch.backends.mps.is_available():
    device = 'mps'

# Load the model and map it to the GPU
model = torch.load("saved_models/model_full.pth", map_location=device)

# Set the model to evaluation mode
model.eval()

print("Model loaded onto", device)

In [None]:
image_dataset = ImageDataset("data", test=True)
plt.figure(figsize=(20, 4))
rand_indices = np.random.randint(0, len(image_dataset), 10)
for i, idx in enumerate(rand_indices):
    # Get random image and target
    img, tgt = image_dataset[idx]
    
    # Plot image
    plt.subplot(2, 10, i+1)
    plt.imshow(img[0].cpu().numpy(), cmap='gray')
    plt.axis('off')
    
    # Plot target
    plt.subplot(2, 10, i+11)
    plt.imshow(tgt[0].cpu().numpy(), cmap='gray')
    plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Load the test image
image_dataset = ImageDataset("data", test=True)
image, target = image_dataset[0]
image = image.to(device).unsqueeze(0)
target = target.to(device).unsqueeze(0)
target_image = target[0][0].cpu().numpy()

with torch.no_grad():  # Disable gradient calculation for evaluation
    y = model(input)  # Get the model's predictions

y_image = y[0][0].cpu().numpy()
binary_output = (y_image > 0.5).astype(np.uint8)
#show image, probability heatmap and ground truth
plt.figure(figsize=(12, 6))
plt.subplot(1, 4, 1)
plt.imshow(image, cmap='gray')
plt.title('Input Image')
plt.axis('off')
plt.subplot(1, 4, 2)
plt.imshow(y_image, cmap='jet')
plt.title('Probability Heatmap')
plt.axis('off')
plt.subplot(1, 4, 3)
plt.imshow(binary_output, cmap='gray')
plt.title('Binary Output')
plt.axis('off')
plt.subplot(1, 4, 4)
plt.imshow(target_image, cmap='gray')
plt.title('Ground Truth')
plt.axis('off')
plt.show()

#Calculate loss
loss = torch.nn.BCELoss()
target = transforms.ToTensor()(target_image).unsqueeze(0).to(device)
loss_value = loss(y, target)
print('Loss:', loss_value.item())

