In [None]:
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from model.model import ENet

MODEL_PATH = 'model_best.pth'

def display_image(img, size=(8, 8)):
    plt.figure(figsize=size)
    plt.imshow(img, cmap='gray')
    plt.xticks([])
    plt.yticks([])
    plt.show()

In [None]:
# Use eval to enable prediction mode for the model
# Model can only run on Cuda
device = torch.device('cuda') 

model = ENet(num_classes=1)
state = torch.load('model_best.pth')['state_dict']
model.load_state_dict(state)

model = model.to(device).eval()

In [None]:
original_filename = 'demo.jpeg'
image = Image.open(original_filename)

print("Image shape:", image.size)

display_image(image)

In [None]:
transform_pipeline = transforms.Compose([      
    transforms.ToTensor()
])

transformed_image = transform_pipeline(image)
transformed_image = torch.unsqueeze(transformed_image, dim=0)
transformed_image = transformed_image.to(device)

print("Transformed image shape:", transformed_image.shape)
out = model(transformed_image)

In [None]:
# Extract masks

# Check which pixels have probability > 0.5
mask = torch.sigmoid(out.squeeze()) > 0.5

# Convert to numpy
mask = mask.to(torch.uint8).cpu().numpy() * 255

display_image(mask)