In [None]:
import os
import matplotlib.pyplot as plt
import torch
from torch_geometric.transforms import ToDevice

# for font embedding things
import matplotlib

matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

fig_dir = "figures"
os.makedirs(fig_dir, exist_ok=True)

save_dir = os.path.join(fig_dir, "mc_conf")
os.makedirs(save_dir, exist_ok=True)

SMALL_SIZE = 14
MEDIUM_SIZE = 16
BIGGER_SIZE = 18

spoof_color = "#FF474C"
benign_color = "#228B22"

# rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
plt.rcParams["font.family"] = "serif"
plt.rc("font", size=SMALL_SIZE)  # controls default text sizes
plt.rc("axes", titlesize=SMALL_SIZE)  # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE)  # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title

# set up paths for saving
fig_dir = "figures"
os.makedirs(fig_dir, exist_ok=True)

# set the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

# set the data directories
data_input = "/data/shared/CARLA/multi-agent-v1"
data_output = "/data/shared/fov/fov_bev_graph"

In [None]:
from fov.graph.dataset import CarlaFieldOfViewDataset
from fov.graph.models import GATModel


# get the dataset
full_dataset = CarlaFieldOfViewDataset(
    carla_root_directory=data_input,
    graph_root_directory=data_output,
    include_infrastructure_agents=False,
    n_frames_max=1000,
    force_reload=False,
    transform=ToDevice(device=device),
)

# instantiate the model
model = GATModel(
    in_channels=-1,
    hidden_channels=256,
    num_layers=5,
    out_channels=1,
    v2=True,
).to(device)

## Visualize Models on Datasets

In [None]:
idx_show = [0, 1]
threshold = 0.7
class_colors = ["#BBF90F", "#FF796C"]
marker = "o"
alpha = 0.2
markersize = 10

for idx in idx_show:
    # run inference from the model
    t_input = full_dataset[0]
    x_input = t_input.x.cpu().detach().numpy()
    y_input = t_input.y.cpu().detach().numpy()
    v_output = model(t_input).cpu().detach().numpy()

    # classify with a threshold
    v_class = v_output >= threshold

    # visualize the result
    gt_colors = [class_colors[v] for v in y_input]
    pred_colors = [class_colors[v] for v in v_class]
    fig, axs = plt.subplots(1, 2, figsize=(10, 6))

    # -- plot the prediction
    all_colors = [pred_colors, gt_colors]
    all_titles = ["Prediction", "Ground Truth"]
    for i, (colors, title) in enumerate(zip(all_colors, all_titles)):
        axs[i].scatter(
            x_input[:, 0],
            x_input[:, 1],
            marker=marker,
            s=markersize,
            alpha=alpha,
            color=colors,
        )
        axs[i].set_title(f"Frame {idx} {title}")
    plt.show()