In [None]:
import numpy as np
import torch
import sys, os

sys.path.append("../")

from vi_rnn.saving import load_model
from vi_rnn.datasets import Basic_dataset
from vi_rnn.utils import orthogonalise_network

from matplotlib.colors import colorConverter as cc
import matplotlib.pyplot as plt

%matplotlib inline

Note that models for this sweep were not include in the public repo to keep it light. Run 

   `train_scripts/student_teacher/test_noise.py`

to obtain the models, or inquire with the authors

In [None]:
# load and eval models
# ----------------
# Particles = 64
# ----------------

directory = "../data_untracked/noise_tests/noise_test/"

directory_bs = os.fsencode(directory)

Stds10 = []
Stds05 = []
Stds20 = []

for file in os.listdir(directory_bs):
    filename = os.fsdecode(file)
    if filename.endswith("_vae_params.pkl"):
        model_name = filename.removesuffix("_vae_params.pkl")
        print(model_name)
        vae, training_params, task_params = load_model(
            directory + model_name, load_encoder=False, backward_compat=True
        )
        vae = orthogonalise_network(vae)
        if task_params["R_z"] == 0.05:
            Stds05.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())
        elif task_params["R_z"] == 0.1:
            Stds10.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())
        elif task_params["R_z"] == 0.2:
            Stds20.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())

In [None]:
from vi_rnn.initialize_parameterize import full_cov_embed

In [None]:
# load and eval models

# ----------------
# Particles = 10
# ----------------

directory = "../data_untracked/noise_tests/noise_test10/"

directory_bs = os.fsencode(directory)

Stds1010 = []
Stds0510 = []
Stds2010 = []

for file in os.listdir(directory_bs):
    filename = os.fsdecode(file)
    if filename.endswith("_vae_params.pkl"):
        model_name = filename.removesuffix("_vae_params.pkl")
        vae, training_params, task_params = load_model(
            directory + model_name, load_encoder=False, backward_compat=True
        )
        vae = orthogonalise_network(vae)

        if task_params["R_z"] == 0.05:
            Stds0510.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())
        elif task_params["R_z"] == 0.1:
            Stds1010.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())
        elif task_params["R_z"] == 0.2:
            Stds2010.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())

In [None]:
# load and eval models

# ----------------
# Particles = 1
# ----------------

directory = "../data_untracked/noise_tests/noise_test1/"

directory_bs = os.fsencode(directory)

Stds101 = []
Stds051 = []
Stds201 = []
i = 0
for file in os.listdir(directory_bs):
    filename = os.fsdecode(file)
    if filename.endswith("_vae_params.pkl"):
        i += 1
        print(i)
        model_name = filename.removesuffix("_vae_params.pkl")
        print(model_name)
        vae, training_params, task_params = load_model(
            directory + model_name, load_encoder=False, backward_compat=True
        )
        vae = orthogonalise_network(vae)

        if task_params["R_z"] == 0.05:
            Stds051.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())
        elif task_params["R_z"] == 0.1:
            Stds101.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())
        elif task_params["R_z"] == 0.2:
            Stds201.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())

In [None]:
# load and eval models

# ----------------
# Particles = 64 + Bootstrap Sampling
# ----------------

directory = "../data_untracked/noise_tests/noise_test_bs/"

directory_bs = os.fsencode(directory)

Stds10bs = []
Stds05bs = []
Stds20bs = []
i = 0
for file in os.listdir(directory_bs):
    filename = os.fsdecode(file)
    if filename.endswith("_vae_params.pkl"):
        i += 1
        print(i)
        model_name = filename.removesuffix("_vae_params.pkl")
        print(model_name)
        print(training_params["resample"])
        vae, training_params, task_params = load_model(
            directory + model_name, load_encoder=False, backward_compat=True
        )
        vae = orthogonalise_network(vae)
        print(training_params["k"], training_params["loss_f"])
        if task_params["R_z"] == 0.05:
            Stds05bs.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())
        elif task_params["R_z"] == 0.1:
            Stds10bs.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())
        elif task_params["R_z"] == 0.2:
            Stds20bs.append(vae.rnn.std_embed_z(vae.rnn.R_z).detach().numpy())

In [None]:
n = 5

In [None]:
dat05 = np.array(Stds05)[:n].flatten()
dat10 = np.array(Stds10)[:n].flatten()
dat20 = np.array(Stds20)[:n].flatten()
dat0510 = np.array(Stds0510)[:n].flatten()
dat1010 = np.array(Stds1010)[:n].flatten()
dat2010 = np.array(Stds2010)[:n].flatten()
dat051 = np.array(Stds051)[:n].flatten()
dat101 = np.array(Stds101)[:n].flatten()
dat201 = np.array(Stds201)[:n].flatten()
dat05bs = np.array(Stds05bs)[:n].flatten()
dat10bs = np.array(Stds10bs)[:n].flatten()
dat20bs = np.array(Stds20bs)[:n].flatten()

