#### Place to create loss table in

In [1]:
import torch

from tqdm import tqdm

from experiments.experiment import Experiment
from experiments.parser import get_parser
from rvae.variational_inference.train import test_vae, test_rvae
from rvae.utils.paths import SAVED_MODELS_PATH

In [None]:
BATCH_SIZE = 128

model_types = ['VAE', 'RVAE']
latent_dims = [2, 5, 10]
seeds = [0, 42, 100]

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
losses = {}

parser = get_parser()

for model_type in model_types:
    for latent_dim in latent_dims:
        for seed in seeds:
            args = parser.parse_args(
                args=[
                    "--model",
                    model_type,
                    "--dataset",
                    "fmnist",
                    "--enc_layers",
                    "300",
                    "300",
                    "--dec_layers",
                    "300",
                    "300",
                    "--latent_dim",
                    f"{latent_dim}",
                    "--num_centers",
                    "350",
                    "--num_components",
                    "1",
                    "--device",
                    "cpu",
                    "--ckpt_path",
                    f"{SAVED_MODELS_PATH}/d{latent_dim}/{seed}/{model_type}/{"fmnist_K1epoch100.ckpt" if model_type == "VAE" else "fmnist_epoch100ckpt"}",
                ]
            )
            exp = Experiment(args)
            exp.load_just_model(pretrained_path=args.ckpt_path)
            exp.model._mean_warmup = False
            losses[model_type] = losses.get(model_type, {})
            losses[model_type][latent_dim] = losses[model_type].get(latent_dim, {})
            if model_type == "VAE":
                losses[model_type][latent_dim][seed] = {key: value for key, value in zip(["test_loss", "test_rec", "test_kld"], test_vae(
                    test_loader=exp.test_loader, b_sz=BATCH_SIZE, model=exp.model, device=device))}
            else:
                losses[model_type][latent_dim][seed] = {key: value for key, value in zip(["test_loss", "test_rec", "test_kld"],
                                                                                       test_rvae(test_loader=exp.test_loader, batch_size=BATCH_SIZE, model=exp.model, device=device))}
            

DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_save_dir)=True
DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_save_dir)=True
DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_save_dir)=True
DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_save_dir)=True
DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_save_dir)=True
DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_save_dir)=True
DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_save_dir)=True
DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_save_dir)=True
DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_save_dir)=True
DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_save_dir)=True
DEBUG: self.rvae_save_dir='../saved_models/RVAE/' | os.path.exists(self.rvae_sav

---

In [22]:
losses

{'VAE': {2: {'test': {'test_loss': tensor(-296.4922),
    'test_rec': tensor(-299.3292),
    'test_kld': tensor(2.8371)}},
  5: {'test': {'test_loss': tensor(-333.7895),
    'test_rec': tensor(-337.3924),
    'test_kld': tensor(3.6030)}},
  10: {'test': {'test_loss': tensor(-332.2371),
    'test_rec': tensor(-335.8943),
    'test_kld': tensor(3.6573)}}},
 'RVAE': {2: {'test': (tensor(-471.1065), tensor(-478.0346), tensor(6.9247))},
  5: {'test': (tensor(-558.5475), tensor(-572.1154), tensor(13.5679))},
  10: {'test': (tensor(-624.0549), tensor(-646.8359), tensor(22.7810))}}}

In [21]:
torch.save(losses, "losses.pt")

---

In [66]:
vae_flat_losses = {}
for key_a, val_a in losses["VAE"].items():
    vae_flat_losses[("VAE", key_a)] = torch.stack([v for v in val_a["test"].values()])

In [72]:
rvae_flat_losses = {}
for key_a, val_a in losses["RVAE"].items():
    rvae_flat_losses[("RVAE", key_a)] = torch.stack(val_a["test"])

In [73]:
rvae_flat_losses, vae_flat_losses

({('RVAE', 2): tensor([-471.1065, -478.0346,    6.9247]),
  ('RVAE', 5): tensor([-558.5475, -572.1154,   13.5679]),
  ('RVAE', 10): tensor([-624.0549, -646.8359,   22.7810])},
 {('VAE', 2): tensor([-296.4922, -299.3292,    2.8371]),
  ('VAE', 5): tensor([-333.7895, -337.3924,    3.6030]),
  ('VAE', 10): tensor([-332.2371, -335.8943,    3.6573])})

In [74]:
import pandas as pd

df = pd.concat(
    [pd.DataFrame(data=vae_flat_losses), pd.DataFrame(data=rvae_flat_losses)], axis=1
)

In [77]:
new_df = df.T
new_df.columns = ["test_loss", "test_rec", "test_kld"]

In [78]:
new_df

Unnamed: 0,Unnamed: 1,test_loss,test_rec,test_kld
VAE,2,-296.492218,-299.329163,2.8371
VAE,5,-333.78949,-337.392395,3.60296
VAE,10,-332.237061,-335.894318,3.657283
RVAE,2,-471.106476,-478.034607,6.924698
RVAE,5,-558.547546,-572.115417,13.567932
RVAE,10,-624.054932,-646.835876,22.781002


In [57]:
new_df = df.T.reset_index()
new_df

Unnamed: 0,level_0,level_1,level_2,0
0,VAE,2,test_loss,-296.492218
1,VAE,2,test_rec,-299.329163
2,VAE,2,test_kld,2.8371
3,VAE,5,test_loss,-333.78949
4,VAE,5,test_rec,-337.392395
5,VAE,5,test_kld,3.60296
6,VAE,10,test_loss,-332.237061
7,VAE,10,test_rec,-335.894318
8,VAE,10,test_kld,3.657283
9,RVAE,2,test_loss,-471.106476


In [47]:
df.T

Unnamed: 0,Unnamed: 1,Unnamed: 2,0
VAE,2,test_loss,-296.492218
VAE,2,test_rec,-299.329163
VAE,2,test_kld,2.8371
VAE,5,test_loss,-333.78949
VAE,5,test_rec,-337.392395
VAE,5,test_kld,3.60296
VAE,10,test_loss,-332.237061
VAE,10,test_rec,-335.894318
VAE,10,test_kld,3.657283


In [33]:
# losses_df = {
#     (outer, inner, inner2): arr1
#     for outer, inner_dict in losses.items()
#     for inner, arr in inner_dict.items()
#     for inner2, arr1 in arr["test"].items()
# }

In [9]:
import pandas as pd

df = pd.DataFrame(data=flat_f1)

ValueError: Mixing dicts with non-Series may lead to ambiguous ordering.

In [None]:
df.T

In [None]:
print(df.T.to_latex())