In [None]:
import copy
import torch
import glob
import json

import numpy as np
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from pandas import DataFrame
import ipywidgets as widgets

from alignn.pretrained import *

from src.utils import *
from src import botcher_hessian_alignn as hess
from src import botcher_utilities as util
from src.hessian_wag import get_hessian_wag
from src.hessian_wag import get_sample_from_normal_dist_of_models, square_model_wts, sqrt_model_wts, get_normed_model,get_stddev_model

from tests.test_hessian_wag import *

In [None]:
device = 'cpu'

# Load Model

In [None]:
style = {'description_width': 'initial'}

config_selector = widgets.Dropdown(
    options=list(get_all_models().keys()),
    value=list(get_all_models().keys())[0],
    description='Select Model',
    style=style,
    disabled=False,
)

display(config_selector)

In [None]:
model_name = config_selector.value
print("Selected: ", model_name)
model, model_wt_dict = load_pretrained_model(model_name, print_summary=True)
plot_model_wt_dist(model)

In [None]:
zeroed_model = zero_out_model(model)
plot_model_wt_dist(zeroed_model)

In [None]:
ones_model = make_constant_model(model, 1.0)
plot_model_wt_dist(ones_model)

In [None]:
neg_ones_model = make_constant_model(model, -1.0)
plot_model_wt_dist(neg_ones_model)

# Tests

In [None]:
# Test square_model_wts with ones model.  The distribution should be the same
test_square_model_wts_w_ones = square_model_wts(ones_model)
plot_model_wt_dist(test_square_model_wts_w_ones)

In [None]:
# Test square_model_wts with neg_ones model.  The distribution should now be all ones
test_square_model_wts_w_neg_ones = square_model_wts(neg_ones_model)
plot_model_wt_dist(test_square_model_wts_w_neg_ones)

In [None]:
# Test sqrt_model_wts with ones model.  The distribution should now be all ones
test_sqrt_model_wts_w_ones = sqrt_model_wts(ones_model)
plot_model_wt_dist(test_sqrt_model_wts_w_ones)

In [None]:
# Test sqrt_model_wts with neg_ones model.  This should break with nan
test_sqrt_model_wts_w_neg_ones = sqrt_model_wts(neg_ones_model)
# plot_model_wt_dist(test_sqrt_model_wts_w_neg_ones) # can't plot a hist of nan values
print(next(iter(test_sqrt_model_wts_w_neg_ones.named_parameters())))

In [None]:
# Test get_normed_model with ones_model, which should return a value of 0.5 because every value gets divided by 2
scaled_model = get_normed_model(ones_model, 2)
plot_model_wt_dist(scaled_model)

To test whether the `get_sample_from_normal_dist_of_models()` function is working we need several things:

1. A model where every weight in the model represents the mean value of that model parameter in the normal distribution
2. A model where every weight in the model represents the standard deviation of that model parameter in the normal distribution

---

Choose a zeros model for the first and a ones model for the second

---

Then every model drawn from this distribution should have a weight and bias distribution which is zero mean gaussian with a standard deviation of one.

In [None]:
model_mu = make_constant_model(model, 0.0)
model_std = make_constant_model(model, 1.0)

In [None]:
norm_dist_model = get_sample_from_normal_dist_of_models(model_mu, model_std, model_mu)
plot_model_wt_dist(norm_dist_model)

Finally, we need to test that we can create a model where every weight in the model represents the standard deviation of that weight in the normal distribution using the `get_stddev_model()` function.  The standard deviation is calculated using the expression $\sigma = \sqrt{E[X^2] - E[X]^2}$. We need

1. A model where every weight in the model represents the mean value of the original model parameter in the to-be-created normal distribution: $E[X]$
2. A model where every weight in the model represents $\sum{X^2}$, so that I can divide by the total number of models $N$ and have the term $E[X^2]$
---

Choose zeros for the first model, and ones for the second model

---
Then every weight and bias in this model should have a value of one, because we are trying to get the standard deviation of the system.

In [None]:
stddev_model = get_stddev_model(zeroed_model, ones_model, 2.0)
plot_model_wt_dist(stddev_model)

Good!

## Test loss landscapes

In [None]:
from src.hessian_wag import Metric
from torch.nn import CosineSimilarity, L1Loss

In [None]:
def get_avg_distances(model_A, model_B):
    uv_distances = []
    cos_metric = CosineSimilarity(dim=0, eps=1e-6)
    l1_metric = L1Loss()

    for layer_name, layer_param in model_A.named_parameters():
        try:
            u = model_A.get_parameter(layer_name).flatten()
        except:
            continue

        v = model_B.get_parameter(layer_name).flatten()

        # dist = cos_metric(u, v)
        dist = l1_metric(u, v)
        uv_distances.append(dist.cpu().detach().numpy())

    return np.mean(uv_distances)

In [None]:
get_avg_distances(ones_model, ones_model)

In [None]:
class Loss(Metric):
    """ Computes a specified loss function over specified input-output pairs. """
    def __init__(self,):
        super().__init__()

    def __call__(self, model_wrapper: ModelWrapper, og_model: ModelWrapper) -> float:
        return get_avg_distances(model_wrapper.modules[0], og_model.modules[0])

In [None]:
from loss_landscapes.main import planar_interpolation

In [None]:
metric = Loss()

In [None]:
tmp = planar_interpolation(zeroed_model, neg_ones_model, ones_model, metric, deepcopy_model=True,steps=40)

In [None]:
plt.imshow(tmp, cmap='jet', origin='lower')
plt.colorbar()

In [None]:
tmp2 = planar_interpolation(zeroed_model, ones_model, neg_ones_model, metric, deepcopy_model=True,steps=40)

In [None]:
plt.imshow(tmp2, cmap='jet', origin='lower')
plt.colorbar()

In [None]:
tmp3 = planar_interpolation(ones_model, neg_ones_model, neg_ones_model, metric, deepcopy_model=True,steps=40)

In [None]:
plt.imshow(tmp3, cmap='jet', origin='lower')
plt.colorbar()

In [None]:
np.amax(tmp3)