In [None]:
import matplotlib
import pickle
from wild_visual_navigation import WVN_ROOT_DIR
import os
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# prev learning_curve
base = "results/ablations/time_adaptation-1000_ge76"

with open(os.path.join(WVN_ROOT_DIR, base, "time_adaptation-1000_steps.pkl"), "rb") as f:
    runs = pickle.load(f)
scene = "forest"

In [None]:
store_every_n_steps = 1
nr_runs = int(np.array([run["run"] for run in runs]).max()) + 1
max_step = np.array([run["steps"] for run in runs]).max()
max_steps = int(max_step / store_every_n_steps) + 1

y_names = [
    "test_auroc_gt_image",
    "test_auroc_self_image",
    "test_auroc_anomaly_gt_image",
    "test_auroc_anomaly_self_image",
    "test_acc_gt_image",
    "test_acc_self_image",
    "test_acc_anomaly_gt_image",
    "test_acc_anomaly_self_image",
    "trainer_logged_metrictest_loss_reco",
    "trainer_logged_metrictest_loss_trav",
    "trainer_logged_metrictest_loss",
]

y = {k: np.zeros((max_steps, nr_runs)) for k in y_names}
x_steps = np.arange(0, max_step + store_every_n_steps, store_every_n_steps)

model_paths = ["nan"] * max_steps
for run in runs:
    step = int(run["steps"] / store_every_n_steps)
    for y_name in y_names:
        try:
            y[y_name][step, run["run"]] = run["results"][scene][y_name]
        except:
            pass
    model_paths[step] = run["model_path"]

y["trainer_logged_metrictest_loss_trav"] *= 16.666666

In [None]:
###################################### AUROC ####################################################
import seaborn as sns
import matplotlib.pyplot as plt
from wild_visual_navigation.visu import paper_colors_rgb_u8, paper_colors_rgba_u8
from wild_visual_navigation.visu import paper_colors_rgb_f, paper_colors_rgba_f

width_singel_inch, width_double_inch = 88.9 / 25.4, 182.0 / 25.4
height_inch = 82 / 25.4
scale = 2
sns.set_style("darkgrid")
plt.rcParams.update({"font.size": 16})

fig = plt.figure(figsize=(width_double_inch * scale, height_inch * scale), dpi=300)
gs = fig.add_gridspec(nrows=3, ncols=8, left=0.035, right=0.98, top=0.97, wspace=0.9, hspace=0.0)
gs = fig.add_gridspec(nrows=3, ncols=8, left=0.035, right=0.98, top=0.97, wspace=0.9, height_ratios=[1.4, 1, 1])
ax0 = fig.add_subplot(gs[0, 0:4])
ax1 = fig.add_subplot(gs[0, 4:])

# fig, ax = plt.subplots(1, 1, figsize=(width_singel_inch*scale, height_inch*scale), dpi=300)
ax0.tick_params(axis="both", which="major", labelsize=12)
ax1.tick_params(axis="both", which="major", labelsize=12)


fig.set_tight_layout(True)