In [None]:
# Plot box plots
alpha = 0.5
plt_colors = ["teal"] * 6
fig, ax = plt.subplots(1, 1, figsize=(2, 1), dpi=300)

c1 = "darkturquoise"
c2 = "teal"
c3 = "darkslategray"
c4 = "slateblue"
pos_64 = [0, 4, 8]
pos_10 = [1, 5, 9]
pos_1 = [2, 6, 10]
pos_bs = [3, 7, 11]

sep = 0.5
pos_64[1] += sep
pos_10[1] += sep
pos_1[1] += sep
pos_bs[1] += sep
pos_64[2] += 2 * sep
pos_10[2] += 2 * sep
pos_1[2] += 2 * sep
pos_bs[2] += 2 * sep


ax.boxplot(
    [dat05, dat10, dat20],
    positions=pos_64,
    widths=0.6,
    patch_artist=True,
    boxprops=dict(facecolor=cc.to_rgba(c1, alpha=alpha), color=c1),
    capprops=dict(color=c1),
    whiskerprops=dict(color=c1),
    medianprops=dict(color=c1),
    flierprops={
        "marker": "o",
        "markersize": 1,
        "markerfacecolor": c1,
        "markeredgecolor": c1,
    },
    label="k=64",
    zorder=10,
)

ax.boxplot(
    [dat0510, dat1010, dat2010],
    positions=pos_10,
    widths=0.6,
    patch_artist=True,
    boxprops=dict(facecolor=cc.to_rgba(c2, alpha=alpha), color=c2),
    capprops=dict(color=c2),
    whiskerprops=dict(color=c2),
    medianprops=dict(color=c2),
    flierprops={
        "marker": "o",
        "markersize": 1,
        "markerfacecolor": c2,
        "markeredgecolor": c2,
    },
    label="k=10",
    zorder=5,
)

ax.boxplot(
    [dat051, dat101, dat201],
    positions=pos_1,
    widths=0.6,
    patch_artist=True,
    boxprops=dict(facecolor=cc.to_rgba(c3, alpha=alpha), color=c3),
    capprops=dict(color=c3),
    whiskerprops=dict(color=c3),
    medianprops=dict(color=c3),
    flierprops={
        "marker": "o",
        "markersize": 1,
        "markerfacecolor": c3,
        "markeredgecolor": c3,
    },
    label="k=1",
)

ax.boxplot(
    [dat05bs, dat10bs, dat20bs],
    positions=pos_bs,
    widths=0.6,
    patch_artist=True,
    boxprops=dict(facecolor=cc.to_rgba(c4, alpha=alpha), color=c4),
    capprops=dict(color=c4),
    whiskerprops=dict(color=c4),
    medianprops=dict(color=c4),
    flierprops={
        "marker": "o",
        "markersize": 1,
        "markerfacecolor": c4,
        "markeredgecolor": c4,
    },
    label="Bootstrap (k=64)",
)

# labels
ax.set_xticks([])
# ax.set_ylim(0, 0.3)
ax.set_yticks([0.05, 0.1, 0.2])
ax.set_yticklabels([".05", ".1", ".2"])
ax.set_ylabel(r"student $\sigma$")
ax.spines[["bottom"]].set_visible(False)

# teacher sigma
lw = 2
ct = "violet"
lb = 1
dashes = (2, 0.5)
ax.plot(
    [1.5 - lw, 1.5 + lw],
    [0.05, 0.05],
    color=ct,
    zorder=-10,
    label=r"teacher $\sigma$",
    lw=lb,
    ls="--",
    dashes=dashes,
)
ax.plot(
    [5.5 + sep - lw, 5.5 + sep + lw],
    [0.1, 0.1],
    color=ct,
    zorder=-10,
    lw=lb,
    ls="--",
    dashes=dashes,
)
ax.plot(
    [9.5 + sep * 2 - lw, 9.5 + sep * 2 + lw],
    [0.2, 0.2],
    color=ct,
    zorder=-10,
    lw=lb,
    ls="--",
    dashes=dashes,
)

# legend
legend = ax.legend(
    loc="upper right",
    bbox_to_anchor=(1.65, 1),
    fontsize=6,
    handlelength=0.55,
    handleheight=0.5,
)
legend_colors = [c1, c2, c3, c4, ct]
for text, color in zip(legend.get_texts(), legend_colors):
    text.set_color(color)

plt.savefig("../figures/noise_comps.svg")
plt.savefig("../figures/noise_comps.png")