In [1]:
import sys

sys.path.append("..") 

In [2]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image

from dataset import (
    InstanceSegmentationLazyDataset,
    _files_in_dir,
)
from helper import (
    boundaries_mirror_y,
    masks_to_boundary,
    model_masks_output,
)

In [3]:
checkpoint_path = Path("../../out/")
images_path = Path("../../data/test/")
boundaries_path = Path("../../data/test/")
output_dir = Path("../../data/test/")

In [None]:
device = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("cpu")
)
print(device)

In [5]:
_image_suffixes = [".tif"]

# Inference

In [None]:
checkpoint = torch.load(
    checkpoint_path / "model" / "best_model.tar",
    weights_only=False,
)

model: torch.nn.Module = checkpoint["model"]
model.eval()
model.to(device)

images = _files_in_dir(images_path)
images = [i for i in images if i.suffix in _image_suffixes]

for image_path in images:
    image = Image.open(image_path)
    image_tensor = TF.to_tensor(image).unsqueeze(0).to(device)

    masks = model_masks_output(model, image_tensor)
    boundaries = masks_to_boundary((masks * 255).astype(np.uint8))
    boundaries_mirrored = boundaries_mirror_y(boundaries)
    boundaries_mirrored.to_file(output_dir / f"{image_path.stem}.shp")

    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image)
    boundaries.boundary.plot(ax=ax, edgecolor="red")
    plt.show()

# Instance Segmentation Evaluation

In [None]:
checkpoint = torch.load(
    checkpoint_path / "model" / "best_model.tar",
    weights_only=False,
)

model: torch.nn.Module = checkpoint["model"]
# model.to(device)

dataset = InstanceSegmentationLazyDataset(images_path, boundaries_path)

for image_boundary_pair in dataset:
    with torch.no_grad():
        model.train()
        loss = model([image_boundary_pair[0]], [image_boundary_pair[1]])

    for k, v in loss.items():
        print(f"{k}: {v.item()}")

    real_mask = image_boundary_pair[1]["masks"]
    real_boundaries = masks_to_boundary((real_mask * 255).numpy().astype(np.uint8))

    model.eval()
    model_mask = model_masks_output(model, [image_boundary_pair[0]])
    model_boundaries = masks_to_boundary((model_mask * 255).astype(np.uint8))

    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image_boundary_pair[0].permute(1, 2, 0))
    real_boundaries.boundary.plot(ax=ax, edgecolor="red", alpha=0.5)
    model_boundaries.boundary.plot(ax=ax, edgecolor="blue", alpha=0.5)
    plt.show()