# Model testing via Lightning checkpoints

In [None]:
from functools import partial
import torch
import torchvision

batch_size = 64

dataset_split = "mnist"
dataset = partial(torchvision.datasets.EMNIST, split=dataset_split)
data_test = dataset(
    root="../../data",
    train=False,
    download=True,
)
dataset_test = torch.utils.data.TensorDataset(
    (data_test.data / 255).unsqueeze(1).flatten(start_dim=1).cuda(),
    torch.nn.functional.one_hot(data_test.targets).to(torch.float32).cuda(),
)

loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=batch_size, shuffle=False, drop_last=True
)

In [None]:
import wandb
import lightning
from tqdm.notebook import tqdm

import sys
sys.path.insert(0, '..')
from custom_callbacks.time_to_convergence_callback import TimeToConvergenceCallback
from deq_modules.onematrix.litonematrixdeq import LitOneMatrixDEQ
from deq_modules.even_odd.litevenodddeq import LitEvenOddDEQ

api = wandb.Api()

logger = lightning.pytorch.loggers.WandbLogger(
    project="HopDEQ",
    entity="hopfield",
    mode='disabled'
    )

trainer = lightning.Trainer(
    accelerator="gpu",
    devices=1,
    logger=logger,
    callbacks=[
        TimeToConvergenceCallback(),
    ],
    enable_progress_bar=False,
)

def runs_from_sweep(sweep):
    all_results = {}

    runs = sweep.runs
    for run in tqdm(runs):
        # .name is the human-readable name of the run.
        print(run.name, run.id)

        # .config contains the hyperparameters        
        key = (run.config["HAM"], run.config["EvenOdd"], run.config["AA"])

        factor = 1

        if run.config["AA"]:
            deq_kwargs = dict(
                forward_kwargs=dict(
                    solver="anderson",
                    iter=40 * factor,
                ),
                backward_kwargs=dict(
                    solver="anderson",
                    iter=8 * factor,
                    method="backprop",
                ),
                damping_factor=1.-1./factor,
            )
        else:
            deq_kwargs = dict(
                forward_kwargs=dict(
                    solver="picard",
                    iter=40 * factor,
                ),
                backward_kwargs=dict(
                    solver="picard",
                    iter=8 * factor,
                    method="backprop",
                ),
                damping_factor=1.-1./factor,
            )

        config = dict(
            batch_size=batch_size,
            lr=0.01,
            onematrix_dims=[784, 512, 10],
            deq_kwargs=deq_kwargs,
            ham = run.config["HAM"],
        )

        if run.config["EvenOdd"]:
            hop = LitEvenOddDEQ(**config)
        else:
            hop = LitOneMatrixDEQ(**config)

        model_path = f"../HopDEQ/{run.id}/checkpoints/epoch=9-step=9370.ckpt"
        test_results = trainer.test(model=hop, ckpt_path=model_path, dataloaders=loader_test, verbose=False)

        if key not in all_results:
            all_results[key] = [test_results]
        else:
            all_results[key].append(test_results)
    
    return all_results

In [None]:
import pickle

all_results_ham = runs_from_sweep(api.sweep("hopfield/HopDEQ/sweeps/9ixfbc41"))
extra_results_ham = runs_from_sweep(api.sweep("hopfield/HopDEQ/sweeps/o41nn6c8"))

all_results_ham = {k:v+extra_results_ham[k] for k,v in all_results_ham.items()}

In [None]:
import numpy as np

def latex_print(all_results):
    outputs = []
    for (ham, eo, deq), dicts in all_results.items():
        model = 'HAM' if ham else 'Hop'
        if eo:
            model += '-EO'
        if deq:
            model += '-DEQ'
        
        test_acc = np.array([d[0]['test_acc'] for d in dicts])
        t2c = np.array([d[0]['Time to convergence'] for d in dicts])

        outputs.append(f"{model} & {t2c.mean():.3f} ($\pm${t2c.std():.3f}) & {100*test_acc.mean():.1f}\% ($\pm${100*test_acc.std():.1f}\%)\\\\")

    for s in sorted(outputs):
        print(s)

latex_print(all_results_ham)