In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
import torch
from PIL import Image
from dataset.preprocessing import sample_pixels, get_surface_pixels, pil_to_binary, binary_to_image
import torchvision.transforms as transforms

# Import the UNet class
from model.simple_unet import UNet


In [None]:
#Used to generate samples from segmaps
"""
# Directory paths
input_dir = './overfit_data'
output_dir = './overfit_data/samples'

# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Iterate through all files in the input directory
for filename in os.listdir(input_dir):
    if filename.endswith('.png'):
        # Load the image
        image_path = os.path.join(input_dir, filename)
        image = Image.open(image_path)
        binary_image = pil_to_binary(image)

        surface_pixels = get_surface_pixels(binary_image)
        
        # Take a sample of 15 points from the surface
        sampled_points = sample_pixels(surface_pixels, 15)
        
        # Save the sampled points to the output directory
        output_path = os.path.join(output_dir, filename)
        plt.imsave(output_path, sampled_points, cmap='gray')

print("Sampling completed and saved to", output_dir)
"""

In [None]:
# Specify the device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
if torch.backends.mps.is_available():
    device = 'mps'

# Load the model and map it to the GPU
model = torch.load("model_full_overfitting.pth", map_location=device)

# Set the model to evaluation mode
model.eval()

print("Model loaded onto", device)

In [None]:
# Load the test image
image_path = r"./overfit_data/samples/8.png"
image = Image.open(image_path).convert('L')
image = np.array(image)
target_image_path = r"./overfit_data/8.png"
target_image = Image.open(target_image_path).convert('L')
target_image = np.array(target_image)
input = transforms.ToTensor()(image).unsqueeze(0).to(device)




# Assuming the model is defined and loaded
model.eval()  # Set the model to evaluation mode (disables dropout, batch norm updates)
with torch.no_grad():  # Disable gradient calculation for evaluation
    y = model(input)  # Get the model's predictions

print(y.shape)

y_image = y[0][0].cpu().numpy()
binary_output = (y_image > 0.5).astype(np.uint8)
#show image, probability heatmap and ground truth
plt.figure(figsize=(12, 6))
plt.subplot(1, 4, 1)
plt.imshow(image, cmap='gray')
plt.title('Input Image')
plt.axis('off')
plt.subplot(1, 4, 2)
plt.imshow(y_image, cmap='jet')
plt.title('Probability Heatmap')
plt.axis('off')
plt.subplot(1, 4, 3)
plt.imshow(binary_output, cmap='gray')
plt.title('Binary Output')
plt.axis('off')
plt.subplot(1, 4, 4)
plt.imshow(target_image, cmap='gray')
plt.title('Ground Truth')
plt.axis('off')
plt.show()

#Calculate loss
loss = torch.nn.BCELoss()
target = transforms.ToTensor()(target_image).unsqueeze(0).to(device)
loss_value = loss(y, target)
print('Loss:', loss_value.item())