In [None]:
import matplotlib.pyplot as plt
from torchvision.transforms import transforms

from src.data.components.custom_transforms import BilinearInterpolation, NormalizeData
from src.data.components.nyu_dataset import NYUDataset
from src.models.unet_module import UNETLitModule

In [None]:
model_ckpt = "./logs/train/runs/2024-04-06_18-37-38/checkpoints/epoch_015.ckpt"

In [None]:
model = UNETLitModule.load_from_checkpoint(model_ckpt)

In [None]:
model.eval()

In [None]:
transforms_img = transforms.Compose([transforms.PILToTensor(), transforms.Resize((224, 224))])

transforms_mask = transforms.Compose(
    [
        transforms.PILToTensor(),
        NormalizeData(10_000 * (1 / 255)),
        BilinearInterpolation((56, 56)),
    ]
)

In [None]:
test_dataset = NYUDataset("nyu2_test.csv", "data/", transforms_img, transforms_mask)

In [None]:
outputs = []

for i in range(10):
    img, mask = test_dataset[i]
    img = img.unsqueeze(0)
    mask = mask.unsqueeze(0)
    img = img.to(model.device)
    out = model(img)
    outputs.append(out)

In [None]:
def visualize_result(img, mask, out):
    _, axs = plt.subplots(1, 3)
    axs[0].imshow(img.squeeze().permute(1, 2, 0))
    axs[0].set_title("Input Image")
    axs[1].imshow(mask.squeeze())
    axs[1].set_title("Ground Truth")
    axs[2].imshow(out.squeeze().detach().cpu())
    axs[2].set_title("Predicted Mask")

    for ax in axs:
        ax.axis("off")

    plt.show()

In [None]:
for i in range(5):
    visualize_result(test_dataset[i][0], test_dataset[i][1], outputs[i])