def plot_auroc(keys, x_steps, y, lim_min, lim_max, ax, y_tags, y_axis_labels, two_axis=False):
    x = [x_steps for k in keys]
    y_mean = [y[k].mean(axis=1) for k in keys]
    y_lower = [y[k].mean(axis=1) - y[k].std(axis=1) for k in keys]
    y_upper = [y[k].mean(axis=1) + y[k].std(axis=1) for k in keys]
    # not used
    ax_ori = ax

    for j, (_x, _y, _y_lower, _y_upper, _y_tag) in enumerate(zip(x, y_mean, y_lower, y_upper, y_tags)):
        k = [k for k in paper_colors_rgb_f.keys()][j]
        if two_axis and j == 1:
            ax.plot([0], [0], label=_y_tag, color=paper_colors_rgb_f[k])
            ax.legend()

            ax.set_ylabel(y_axis_labels[j])

        if two_axis and j == 1:
            ax = ax.twinx()

        ax.plot(_x, _y, label=_y_tag, color=paper_colors_rgb_f[k])
        if not (_y_lower is None):
            ax.plot(_x, _y_lower, color=paper_colors_rgb_f[k + "_light"], alpha=0.1)
            ax.plot(_x, _y_upper, color=paper_colors_rgb_f[k + "_light"], alpha=0.1)
            ax.fill_between(_x, _y_lower, _y_upper, color=paper_colors_rgb_f[k + "_light"], alpha=0.2)

        if not (two_axis and j == 1):
            ax.legend()

        ax.tick_params(axis="both", which="major", labelsize=12)

    # ax.plot(np.linspace(0, 1, 100), np.linspace(0, 1, 100), linestyle="--", color="gray")

    ax_ori.set_xlabel("Training Steps")
    ax_ori.spines["top"].set_visible(False)
    ax_ori.spines["right"].set_visible(False)
    label_x = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    ax_ori.set_xticks([0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000])
    ax_ori.set_xticklabels(label_x)

    if not two_axis:
        ax_ori.set_ylabel(y_axis_labels[0])

    # plt.xlim(0, 1)
    ax_ori.set_ylim(lim_min, lim_max)


# plot_auroc(
#     ["test_auroc_gt_image", "test_auroc_self_image"],
#     x_steps,
#     y,
#     0.6,
#     1,
#     ax0,
#     ["GT AUROC", "SELF AUROC"],
#     y_axis_labels=["AUROC"],
# )
plot_auroc(
    ["trainer_logged_metrictest_loss_reco", "trainer_logged_metrictest_loss_trav"],
    x_steps,
    y,
    None,
    None,
    ax0,
    ["Reco", "Trav"],
    y_axis_labels=["Loss"],
)

plot_auroc(
    ["test_acc_gt_image", "test_acc_self_image"],
    x_steps,
    y,
    None,
    None,
    ax1,
    ["GT Acc", "SELF Acc"],
    y_axis_labels=["Accuracy", "Self Acc"],
    two_axis=False,
)


###################################### Images ####################################################


# Setup dataloader
from wild_visual_navigation.learning.dataset import get_ablation_module
from wild_visual_navigation.learning.utils import load_env
from wild_visual_navigation import WVN_ROOT_DIR
from wild_visual_navigation.cfg import ExperimentParams

from wild_visual_navigation.visu import LearningVisualizer
from PIL import Image
from wild_visual_navigation.learning.model import get_model
from dataclasses import asdict
import torch

with open(os.path.join(WVN_ROOT_DIR, base, "experiment_params.pkl"), "rb") as f:
    exp = pickle.load(f)
exp.ablation_data_module

env = load_env()

exp.ablation_data_module
ablation_data_module = {
    "batch_size": 1,
    "num_workers": 0,
    "env": scene,
    "feature_key": exp.ablation_data_module.feature_key,
    "test_equals_val": False,
    "val_equals_test": False,
    "test_all_datasets": False,
    "training_data_percentage": 100,
    "training_in_memory": False,
}
train_loader, val_loader, test_loader = get_ablation_module(**ablation_data_module, perugia_root=env["perugia_root"])
test_scenes = [a.dataset.env for a in test_loader]
test_all_datasets = True


def load_model(model_cfg: ExperimentParams.ModelParams, checkpoint_path: str):
    model = get_model(asdict(model_cfg))
    ckpt = torch.load(checkpoint_path)
    ckpt = {
        k.replace("_model.", ""): v
        for k, v in ckpt.items()
        if k.find("_traversability") == -1 and k.find("threshold") == -1
    }
    res = model.load_state_dict(ckpt, strict=True)
    return model


fontdict = {
    "fontsize": 16,
    "fontweight": plt.rcParams["axes.titleweight"],
    "verticalalignment": "baseline",
    "horizontalalignment": "center",
}

