In [1]:
import argparse
import gzip
import json
import matplotlib.pyplot as plt

from collections import defaultdict

from typing import Dict

from analysis_utils import Line
from analysis_utils import dict2tex


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--lr3-file",
        help="Result file with learning rate 1e-3",
        required=True,
    )
    parser.add_argument(
        "--lr4-file",
        help="Result file with learning rate 1e-4",
        required=True,
    )
    parser.add_argument(
        "-o",
        "--output",
        help="Output file to write to (.tex)",
    )
    return parser.parse_args()


def parse_full(filename: str) -> Dict[int, Dict[int, float]]:

    params = defaultdict(set)
    meta_keys_max = {"dataset": 1, "model": 1, "seed": 1}

    with gzip.open(filename, "rb") as fp:
        contents = fp.read()
        data = json.loads(contents.decode("utf-8"))

    metadata = data["meta"]
    results = data["results"]

    # Ensure we're not mixing datasets/models/seeds
    for key in meta_keys_max:
        params[key].add(metadata[key])
        assert len(params[key]) <= meta_keys_max[key]

    train_loss = results["losses"]["train"]
    test_loss = results["losses"]["test"]

    return train_loss, test_loss

In [2]:
def show_plot(epochs, lr3_train, lr3_test, lr4_train=None, lr4_test=None):
    plt.plot(epochs, lr3_train, c="tab:blue", label="$\eta = 10^{-3}$, train")
    plt.plot(
        epochs,
        lr3_test,
        c="tab:blue",
        ls="--",
        label="$\eta = 10^{-3}$, test",
    )
    # plt.plot(
    #     epochs, lr4_train, c="tab:orange", label="$\eta = 10^{-4}$, train"
    # )
    # plt.plot(
    #     epochs,
    #     lr4_test,
    #     c="tab:orange",
    #     ls="--",
    #     label="$\eta = 10^{-4}$, test",
    # )
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()


In [None]:
lr3_file = "./results/binarized_mnist_extra_lr3BinarizedMNIST_BernoulliMLPVAE_512-256-16_seed42_full_repeat0.json.gz"
lr3_train, lr3_test = parse_full(args.lr3_file)

burn = 5
L = len(lr3_train)
epochs = list(range(burn + 1, L + 1))

lr3_train = lr3_train[burn:]
lr3_test = lr3_test[burn:]

show_plot(epochs, lr3_train, lr3_test)

