In [None]:
import matplotlib.pyplot as plt
import numpy as np
from omegaconf import DictConfig
import os
import torch

from vsa_ogm.dataloaders.dl_evilog import EviLogDataLoader
from spl.mapping import OGM2D_V4

In [None]:
BASE_CONFIG = DictConfig({
    "mapper": {
        "axis_resolution": 0.1, # meters
        "decision_thresholds": [-0.99, 0.99],
        "device": "cuda:0" if torch.cuda.is_available() else "cpu",
        "length_scale": 0.2,
        "quadrant_hierarchy": [1],
        "use_query_normalization": True,
        "use_query_rescaling": False,
        "verbose": True,
        "vsa_dimensions": 32000,
        "plotting": {
            "plot_xy_voxels": False
        }
    },
})
WORLD_SIZE = [0, 9, 0, 13]

In [None]:
def eval_sample(points, labels):
    mapper = OGM2D_V4(BASE_CONFIG.mapper, WORLD_SIZE, ".")

    mapper.process_observation(points, labels)

    occ_hm = mapper.xy_axis_occupied_heatmap.cpu().numpy()
    empty_hm = mapper.xy_axis_empty_heatmap.cpu().numpy()

    return occ_hm, empty_hm

In [None]:
occ_heatmaps = []
empty_heatmaps = []
input_imgs = []
target_imgs = []
points_list = []
labels_list = []

loader_config = DictConfig({
    "data_dir": "/home/ssnyde9/dev/EviLOG/model/output/2024-02-12-12-51-25/Evaluation"
})

loader = EviLogDataLoader(loader_config)

counter = 0
lidar_points, labels, input_img, target_img = loader.reset()
lidar_points *= 0.1

while counter < loader.max_steps() - 1:
    occ_hm, empty_hm = eval_sample(lidar_points, labels)

    occ_heatmaps.append(occ_hm)
    empty_heatmaps.append(empty_hm)
    input_imgs.append(input_img)
    target_imgs.append(target_img)
    points_list.append(lidar_points)
    labels_list.append(labels)

    lidar_points, labels, input_img, target_img = loader.step()
    lidar_points *= 0.1

    counter += 1

In [None]:
from skimage.filters.rank import entropy
from skimage.morphology import disk
from sklearn import metrics
from highfrost.ogm.metrics import calculate_multiple_TP_FP_rates

In [None]:
global_entropy_list = []
occ_disk = disk(2)
empty_disk = disk(4)

image_dir = os.path.join(".", "images")
os.makedirs(image_dir, exist_ok=True)

for i in range(len(occ_heatmaps)):
    observation_dir = os.path.join(image_dir, f"observation_{i}")
    os.makedirs(observation_dir, exist_ok=True)

    occ_hm = occ_heatmaps[i]
    empty_hm = empty_heatmaps[i]
    input_img = input_imgs[i]
    target_img = target_imgs[i]

    plt.imshow(occ_hm, cmap="plasma")
    plt.colorbar()
    plt.clim(-1, 1)
    plt.savefig(os.path.join(observation_dir, "occ_heatmap.png"), dpi=500)
    plt.close()

    plt.imshow(empty_hm, cmap="plasma")
    plt.colorbar()
    plt.clim(-1, 1)
    plt.savefig(os.path.join(observation_dir, "empty_heatmap.png"), dpi=500)
    plt.close()

    plt.imshow(input_img)
    plt.savefig(os.path.join(observation_dir, "input_img.png"), dpi=500)

    plt.imshow(target_img)
    plt.savefig(os.path.join(observation_dir, "target_img.png"), dpi=500)

    occ_entropy = entropy(np.square(occ_hm), occ_disk)
    empty_entropy = entropy(np.square(empty_hm), empty_disk)
    global_entropy = occ_entropy - empty_entropy

    plt.imshow(occ_entropy)
    plt.colorbar()
    plt.savefig(os.path.join(observation_dir, "occ_entropy.png"), dpi=500)
    plt.close()

    plt.imshow(empty_entropy)
    plt.colorbar()
    plt.savefig(os.path.join(observation_dir, "empty_entropy.png"), dpi=500)
    plt.close()

    plt.imshow(global_entropy)
    plt.colorbar()
    plt.savefig(os.path.join(observation_dir, "global_entropy.png"), dpi=500)
    plt.close()

    global_entropy_list.append(global_entropy)

    print(f"Saved observation {i}")

