In [None]:
import os
import ssl,sys

import albumentations as A
import matplotlib.pyplot as plt
import numpy as np

if not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(
    ssl, "_create_unverified_context", None
):
    ssl._create_default_https_context = ssl._create_unverified_context


import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
sys.path.append("../")
import config as cfg
import data_loader as dl
import utils
import wandb

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

dataloaders = utils.load_and_process_data(batch_size_train=16, batch_size_val_test=16, num_workers=0)

In [None]:
import segmentation_models_pytorch as segmentation_models

class Reshape(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1, 4)
def build_model( device="cpu"):
    """
    model=config.model, dropout=config.dropout, activation=config.activation)
 
    """

    
    model = segmentation_models.Unet(
        encoder_name='efficientnet-b4',
        encoder_weights='imagenet',
        classes=1, # should not matter for now as we will modify last layer
        activation=None, # should not matter for now as we will modify last layer
    )
   

  
    
    
    
    
        # Modify the last layer to match the output size
    model.segmentation_head = nn.Sequential(
        nn.AdaptiveAvgPool2d((1, 1)),  # Global average pooling
        nn.Flatten(),
        nn.Linear(16, 10 * 4), # its only 16 because of the global pooling
        nn.Sigmoid(),  # Adding ctivation function
        nn.Dropout(p=0.009),  # Adding dropout layer
        Reshape(),  # Reshape from [batch, 40] to [batch, 10, 4]
    )

    # Move model to the device
    model = model.to(device)

    return model

In [None]:
model = build_model()
model.load_state_dict(torch.load("./regression_results/unet_balmy-snowflake-7_MSE_best_model.pth"))
model.eval()

In [None]:
# predict on test set

inputs_lst = []
labels_septal_lst = []
labels_lateral_lst = []
outputs_lst_septal = []
outputs_lst_lateral = []

for i, data in enumerate(dataloaders['test']):
    if i == 1:
        break
    inputs = data["image"]
    inputs = inputs.unsqueeze(1) # add a channel dimension
    inputs = inputs.repeat(
            1, 3, 1, 1
        )  # repeat the channel dimension to create a 3-channel image

    original_size = inputs.shape[-2:]

    inputs_lst.append(inputs.float())

    #inputs = F.interpolate(inputs, size=(224, 224))
    labels_septal = data["landmarks"]["leaflet_septal"]
    labels_septal = utils.readjust_keypoints(labels_septal, original_size)
    labels_septal_lst.append(labels_septal)

    labels_lateral = data["landmarks"]["leaflet_lateral"]
    labels_lateral = utils.readjust_keypoints(labels_lateral, original_size)
    labels_lateral_lst.append(labels_lateral)


    outputs = model(inputs) # shape: (batch_size, 10, 4)
    print("outputs.shape", outputs.shape)

    outputs_septal = outputs[:, :, :2] # shape: (batch_size, 10, 2)
    outputs_lateral = outputs[:, :, 2:] # shape: (batch_size, 10, 2)
    # detach
    outputs_septal = outputs_septal.detach().numpy()
    outputs_lateral = outputs_lateral.detach().numpy()

    outputs_septal = utils.readjust_keypoints(outputs_septal, original_size)
    outputs_lateral = utils.readjust_keypoints(outputs_lateral, original_size)

    outputs_lst_septal.append(outputs_septal)
    outputs_lst_lateral.append(outputs_lateral)


In [None]:
# now plot the images with the predicted keypoints

fig = plt.figure(figsize=(10, 10))
for j in range(1, 12):
    ax = plt.subplot(4, 4, j)
    plt.imshow(inputs_lst[0][j][0], cmap="gray")


    plt.scatter(
        labels_septal_lst[0][j][:, 0], labels_septal_lst[0][j][:, 1], s=8, marker=".", c="g", label="true"
    )

    plt.scatter(
        labels_lateral_lst[0][j][:, 0], labels_lateral_lst[0][j][:, 1], s=8, marker=".", c="g"
    )

    plt.scatter(
        outputs_lst_septal[0][j][:, 0], outputs_lst_septal[0][j][:, 1], s=8, marker=".", c="r"
    )

    plt.scatter(
        outputs_lst_lateral[0][j][:, 0], outputs_lst_lateral[0][j][:, 1], s=8, marker=".", c="r", label="predicted"
    )

    #add legend
    plt.legend(loc="upper right", prop={'size': 6})
    # make font smaller
    plt.tick_params(axis="both", which="major", labelsize=1)
    
plt.savefig("unet-regr_test.png")
plt.show()