if True:
    for l, img_nr in enumerate([20, 32]):
        if l == 0:
            gs2 = fig.add_gridspec(nrows=3, ncols=8, left=0.06, right=0.97, top=1.0, wspace=0.1, hspace=0.1)
            gs2 = fig.add_gridspec(
                nrows=3, ncols=8, left=0.03, right=0.98, top=1.0, wspace=0.1, height_ratios=[1.7, 3, 0.1]
            )

        if l == 1:
            gs2 = fig.add_gridspec(nrows=3, ncols=8, left=0.06, right=0.97, top=1.0, wspace=0.1, hspace=0.1)
            gs2 = fig.add_gridspec(
                nrows=3, ncols=8, left=0.03, right=0.98, top=1.0, bottom=0, wspace=0.1, height_ratios=[1.3, 2.15, 2.0]
            )

        thresholds = []
        models = [0, 99, 249, 499, 999]
        for j, training_step in enumerate(models):
            model = load_model(exp.model, model_paths[training_step])
            model.to("cuda")
            graph = test_loader[0].dataset[img_nr]
            pred = model(graph)
            visualizer = LearningVisualizer()
            img = graph.img
            seg = graph.seg
            res = visualizer.plot_detectron(img[0], graph.label[0].type(torch.long), not_log=True, max_seg=2)
            center = graph.center

            threshold = torch.load(model_paths[training_step])["threshold"].item()
            thresholds.append(threshold)
            traversability = pred[:, 0]
            m = traversability < threshold
            # Scale untraversable
            traversability[m] *= 0.5 / threshold
            # Scale traversable
            traversability[~m] -= threshold
            traversability[~m] *= 0.5 / (1 - threshold)
            traversability[~m] += 0.5
            traversability = traversability.clip(0, 1)

            buffer_traversability = graph.seg.clone().type(torch.float32).flatten()
            BS, H, W = graph.seg.shape
            seg_pixel_index = (graph.seg).flatten()
            buffer_traversability = traversability[seg_pixel_index].reshape(BS, H, W)
            res_pred = visualizer.plot_detectron_classification(img[0], buffer_traversability, not_log=True)

            # res_pred = visualizer.plot_traversability_graph_on_seg(
            #    traversability, seg[0], graph, center, img[0], not_log=True, colorize_invalid_centers=True
            # )

            ax = fig.add_subplot(gs2[1 + l, j + 2])
            ax.imshow(res_pred)
            ax.axis("off")
            step = int((training_step + 1) * 1)
            if l == 1:
                ax.set_title(f"Step-{step}", fontdict=fontdict, loc="center", y=0, pad=-16)

        # graph.seg[0]
        res_img = visualizer.plot_detectron(
            img[0],
            graph.label[0].type(torch.long) * 99,
            not_log=True,
            max_seg=100,
            colormap="RdYlBu",
            boundary_seg=None,
            alpha=0,
        )
        res_gt = visualizer.plot_detectron(
            img[0],
            graph.label[0].type(torch.long) * 99,
            not_log=True,
            max_seg=100,
            colormap="RdYlBu",
            boundary_seg=None,
        )
        res_prop = visualizer.plot_traversability_graph_on_seg(
            graph.y, seg[0], graph, center, img[0], not_log=True, colorize_invalid_centers=True
        )

        ax0 = fig.add_subplot(gs2[1 + l, 0])
        ax1 = fig.add_subplot(gs2[1 + l, 1])
        ax7 = fig.add_subplot(gs2[1 + l, 7])

        res = ax0.text(-110, 370, s=f"Example {l+1}", fontsize=16, rotation=90)

        if l == 1:
            ax0.set_title("Image", fontdict=fontdict, loc="center", y=0, pad=-16)
            ax1.set_title("Supervision", fontdict=fontdict, loc="center", y=0, pad=-16)
            ax7.set_title("Label", fontdict=fontdict, loc="center", y=0, pad=-16)

        ax0.axis("off")
        ax1.axis("off")
        ax7.axis("off")

        ax0.imshow(res_img)
        ax1.imshow(res_prop)
        ax7.imshow(res_gt)

fig.set_tight_layout(False)
fig.savefig("/tmp/img.png", dpi=300)