In [None]:
def calculate_preds_w_threshold_test(ge: np.ndarray, threshold: float, lidar_points: np.ndarray, labels: np.ndarray):
    points = np.copy(lidar_points)
    labels_ = np.copy(labels)

    points = points[points[:,0] >= WORLD_SIZE[0], :]
    points = points[points[:,0] <= WORLD_SIZE[1], :]
    points = points[points[:,1] >= WORLD_SIZE[2], :]
    points = points[points[:,1] <= WORLD_SIZE[3], :]
    labels_ = labels_[points[:,0] >= WORLD_SIZE[0]]
    labels_ = labels_[points[:,0] <= WORLD_SIZE[1]]
    labels_ = labels_[points[:,1] >= WORLD_SIZE[2]]
    labels_ = labels_[points[:,1] <= WORLD_SIZE[3]]  

    assert points.shape[0] == labels.shape[0]
    assert points.shape[0] > 0

    points[:, 0] -= WORLD_SIZE[0]
    points[:, 1] -= WORLD_SIZE[2]
    points /= 0.1
    points = points.astype(np.uint8)

    assert points.shape[0] == labels_.shape[0]
    assert points.shape[0] > 0

    e_values = ge[points[:, 1], points[:, 0]]

    assert e_values.shape[0] == points.shape[0]
    
    preds = np.zeros(shape=(points.shape[0]))
    preds[e_values > threshold] = 1

    assert preds.shape[0] == points.shape[0]

    return labels_, preds

In [None]:
auc_list = []
f1_list = []
precision_list = []
recall_list = []

for i in range(len(global_entropy_list)):
    global_entropy = global_entropy_list[i]

    threshold_min = np.min(global_entropy)
    threshold_max = np.max(global_entropy)
    threshold_step_size = 0.01
    threshold_range = np.arange(threshold_min, threshold_max, threshold_step_size)

    y_true: list[np.ndarray] = []
    y_pred: list[np.ndarray] = []

    for t in threshold_range:
        true, pred = calculate_preds_w_threshold_test(global_entropy, t, points_list[i], labels_list[i])
        
        y_true.append(true)
        y_pred.append(pred)

    tpr_list, fpr_list = calculate_multiple_TP_FP_rates(y_true, y_pred)
    auc = metrics.auc(fpr_list, tpr_list)

    auc_list.append(auc)

    f1_scores = []
    for i in range(len(threshold_range)):
        f1_scores.append(metrics.f1_score(y_true[i], y_pred[i]))

    best_threshold = threshold_range[np.argmax(f1_scores)]

    f1_list.append(np.max(f1_scores))
    precision_list.append(metrics.precision_score(y_true[np.argmax(f1_scores)], y_pred[np.argmax(f1_scores)]))
    recall_list.append(metrics.recall_score(y_true[np.argmax(f1_scores)], y_pred[np.argmax(f1_scores)]))

In [None]:
plt.plot(auc_list)
plt.title("AUC Scores")
print(f"Mean AUC: {np.mean(auc_list)}")

In [None]:
plt.plot(f1_list)
plt.title("F1 Scores")
print(f"Mean F1: {np.mean(f1_list)}")

In [None]:
plt.plot(precision_list)
plt.title("Precision Scores")
print(f"Mean Precision: {np.mean(precision_list)}")

In [None]:
plt.plot(recall_list)
plt.title("Recall Scores")
print(f"Mean Recall: {np.mean(recall_list)}")