# MNIST experiment prototype

In [None]:
import torch
import numpy as np

In [None]:
import json

from pkg_resources import resource_stream

with resource_stream("cplxpaper.mnist", "template.json") as fin:
    options = json.load(fin)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
from cplxpaper.auto import auto

In [None]:
np.random.randint(0x7fff_ffff)

In [None]:
options["device"] = "cuda:1"

In [None]:
options["datasets"] = {
    'mnist-train': {
        'cls': "<class 'cplxpaper.mnist.dataset.FashionMNIST_Train'>",
        'root': './data',
        'train_size': 10000,
        'random_state': 1_425_950_960,
    },
    'mnist-test': {
        'cls': "<class 'cplxpaper.mnist.dataset.FashionMNIST_Test'>", 'root': './data'
    }
}

In [None]:
options["features"] = {
    'cls': "<class 'cplxpaper.auto.feeds.FeedFourierFeatures'>",
    'signal_ndim': 2,
    'cplx': True,
    'shift': True,
#     'cls': "<class 'cplxpaper.auto.feeds.FeedRawFeatures'>"
}

In [None]:
options["model"] = {
    "cls": "<class 'cplxpaper.mnist.models.complex.TwoLayerDenseModel'>",
    "n_outputs": 10
}
options["stages"]['sparsify']["model"]["cls"] = "<class 'cplxpaper.mnist.models.complex.TwoLayerDenseModelARD'>"
options["stages"]['fine-tune']["model"]["cls"] = "<class 'cplxpaper.mnist.models.complex.TwoLayerDenseModelMasked'>"

In [None]:
options["stages"]['fine-tune']['reset'] = False

In [None]:
options["stages"]['sparsify']['n_epochs'] = 200
options["stages"]['sparsify']['objective']['kl_div'] = 1.

In [None]:
# options

In [None]:
snapshots = auto.run(options, './test', "emnist-wide")

In [None]:
from cplxpaper.auto.utils import load_snapshot

def get_terms_from_snapshots(*snapshots):
    history = {"early_history": []}
    for snapshot in snapshots:
        state = load_snapshot(snapshot)

        # extract train loss history
        for k, v in state["history"].items():
            history.setdefault(k, []).append(v)

        # extract history
        history["early_history"].append(state["early_history"])

    return {k: np.concatenate(v, axis=0) for k, v in history.items()}

In [None]:
history = get_terms_from_snapshots(*snapshots)

fig, ax = plt.subplots(1, 1, figsize=(16, 5))
for name, values in history.items():
    if name == "early_history": continue
#     if name == "kl_div":
#         values = values - values.min()

    ax.semilogy(values, label=name, alpha=0.6)

In [None]:
load_snapshot(snapshots[0])["performance"]

In [None]:
load_snapshot(snapshots[-1])["performance"]

In [None]:
assert False

<br>