# Check the results of the models

In [None]:
import os


# set up paths for saving
fig_dir = "figures"
os.makedirs(fig_dir, exist_ok=True)


# specify all data directories
data_dirs = {
    "carla": "/data/shared/CARLA/multi-agent-intersection",
    "nuscenes": "/data/shared/nuScenes",
    "radnav": "/data/shared/radnav/WILK_BASEMENT",
}

fov_dirs = {
    "carla": "/data/shared/fov/fov_bev_segmentation/carla",
    "nuscenes": "/data/shared/fov/fov_bev_segmentation/nuscenes",
    "radnav": "/data/shared/fov/fov_bev_segmentation/radnav",
}

cfg_paths = {
    "carla": "../../config/segmentation/carla/unet_carla_benign.py",
    "nuscenes": "../../config/segmentation/nuscenes/unet_nuscenes_benign.py",
    "radnav": "../../config/segmentation/radnav/unet_radnav_benign.py",
}

model_dirs = {
    "carla": "../../scripts/segmentation/models/carla/unet_carla",
    "nuscenes": "../../scripts/segmentation/models/nuscenes/unet_nuscenes",
    "radnav": "../../scripts/segmentation/models/radnav/unet_radnav",
}

In [None]:
import torch
from utils import get_dataset, get_unet_model


device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# load in the datasets
dataset = "radnav"
SM, seg_dataset = get_dataset(
    device=device,
    dataset=dataset,
    data_dir=data_dirs[dataset],
    fov_dir=fov_dirs[dataset],
    is_adversarial=False,
    split="val",
)

# load in the models
model = get_unet_model(
    device=device,
    cfg_path=cfg_paths[dataset],
    model_dir=model_dirs[dataset],
)

In [None]:
import matplotlib.pyplot as plt


# get the frame index to show
idxs_show = {
    "carla": 0,
    "nuscenes": 0,
    "radnav": 10,
}
idx_show = idxs_show[dataset]

# get data
pc_img, gt_mask = seg_dataset[idxs_show[dataset]]
pc_img = torch.unsqueeze(pc_img, 0)
pc_np = seg_dataset.get_pointcloud(idxs_show[dataset])
metadata = seg_dataset.get_metadata(idxs_show[dataset])

# run inference
pred = model(pc_img, pc_np, metadata).detach().cpu().squeeze()
truth = gt_mask.detach().cpu().squeeze()

# visualize the result
cmap_binary = "gray"
cmap_conf = "plasma"
fig, axs = plt.subplots(1, 2, figsize=(10, 8))
threshold = 0.7

# -- left is gt
axs[0].imshow(truth, cmap=cmap_binary)
axs[0].tick_params(which="both", size=0, labelsize=0)
axs[0].set_title("Truth", size=20)

# -- right is inference
axs[1].imshow(pred > threshold, cmap=cmap_binary)
axs[1].tick_params(which="both", size=0, labelsize=0)
axs[1].set_title("Inference", size=20)

plt.tight_layout()
# save_path = os.path.join(
#     save_dir, "segout_{}_frame_{}.{}".format(dataset_name, idx, "{}")
# )
# plt.savefig(save_path.format("png"))
# plt.savefig(save_path.format("pdf"))
plt.show()