In [None]:
import torch
from PIL import Image
from torchvision import transforms
from models.uNet import UNet
from io import BytesIO
import requests
from metrics.segmentation_metrics import compute_metrics

In [None]:
model = UNet(n_channels=3, n_classes=3)
ckpt = torch.load(rf"C:\work\an 3\dl\face-segmentation\checkpoints\breezy-sweep-11_ckpt_epoch_21_mean_iou.pth")
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

In [None]:
preprocess = transforms.Compose([
    transforms.Resize((256, 256), interpolation=Image.LANCZOS),
    transforms.ToTensor()])

image_input = Image.open(rf"C:\work\an 3\dl\face-segmentation\data\lfw_dataset\lfw_funneled\Aaron_Eckhart\Aaron_Eckhart_0001.jpg")

image_input = preprocess(image_input)

with torch.no_grad():
    output = model(image_input.unsqueeze(0))

output = (output - output.min()) / (output.max() - output.min())
tensor_to_pil = transforms.ToPILImage()
pil_image = tensor_to_pil(output.squeeze())
pil_image.show()

mpa, m_iou, m_fw_iou = compute_metrics(image_input.unsqueeze(0), output)
print(f"MPA: {mpa}, MIoU: {m_iou}, MFWIoU: {m_fw_iou}")

scripted_model = torch.jit.trace(model, image_input.unsqueeze(0))

scripted_model.save(rf"C:\work\an 3\dl\face-segmentation\checkpoints\unet_scripted_model.pt")

In [None]:
model = torch.jit.load(rf"C:\work\an 3\dl\face-segmentation\checkpoints\unet_scripted_model.pt")

image_url = 'https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?q=80&w=1780&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D'
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert('RGB')

input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)

with torch.no_grad():
    output = model(input_batch)
    
output = (output - output.min()) / (output.max() - output.min())
tensor_to_pil = transforms.ToPILImage()
pil_image = tensor_to_pil(output.squeeze())
pil_image.show()

mpa, m_iou, m_fw_iou = compute_metrics(input_batch, output)

print(f"MPA: {mpa}, MIoU: {m_iou}, MFWIoU: {m_fw_iou}")