In [None]:
import os


base_dir = "../../scripts/segmentation/models/carla_parametric"
config_dir = os.path.join(base_dir, "configs")
model_dir = os.path.join(base_dir, "models")

In [None]:
import torch
from torchvision import transforms

from avstack.config import Config
from fov.segmentation.dataset import BinaryFovDataset
from fov.segmentation.utils import get_unet_model


class ToDevice:
    def __init__(self, device):
        self.device = device

    def __call__(self, image: torch.Tensor):
        return image.to(self.device)


def get_model_dataset(width: int, depth: int, resol: int, device="cpu"):
    # specify the cfg to use
    cfg_name = f"width_{width}_depth_{depth}_resolution_{resol}"
    cfg_file = os.path.join(
        config_dir, cfg_name + ".py"
    )
    cfg = Config.fromfile(cfg_file)


    # set up the transforms
    trans = transforms.Compose([
        ToDevice(device=device),
        transforms.Resize(size=cfg["model_io_size"]),
    ])


    # load the dataset
    split = "val"
    seg_dataset = BinaryFovDataset(
        cfg["data_output_dir"],
        transform=trans,
        transform_mask=trans,
        split=split,
        max_range=cfg["max_range"],
        extent=cfg["extent"],
        img_size=cfg["img_size"],
    )

    # load a trained model
    model = get_unet_model(
        cfg=cfg,
        weight_dir=os.path.join(model_dir, cfg_name),
        device=device
    )

    return model, seg_dataset

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors


def fill_holes(image: np.ndarray):
    """Fills holes in a binary image."""
    # Threshold the image to ensure it's binary (if needed)
    # _, thresh = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)

    # Copy the thresholded image
    im_floodfill = image.copy()

    # Mask used for flood fill.
    # Size needs to be 2 pixels larger than the image
    h, w = image.shape[:2]
    mask = np.zeros((h+2, w+2), np.uint8)

    # Flood fill from point (0, 0)
    cv2.floodFill(im_floodfill, mask, (0,0), 255)

    # Invert flood filled image
    im_floodfill_inv = cv2.bitwise_not(im_floodfill)

    # Combine the two images to get the filled holes
    im_out = image | im_floodfill_inv

    return im_out


def show_results(idx, model, seg_dataset, save_prefix):
    # set colormaps
    plt.set_cmap("Greys")
    cmap_base = colors.LinearSegmentedColormap.from_list("wt", ["white", "teal"])
    cmap_overlay = colors.LinearSegmentedColormap.from_list(
        "wtb", ["white", "teal", "darkslategrey"]
    )

    # show some examples
    image, mask = seg_dataset[idx]

    # -- input image
    print(image.shape, mask.shape)
    plt.imshow(image[0, ...] > 0, cmap=cmap_base)
    plt.axis("off")
    plt.savefig(save_prefix + "input_image.png")
    plt.savefig(save_prefix + "input_image.pdf")
    plt.show()

    # -- ground truth mask
    plt.imshow(mask[0, :, :], cmap=cmap_base)
    plt.axis("off")
    plt.savefig(save_prefix + "gt_mask.png")
    plt.savefig(save_prefix + "gt_mask.pdf")
    plt.show()

    # -- inference, probs
    res = model(pc_img=image[None, ...], pc_np=None, metadata=None).detach().numpy()
    plt.imshow(res[0, 0, ...], cmap=cmap_base)
    plt.axis("off")
    plt.savefig(save_prefix + "inference_probs.png")
    plt.savefig(save_prefix + "inference_probs.pdf")
    plt.show()

    # -- inference, binary
    threshold = 0.7
    plt.imshow(fill_holes((res[0, 0, ...] > threshold).astype(np.uint8)), cmap=cmap_base)
    plt.axis("off")
    plt.savefig(save_prefix + "inference_bin.png")
    plt.savefig(save_prefix + "inference_bin.pdf")
    plt.show()

In [None]:
# set up paths for saving
fig_dir = "figures/parametric/case"
os.makedirs(fig_dir, exist_ok=True)

In [None]:
# make plots over resolution
idx_frame = 0
width = 8
depth = 4
device = "cpu"
for resol in [64, 128, 256, 512]:
    save_prefix = os.path.join(fig_dir, f"parametric_resol_{resol}_")
    model, dataset = get_model_dataset(width, depth, resol, device)
    show_results(idx_frame, model, dataset, save_prefix=save_prefix)