In [38]:
import sys
sys.path.append("..") 

In [39]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image

from helper import (
    boundaries_mirror_y,
    masks_to_boundary,
    model_masks_output,
)
from dataset import _files_in_dir

In [40]:
checkpoint_path = Path("../../out/")
images_path = Path("../../data/test/")
boundaries_path = Path("../../data/test/")
output_dir = Path("../../data/test/")

In [None]:
device = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("cpu")
)
print(device)

# Inference

In [43]:
_image_suffixes = [".tif"]

In [None]:
checkpoint = torch.load(
    checkpoint_path / "model" / "best_model.tar",
    weights_only=False,
)

model: torch.nn.Module = checkpoint["model"]
model.eval()
model.to(device)

images = _files_in_dir(images_path)
images = [i for i in images if i.suffix in _image_suffixes]

for image_path in images:
    image = Image.open(image_path)
    image_tensor = TF.to_tensor(image).unsqueeze(0).to(device)

    masks = model_masks_output(model, image_tensor)
    boundaries = masks_to_boundary((masks * 255).astype(np.uint8))
    boundaries_mirrored = boundaries_mirror_y(boundaries)
    boundaries_mirrored.to_file(output_dir / f"{image_path.stem}.shp")

    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image)
    boundaries.boundary.plot(ax=ax, edgecolor="red")
    plt.show()