In [None]:
import wild_visual_navigation.visu.paper_colors as pc
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

%matplotlib inline

WORKING_DIR="/media/matias/datasets/2023-01-19-wvn-uni-parks/wvn_data_dump/2023-01-19T14-17-58_test"

In [None]:
samples_directory = Path(os.path.join(WORKING_DIR, "confidence_generator"))

# Check confidence generator folder
if not samples_directory.exists():
  raise FileNotFoundError(f"Confidence generator path selected does not exists: {confidence_generator_path}")

# Load all the .pt files
steps = []
all_samples = []
positive_samples = []
gaussian_mean = []
gaussian_std = []

for s in sorted(samples_directory.rglob("*.pt")):
  data = torch.load(s)

  # Get step
  step = int(s.stem.replace("samples_", ""))

  # Save step
  steps.append(step)
  all_samples.append(data["x"].numpy())
  positive_samples.append(data["x_positive"].numpy())
  gaussian_mean.append(data["mean"].item())
  gaussian_std.append(data["std"].item())

steps = np.asarray(steps)
gaussian_mean = np.asarray(gaussian_mean)
gaussian_std = np.asarray(gaussian_std)

In [None]:
# Plot distribution over time
# samples as scatterplot
# Positive Gaussian as GP (mean + std)

# Plot stuff
fig_scale = 0.5
fig, axs = plt.subplots(1, constrained_layout=True, figsize=(fig_scale * 13, fig_scale * 9))

for s in range(len(steps)):
  step = steps[s]

  N = len(all_samples[s])
  axs.scatter(step*np.ones(N), all_samples[s], color=(0.2, 0.2, 0.2), alpha=0.05, linewidths=0) # axs.set_xlabel("Steps")

  M = len(positive_samples[s])
  axs.scatter(step*np.ones(M), positive_samples[s], color=pc.paper_colors_rgb_f["blue"], alpha=0.3, linewidths=0) # label="positive_samples"

std_factor = 1.5
axs.fill_between(
    steps,
    gaussian_mean - std_factor*gaussian_std,
    gaussian_mean + std_factor*gaussian_std,
    color=pc.darken(pc.paper_colors_rgb_f["blue"], 0.3),
    alpha=0.2,
    linewidth=0
)
axs.plot(steps, gaussian_mean, label="Gaussian approximation", color=pc.darken(pc.paper_colors_rgb_f["blue"], 0.3), linewidth=2)
axs.set_xlabel("Steps")
axs.set_ylabel("Reconstruction loss")
axs.legend()