# Model testing via Lightning checkpoints

In [None]:
from functools import partial
import torch
import torchvision

batch_size = 64

dataset_split = "mnist"
dataset = partial(torchvision.datasets.EMNIST, split=dataset_split)
data_test = dataset(
    root="../../data",
    train=False,
    download=True,
)
dataset_test = torch.utils.data.TensorDataset(
    (data_test.data / 255).unsqueeze(1).flatten(start_dim=1).cuda(),
    torch.nn.functional.one_hot(data_test.targets).to(torch.float32).cuda(),
)

loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=batch_size, shuffle=False, drop_last=True
)

In [None]:
import wandb
import lightning
from tqdm.notebook import tqdm

import sys
sys.path.insert(0, '..')
from custom_callbacks.track_log_relative_residual_callback import TrackLogRelativeResidualCallback
from deq_modules.onematrix.litonematrixdeq import LitOneMatrixDEQ
from deq_modules.even_odd.litevenodddeq import LitEvenOddDEQ

api = wandb.Api()

logger = lightning.pytorch.loggers.WandbLogger(
    project="HopDEQ",
    entity="hopfield",
    mode='disabled'
    )

trainer = lightning.Trainer(
    accelerator="gpu",
    devices=1,
    logger=logger,
    callbacks=[
        TrackLogRelativeResidualCallback(),
    ],
    enable_progress_bar=False,
)

def runs_from_sweep(sweep):
    all_rel_res = {}

    runs = sweep.runs
    for run in tqdm(runs):
        # .name is the human-readable name of the run.
        print(run.name, run.id)

        factor = 1

        if run.config["AA"]:
            deq_kwargs = dict(
                forward_kwargs=dict(
                    solver="anderson",
                    iter=40 * factor,
                ),
                backward_kwargs=dict(
                    solver="anderson",
                    iter=8 * factor,
                    method="backprop",
                ),
                damping_factor=1.-1./factor,
            )
        else:
            deq_kwargs = dict(
                forward_kwargs=dict(
                    solver="picard",
                    iter=40 * factor,
                ),
                backward_kwargs=dict(
                    solver="picard",
                    iter=8 * factor,
                    method="backprop",
                ),
                damping_factor=1.-1./factor,
            )

        config = dict(
            batch_size=batch_size,
            lr=0.01,
            onematrix_dims=[784, 512, 10],
            deq_kwargs=deq_kwargs,
            ham = run.config["HAM"],
        )

        if run.config["EvenOdd"]:
            hop = LitEvenOddDEQ(**config)
        else:
            hop = LitOneMatrixDEQ(**config)

        model_path = f"../HopDEQ/{run.id}/checkpoints/epoch=9-step=9370.ckpt"
        trainer.test(model=hop, ckpt_path=model_path, dataloaders=loader_test, verbose=False)
        rel_res = trainer.callbacks[0].all_rel_res

        key = (run.config["EvenOdd"], run.config["AA"])
        if key not in all_rel_res:
            all_rel_res[key] = rel_res
        else:
            all_rel_res[key].extend(rel_res)
        
    return all_rel_res

In [None]:
all_rel_res = runs_from_sweep(api.sweep("hopfield/HopDEQ/sweeps/9ixfbc41"))

In [None]:
all_rel_res2 = runs_from_sweep(api.sweep("hopfield/HopDEQ/sweeps/o41nn6c8"))

In [None]:
all_rel_res_together = {}

for key in all_rel_res:
    all_rel_res_together[key] = all_rel_res[key]+all_rel_res2[key]

In [None]:
# import pickle
# import torch #necessary to load the pickle file properly
# with open("all_rel_res_together.pickle", "rb") as file:
#     all_rel_res_together = pickle.load(file)

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
from matplotlib.colors import LogNorm

# Create a 2x2 figure
fig, axes = plt.subplots(2, 2, figsize=(10, 10))

# Create a LogNorm instance for the colorbar
norm = LogNorm()

# Loop through the dictionary and plot for each key-value pair
for i, key in enumerate([(False, False), (False, True), (True, False), (True, True)]):
    ax = axes[i // 2, i % 2]
    
    bins = 400

    # Plot 1e-4 convergence line
    ax.hlines(bins//2, -0.5, 38.5, colors='w', linestyles='dashed')

    # Plot the code for the tensor associated with the key
    with torch.inference_mode():
        log_rel_residual = torch.cat(all_rel_res_together[key], dim=1)
        hists = torch.stack([torch.histc(rr, bins=bins, min=-8, max=0) for rr in log_rel_residual])

        conv_tols = torch.arange(-8, -1.5, 0.5)
        conv_times = []
        for tol in conv_tols:
            converged_indx = torch.argmax((log_rel_residual < tol).float(), dim=0)
            converged_indx[converged_indx==0] = log_rel_residual.size(0)  # =max number of time steps
            conv_times.append(converged_indx.float().mean())

            if tol == -4.:
                ax.scatter(conv_times[-1], bins//2, marker="o", linewidths=2.,
                           s=100, facecolors='none', edgecolors='w')

        conv_times = torch.stack(conv_times)
        print(conv_times)
        print(conv_tols)

    ax.plot(reversed(conv_times), bins-reversed(conv_tols)*(-bins)//8, color='c')
    im = ax.imshow(hists.T, origin='lower', norm=norm, cmap='magma_r', aspect='auto')

    model = 'HAM'
    if key[0]:
        model += '-EO'
    if key[1]:
        model += '-DEQ'
    
    ax.set_title("State dynamics for "+model)

    #Only for bottom row
    if key[0]: ax.set_xlabel('Nr. of iterations')

    #Only for left column
    if not key[1]: ax.set_ylabel(r'Relative residual $\frac{||s_{t+1} - s_t||_2}{||s_{t+1}||_2}$')

    ax.set_yticks(range(0, bins, bins//4))
    ax.set_yticklabels(["$10^{"+str(x)+"}$" for x in range(-8, 0, 2)])

# Add a colorbar using the LogNorm instance
cbar = fig.colorbar(im, ax=axes, orientation='vertical', pad=0.1, norm=norm)
cbar.set_label('State trajectory density (raw counts)')

plt.show()
