In [None]:
# Run once
%load_ext autoreload
%autoreload 2
%cd ..

# Preamble

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps as cmaps
from sklearn.datasets import make_moons, make_circles
import datasets
import torch
from datasets.tabular import TabularModel, TabularModelPerturb, learning_pipeline
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from similarity import get_top_k, top_k_sa, average_pairwise_score, ground_truth_score, average_ground_truth_score
from similarity import angle_diff, average_pairwise_score_grad, cosine_similarity
from util import State, get_weight_norm, get_weight_diff, linear_weight_interpolation
from train import get_states
from tqdm import tqdm

In [None]:
class TwoMoons(Dataset):
    def __init__(self, n_samples=1000, noise=0.1, random_state=0, circles=False):
        if circles:
            X, y = make_circles(n_samples=n_samples, noise=noise,
                                factor=0.5, random_state=random_state)
        else:
            X, y = make_moons(n_samples=n_samples, noise=noise,
                              random_state=random_state)
        self.name = 'moons'
        self.data = torch.FloatTensor(X)
        self.labels = torch.LongTensor(y)

    def __len__(self):
        """Return the length of the dataset. Necessary for PyTorch's DataLoader."""
        return len(self.data)

    def __getitem__(self, idx):
        """Return the sample at index idx. Necessary for PyTorch's DataLoader."""
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

In [None]:
circles = False
trainset = TwoMoons(n_samples=800, noise=0.4, random_state=0, circles=circles)
testset = TwoMoons(n_samples=200, noise=0.4, random_state=1, circles=circles)
# trainset.data[trainset.labels==1, 0] -= 0.5
# testset.data[testset.labels==1, 0] -= 0.5
X_test, y_test = testset.data.numpy(), testset.labels.numpy()
n_inputs, n_features = X_test.shape
X_train, y_train = trainset.data.numpy(), trainset.labels.numpy()

# Plot trainset and test set on two columns
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.scatter(trainset.data[:, 0], trainset.data[:, 1], c=trainset.labels, s=2)
ax2.scatter(testset.data[:, 0], testset.data[:, 1], c=testset.labels, s=2)
plt.show()

# Train Models

In [None]:
n_models = 200
config = {'n': n_models,
          'optimizer': 'adam',
          'epochs': 100,
          'lr': 0.004,
          'batch_size': 16,
          'loo': False,
          'seed': 0,
          'mode_connect': '',
          'wandb': False}

In [None]:
models = []
config['epochs'] = 100
config['mode_connect'] = ''
States = get_states(n_models, TabularModel, trainset, testset, config)
for S in States:
    models.append(learning_pipeline(S))

In [None]:
mode_models = []
config['epochs'] = 100
config['mode_connect'] = 'bezier'
States = get_states(n_models, TabularModel, trainset, testset, config)
for S in States:
    mode_models.append(learning_pipeline(S))

In [None]:
# Compute perturbed models
num_perturbations = 50
sigma = 0.2
pert_models = []
for model in tqdm(models):
    pert_model = TabularModelPerturb(model, num_perturbations, sigma)
    pert_models.append(pert_model)

# Postprocess

In [None]:
def get_stats(method, X, idx):
    if method == 'average':
        logits = models[idx](torch.FloatTensor(X)).detach().numpy()
        grads = models[idx].compute_gradients(torch.FloatTensor(X), return_numpy=True)
    elif method == 'perturb':
        logits = pert_models[idx](torch.FloatTensor(X)).detach().numpy().mean(axis=0)
        grads = pert_models[idx].compute_gradients(X, mean=True)
    elif method == 'mode connect':
        logits = mode_models[idx].compute_logits(X, TabularModel, ts).mean(axis=0)
        grads = mode_models[idx].compute_gradients(X, TabularModel, ts).mean(axis=0)
    return logits, grads

In [None]:
ts = np.linspace(0,1,50)

In [None]:
methods = ['average', 'perturb', 'mode connect']
logits = np.zeros((len(methods), n_models, n_inputs, 2))
grads = np.zeros((len(methods), n_models, n_inputs, n_features))
angles = np.zeros((len(methods), n_inputs))
accs = np.zeros((len(methods), n_models))
for i, method in enumerate(tqdm(methods)):
    for j, model in enumerate(models):
        logits[i, j], grads[i, j] = get_stats(method, X_test, j)
    angles[i] = average_pairwise_score_grad(grads[i], angle_diff)
    accs[i] = (logits[i].argmax(axis=2) == y_test).mean(axis=1)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 5), dpi=100)
for i, method in enumerate(methods):
    ax[0].boxplot([angles[i] for i in range(len(methods))], labels=methods)
    ax[0].set_title('Angle between gradients')
    ax[1].boxplot([accs[i] for i in range(len(methods))], labels=methods)
    ax[1].set_title('Accuracy')
plt.show()

# Heatmaps

In [None]:
print(X_test[:, 0].min(), X_test[:, 0].max(), X_test[:, 1].min(), X_test[:, 1].max())
print(X_train[:, 0].min(), X_train[:, 0].max(), X_train[:, 1].min(), X_train[:, 1].max())

In [None]:
x = np.linspace(-1, 2, 400)
y = np.linspace(-1.2, 1.6, 400)
xx, yy = np.meshgrid(x, y)
X = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])
extent = [x.min(), x.max(), y.min(), y.max()]
ts = np.linspace(0, 1, 50)

### Predictions

In [None]:
Z = np.zeros((n_models, *xx.shape))
Z_mode = np.zeros((n_models, *xx.shape))
Z_pert = np.zeros((n_models, *xx.shape))
for i in tqdm(range(n_models)):
    Z[i] = models[i](torch.FloatTensor(X)).detach().numpy()[:, 1].reshape(xx.shape)
    Z_mode[i] = mode_models[i].compute_logits(X, TabularModel, ts).mean(axis=0)[:, 1].reshape(xx.shape)
    Z_pert[i] = pert_models[i](torch.FloatTensor(X)).detach().numpy().mean(axis=0)[:, 1].reshape(xx.shape)

In [None]:
# Plot heatmap of average prediction
fig, ax = plt.subplots(1, 3, figsize=(10, 5), dpi=150)
titles = ['Standard', 'Perturbed', 'Bezier']
for i, z in enumerate([Z, Z_pert, Z_mode]):
    im = ax[i].imshow(z.mean(axis=0), extent=extent, origin='lower', cmap='RdYlBu', alpha=0.8)
    fig.colorbar(im, fraction=0.035)
    # Scatter test data, dark red for class 0, dark blue for class 1


    ax[i].scatter(testset.data[:, 0], testset.data[:, 1], c=testset.labels, s=4, cmap='hot', alpha=0.5)
    ax[i].set_xticks([])
    ax[i].set_yticks([])
plt.show()

In [None]:
cols = ['binary', 'gist_yarg', 'gist_gray', 'gray', 'bone',
        'pink', 'spring', 'summer', 'autumn', 'winter', 'cool',
        'Wistia', 'hot', 'afmhot', 'gist_heat', 'copper']
for col in cols:
    # Plot heatmap of 5 models predictions
        model_idx = [94,133,64,161]#np.random.choice(n_models, 5, replace=False)
        s = 4
        back_alpha = 0.9
        cmap = col

        fig, ax = plt.subplots(2, 2, figsize=(10,6), dpi=150)
        fig.subplots_adjust(right=0.8)
        plt.subplots_adjust(wspace=0.035, hspace=0.1)
        for i, idx in enumerate(model_idx):
                zi = models[idx](torch.FloatTensor(X)).detach().numpy()[:, 1].reshape(xx.shape)
                im = ax[i//2,i%2].imshow(zi, extent=extent, origin='lower', cmap=cmap, alpha=back_alpha)
                ax[i//2,i%2].scatter(testset.data[:, 0], testset.data[:, 1], cmap='hot_r',  c=testset.labels, alpha=0.9, s=s)
                ax[i//2,i%2].set_xticks([])
                ax[i//2,i%2].set_yticks([])
        cbar_ax = fig.add_axes([0.815, 0.11, 0.02, 0.77])  # x, y, width, height
        fig.colorbar(im, cax=cbar_ax)
        for t in cbar_ax.get_yticklabels():
                t.set_fontsize(16)
        plt.suptitle('Model Variation due to Random Seed', fontsize=18, y=0.93, x=0.465)
        plt.show()

In [None]:
from matplotlib import colors, colormaps
from matplotlib.colors import LinearSegmentedColormap
# Get red/blue colormap from matplotlib
cmap = colormaps['RdYlBu']
light_red = cmap(0.2)  # Adjust this value to get the desired shade of light red
yellow = cmap(0.5)  # This value is fixed
light_blue = cmap(0.8)  # Adjust this value to get the desired shade of light blue

# Create a custom colormap with the desired range
cols = [light_red, yellow, light_blue][::-1]  # Color list: [light_red, white, light_blue]
custom_cmap = LinearSegmentedColormap.from_list('custom_RdYlBu', cols)

In [None]:
# Plot heatmap of 5 models predictions
model_idx = [94,133,64,161]#np.random.choice(n_models, 5, replace=False)
s, alpha = 6, 0.7
back_alpha = 1
cmap = custom_cmap#'RdYlBu_r'

fig, ax = plt.subplots(2, 2, figsize=(10,6), dpi=200)
fig.subplots_adjust(right=0.8)
plt.subplots_adjust(wspace=0.035, hspace=0.06)
for i, idx in enumerate(model_idx):
    zi = models[idx](torch.FloatTensor(X)).detach().numpy()[:, 1].reshape(xx.shape)
    im = ax[i//2,i%2].imshow(zi**0.3, extent=extent, origin='lower', cmap=cmap, alpha=back_alpha)
    ax[i//2,i%2].scatter(testset.data[:, 0], testset.data[:, 1], cmap='RdYlBu',  c=1-testset.labels, alpha=alpha, s=s)
    ax[i//2,i%2].set_xticks([])
    ax[i//2,i%2].set_yticks([])
# Remove right and bottom border for ax[0,0]
ax[0,0].spines['right'].set_visible(False)
ax[0,0].spines['bottom'].set_visible(False)
ax[0,1].spines['left'].set_visible(False)
ax[0,1].spines['bottom'].set_visible(False)
cbar_ax = fig.add_axes([0.815, 0.11, 0.02, 0.77])  # x, y, width, height
fig.colorbar(im, cax=cbar_ax)
for t in cbar_ax.get_yticklabels():
     t.set_fontsize(16)
plt.suptitle('Model Variation within the Underspecification Set', fontsize=19, y=0.935, x=0.465)
plt.show()

In [None]:
# Plot heatmap of 5 models predictions
model_idx = [94,133,64,161]#np.random.choice(n_models, 5, replace=False)
s = 6
back_alpha = 1.0
cmap = custom_cmap#'RdYlBu'

fig, ax = plt.subplots(1, len(model_idx), figsize=(len(model_idx)*4, 5))
for i, idx in enumerate(model_idx):
    zi = models[idx](torch.FloatTensor(X)).detach().numpy()[:, 1].reshape(xx.shape)
    im = ax[i].imshow(zi, extent=extent, origin='lower', cmap=cmap, alpha=back_alpha)
    fig.colorbar(im, fraction=0.035)
    ax[i].scatter(testset.data[:, 0], testset.data[:, 1], cmap='hot_r',  c=testset.labels, alpha=0.5, s=s)
    ax[i].set_xticks([])
    ax[i].set_yticks([])
plt.show()

fig, ax = plt.subplots(1, len(model_idx), figsize=(len(model_idx)*4, 5))
for i, idx in enumerate(model_idx):
    zi = pert_models[i](torch.FloatTensor(X)).detach().numpy()[:, :, 1].mean(axis=0).reshape(xx.shape)
    im = ax[i].imshow(zi, extent=extent, origin='lower', cmap=cmap, alpha=back_alpha)
    fig.colorbar(im, fraction=0.035)
    ax[i].scatter(testset.data[:, 0], testset.data[:, 1],  cmap='hot_r', c=testset.labels, alpha=0.5, s=s)
    ax[i].set_xticks([])
    ax[i].set_yticks([])
plt.show()

fig, ax = plt.subplots(1, len(model_idx), figsize=(len(model_idx)*4, 5))
for i, idx in enumerate(model_idx):
    zi = mode_models[i].compute_logits(X, TabularModel, ts)[:, :, 1].mean(axis=0).reshape(xx.shape)
    im = ax[i].imshow(zi, extent=extent, origin='lower', cmap=cmap, alpha=back_alpha)
    fig.colorbar(im, fraction=0.035)
    ax[i].scatter(testset.data[:, 0], testset.data[:, 1], c=testset.labels, cmap='hot_r', alpha=0.5, s=s)
    ax[i].set_xticks([])
    ax[i].set_yticks([])
plt.show()

### Ensembles

In [None]:
def get_grads(method, idx):
    if method == 'average':
        grads = models[idx].compute_gradients(torch.FloatTensor(X), return_numpy=True)
    elif method == 'perturb':
        grads = pert_models[idx].compute_gradients(X, mean=True)
    elif method == 'mode connect':
        grads = mode_models[idx].compute_gradients(X, TabularModel, ts).mean(axis=0)
    return grads

In [None]:
methods = ['average', 'perturb', 'mode connect']
grads_grid = np.zeros((len(methods), n_models, X.shape[0], 2))
from datasets.tabular import TabularModel
for i, method in enumerate(methods):
    for j in tqdm(range(n_models)):
        # Compute gradients
        grads_grid[i,j] = get_grads(method, j)

In [None]:
# Parameters
ensemble_sizes = [1, 2, 4, 6, 8, 10]
n_trials = 20
k = 1

# Compute statistics
e_grads_grid = np.zeros((len(methods), len(ensemble_sizes), n_trials, X.shape[0], 2))
angles_grid = np.zeros((len(methods), len(ensemble_sizes), X.shape[0]))
cosines_grid = np.zeros((len(methods), len(ensemble_sizes), X.shape[0]))
# topk = np.zeros((len(methods), len(ensemble_sizes), n_trials, X.shape[0], k))
# signs = np.zeros((len(methods), len(ensemble_sizes), n_trials, X.shape[0], k), dtype=int)
# sa = np.zeros((len(methods), len(ensemble_sizes), X.shape[0]))

from datasets.tabular import TabularModel
for i, method in enumerate(methods):
    for j, ensemble_size in enumerate(tqdm(ensemble_sizes)):
        e_size = ensemble_size//2 if method == 'mode connect' else ensemble_size
        e_grads_grid[i,j] = grads_grid[i, :n_trials*e_size].reshape(n_trials, e_size, *X.shape).mean(axis=1)
        angles_grid[i,j] = average_pairwise_score_grad(e_grads_grid[i,j], angle_diff)
        cosines_grid[i,j] = average_pairwise_score_grad(e_grads_grid[i,j], cosine_similarity)
        # topk[i,j], signs[i,j] = get_top_k(k, e_grads[i,j], return_sign=True)
        # sa[i,j] = average_pairwise_score(topk[i,j], signs[i,j], top_k_sa)

In [None]:
def plot_methods(sim, idx, vmin, vmax, sep, scale=2.5, cmap='RdYlGn_r'):
    v = np.arange(vmin, vmax, sep)
    for j, method in enumerate(methods):
        fig, ax = plt.subplots(1, len(ensemble_sizes), figsize=(4*len(ensemble_sizes), 3), dpi=150)
        plt.subplots_adjust(wspace=0.2)
        for i, ensemble_size in enumerate(ensemble_sizes):
            Z = sims[sim][j,i].reshape(xx.shape)
            im = ax[i].imshow(Z, vmin=vmin, vmax=vmax, extent=extent, origin='lower', cmap=cmap)
            fig.colorbar(im, ax=ax[i], fraction=0.035, boundaries=v)
            pg = e_grads_grid[j, i, :, idx]
            norm = np.linalg.norm(pg, axis=1)
            # Convert norm to range 0.5, 1
            norm = (norm - norm.min()) / (norm.max() - norm.min()) + 0.5
            pg = pg / np.linalg.norm(pg, axis=1, keepdims=True)
            pg = pg * norm[:, None]
            ax[i].quiver(np.repeat(X[idx,0], n_trials), np.repeat(X[idx,1], n_trials),
                        pg[:,0], pg[:, 1], angles='xy', scale_units='xy', scale=scale, color='black')
            ax[i].set_title(f'Ensemble size: {ensemble_size}')
            ax[i].set_xticks([]); ax[i].set_yticks([])
        plt.suptitle(f'Average pairwise {sim} difference ({titles[j]})')
        plt.show()

titles = ['Average', 'Perturb', 'Mode Connect']
sims = {'angle difference': angles_grid,
        'cosine similarity': cosines_grid,
        'sign agreement': sa}

In [None]:
plt.boxplot(angles_grid[0,0])
plt.show()

In [None]:
ensemble_sizes_plot = []
fig, ax = plt.subplots(1, len(ensemble_sizes), figsize=(4*len(ensemble_sizes), 3), dpi=150)
plt.subplots_adjust(wspace=0.2)
for i, ensemble_size in enumerate(ensemble_sizes):
    Z = sims[sim][j,i].reshape(xx.shape)
    im = ax[i].imshow(Z, vmin=vmin, vmax=vmax, extent=extent, origin='lower', cmap=cmap)
    fig.colorbar(im, ax=ax[i], fraction=0.035, boundaries=v)
    pg = e_grads_grid[j, i, :, idx]
    norm = np.linalg.norm(pg, axis=1)
    # Convert norm to range 0.5, 1
    norm = (norm - norm.min()) / (norm.max() - norm.min()) + 0.5
    pg = pg / np.linalg.norm(pg, axis=1, keepdims=True)
    pg = pg * norm[:, None]
    ax[i].quiver(np.repeat(X[idx,0], n_trials), np.repeat(X[idx,1], n_trials),
                pg[:,0], pg[:, 1], angles='xy', scale_units='xy', scale=scale, color='black')
    ax[i].set_title(f'Ensemble size: {ensemble_size}')
    ax[i].set_xticks([]); ax[i].set_yticks([])
plt.suptitle(f'Average pairwise {sim} difference ({titles[j]})')
plt.show()

In [None]:
plt.boxplot(angles_grid[0,0])
plt.show()

In [None]:
idx = right_idx[3]
vmin, vmax = 0, 91
sep = 10
z = angles_grid[0,0].reshape(xx.shape)
pg = e_grads_grid[0,0, :, idx]
plt.figure(figsize=(6,6), dpi=100)
plt.imshow(z, vmin=vmin, vmax=vmax, extent=extent, origin='lower', cmap='RdYlBu_r')
plt.colorbar(fraction=0.0314, boundaries=np.arange(vmin, vmax, sep))
plt.quiver(np.repeat(X[idx,0], n_trials), np.repeat(X[idx,1], n_trials),
            pg[:,0], pg[:, 1], angles='xy', scale_units='xy', scale=8, color='black', alpha=0.75)
plt.title('Angular Difference between Ensemble Gradients', fontweight='bold')
plt.xticks([]); plt.yticks([])
# no plot border
# plt.gca().spines['right'].set_visible(False)
# plt.gca().spines['top'].set_visible(False)
# plt.gca().spines['left'].set_visible(False)
# plt.gca().spines['bottom'].set_visible(False)
plt.show()

In [None]:
idxs = np.argsort(angles_grid[0,0])[-20:][::-1]
idxs
left_idx = idxs[idxs % 400 < 201]
right_idx = idxs[idxs % 400 > 200]
left_idx, right_idx

In [None]:
idx = left_idx[1]
vmin, vmax, sep = 0, 91, 10
v = np.arange(vmin, vmax, sep)
ensemble_size = 4
e_idx = ensemble_sizes.index(ensemble_size)
ijs = [[0,0], [0,e_idx], [1,e_idx], [2,e_idx]]  # (method, ensemble) index pairs e.g. mode connect, size 2 would be (2,1)

fig, ax = plt.subplots(1, 4, figsize=(6*len(ijs), 6), dpi=400)
plt.subplots_adjust(wspace=0.15)

method_titles = ['Vanilla', 'Perturbed', 'Connected']
n_grads_plot = 15

for i, (m_i, e_i) in enumerate(ijs):
    zi = angles_grid[m_i, e_i].reshape(xx.shape)
    im = ax[i].imshow(zi, vmin=vmin, vmax=vmax, extent=extent, origin='lower', cmap='RdYlBu_r')
    fig.colorbar(im, ax=ax[i], fraction=0.0425, boundaries=v)
    pg = e_grads_grid[m_i, e_i, :, idx]
    norm = np.linalg.norm(pg, axis=1)
    # Convert norm to range 0.5, 1
    norm = (norm - norm.min()) / (norm.max() - norm.min()) + 0.5
    pg = pg / np.linalg.norm(pg, axis=1, keepdims=True)
    pg = pg * norm[:, None]
    ax[i].quiver(np.repeat(X[idx,0], n_trials), np.repeat(X[idx,1], n_trials),
                pg[:,0], pg[:, 1], angles='xy', scale_units='xy', scale=1, color='black', alpha=0.7-0.1*(i>0))
    if e_i == 0:
        ax[i].set_title('Constituent Models', fontsize=19)
    else:
        ax[i].set_title(f'{method_titles[m_i]} Ensembles', fontsize=19)
    ax[i].set_xticks([]); ax[i].set_yticks([])
# plt.suptitle(f'Average pairwise angular difference between gradients (ensembles of size {ensemble_size})',
#              y=0.84, fontsize=16)
plt.show()

In [None]:
idx = left_idx[0]
vmin, vmax, sep = 0, 91, 10
v = np.arange(vmin, vmax, sep)
ensemble_size = 4
e_idx = ensemble_sizes.index(ensemble_size)
ijs = [[0,e_idx], [1,e_idx], [2,e_idx]]  # (method, ensemble) index pairs e.g. mode connect, size 2 would be (2,1)

fig, ax = plt.subplots(1, len(ijs), figsize=(6*len(ijs), 6), dpi=200)
plt.subplots_adjust(wspace=0.15)

method_titles = ['Vanilla', 'Weight Perturbation', 'Mode Connectivity']

for i, (m_i, e_i) in enumerate(ijs):
    zi = angles_grid[m_i, e_i].reshape(xx.shape)
    im = ax[i].imshow(zi, vmin=vmin, vmax=vmax, extent=extent, origin='lower', cmap='RdYlBu_r')
    fig.colorbar(im, ax=ax[i], fraction=0.0313, boundaries=v)
    pg = e_grads_grid[m_i, e_i, :, idx]
    norm = np.linalg.norm(pg, axis=1)
    # Convert norm to range 0.5, 1
    norm = (norm - norm.min()) / (norm.max() - norm.min()) + 0.5
    pg = pg / np.linalg.norm(pg, axis=1, keepdims=True)
    pg = pg * norm[:, None]
    ax[i].quiver(np.repeat(X[idx,0], n_trials), np.repeat(X[idx,1], n_trials),
                pg[:,0], pg[:, 1], angles='xy', scale_units='xy', scale=1, color='black')
    ax[i].set_title(f'{method_titles[m_i]} Ensembles', fontsize=14)
    ax[i].set_xticks([]); ax[i].set_yticks([])
# plt.suptitle(f'Average pairwise angular difference between gradients (ensembles of size {ensemble_size})',
#              y=0.84, fontsize=16)
plt.show()

In [None]:
# Plot heatmap of angles
idx = right_idx[0]
vmin, vmax = 0, 91
sep = 5
plot_methods('angle difference', idx, vmin, vmax, sep, scale=0.9, cmap='RdYlBu_r')

In [None]:
# Plot heatmap of cosine similarities
idx = np.argmin(cosines[0,0])
vmin, vmax = 0.35, 0.51
sep = 0.05
plot_methods('cosine similarity', idx, vmin, vmax, sep, scale=5, cmap='RdYlGn')

In [None]:
# Plot heatmap of sign agreements
idx = np.argmin(sa[0,0])
vmin, vmax = 0.4, 1.01
sep = 0.1
plot_methods('sign agreement', idx, vmin, vmax, sep, scale=5, cmap='RdYlGn')

In [None]:
# Plot heatmap of angles
titles = ['Original', 'Perturbed']
idx = np.argmin(sa[0])
sims = [sa, sa_perturb]
sim_grads = [e_grads, e_grads_perturb]
vmin, vmax = 0.4, 1.01
sep = 0.1
v = np.arange(vmin, vmax, sep)
for j, sim in enumerate(sims):
    fig, ax = plt.subplots(1, len(ensemble_sizes), figsize=(4*len(ensemble_sizes), 3), dpi=150)
    plt.subplots_adjust(wspace=0.2)
    for i, ensemble_size in enumerate(ensemble_sizes):
        Z = sim[i].reshape(xx.shape)
        im = ax[i].imshow(Z, vmin=vmin, vmax=vmax, extent=extent, origin='lower', cmap='RdYlGn')
        fig.colorbar(im, fraction=0.035, boundaries=v)
        ax[i].quiver(np.repeat(X[idx,0], n_trials), np.repeat(X[idx,1], n_trials),
                     sim_grads[j][i, :, idx, 0], sim_grads[j][i, :, idx, 1],
                     angles='xy', scale_units='xy', scale=2.5, color='black')
        ax[i].set_title(f'Ensemble size: {ensemble_size}')
        ax[i].set_xticks([]); ax[i].set_yticks([])
    plt.suptitle(f'Average pairwise cosine similarity ({titles[j]})')
    plt.show()

In [None]:
fig, ax = plt.subplots(1, len(sims), figsize=(6*len(sims), 3), dpi=150)
for i, sim in enumerate(sims):
    for j, method in enumerate(methods):
        q = np.quantile(sims[sim][j], [0.4, 0.5, 0.6], axis=1)
        ax[i].plot(ensemble_sizes, q[1], label=titles[j])
        ax[i].fill_between(ensemble_sizes, q[0], q[2], alpha=0.2)
    ax[i].set_xlabel('Ensemble size')
    ax[i].set_title(f'Average pairwise {sim}')
    ax[i].legend()
plt.show()

### Misc

In [None]:
# Plot heatmap of sa
idx = 3144
Z = angles.reshape(xx.shape)
plt.imshow(Z, extent=[-2, 3, -1.5, 2], origin='lower', cmap='RdBu')
plt.colorbar()
plt.scatter(trainset.data[:, 0], trainset.data[:, 1], c=trainset.labels)
plt.quiver(np.repeat(X[idx,0], n_models), np.repeat(X[idx,1], n_models), grads[:, idx, 0], grads[:, idx, 1], scale=15, color='green')
plt.show()

In [None]:
k = 1
tk, s = get_top_k(k, grads, return_sign=True)
sa = average_pairwise_score(tk, s, top_k_sa)

In [None]:
# Plot heatmap of sa
Z = sa.reshape(xx.shape)
plt.imshow(Z, extent=[-2, 3, -1.5, 2], origin='lower', cmap='RdBu')
plt.colorbar()
plt.scatter(trainset.data[:, 0], trainset.data[:, 1], c=trainset.labels)
plt.quiver(np.repeat(X[idx,0], n_models), np.repeat(X[idx,1], n_models), grads[:, idx, 0], grads[:, idx, 1], scale=15, color='green')
plt.show()

# Modconn

In [None]:
from modconn.curves import train_curve
from modconn import curves
from datasets.tabular import TabularModelCurve
layers = datasets.tabular.layers['moons']
model_args = [n_features, layers]
model_args

In [None]:
def mode_connect(models, trainloader, lr, epochs, curve_type='polychain', optim='sgd',
                 ts=np.linspace(0, 1, 101), disable_tqdm=True, fix_start=False, fix_end=False):
    if curve_type == 'polychain':
        curve_type = curves.PolyChain
    elif curve_type == 'bezier':
        curve_type = curves.Bezier
    else:
        raise ValueError(f'Unknown curve type {curve_type}')
    p_curve = curves.train_curve(models=models, trainloader=trainloader,
                                 curve_class=TabularModelCurve, curve=curve_type,
                                 input_size=model_args[0], hidden_layers=model_args[1],
                                 fix_start=fix_start, fix_end=fix_end, optim=optim,
                                 lr=lr, epochs=epochs, disable_tqdm=disable_tqdm)
    # Compute gradients
    p_curve_logits = p_curve.compute_logits(X_test, TabularModel, ts)
    p_curve_grads = p_curve.compute_gradients(X_test, TabularModel, ts)
    return p_curve, p_curve_logits, p_curve_grads

def init_model(idx):
    torch.manual_seed(idx)
    model = TabularModel(*model_args)
    return model

In [None]:
def get_curve_statistics(p_curve, ts=np.linspace(0,1,101)):
    # Compute losses
    loss_fn = torch.nn.functional.cross_entropy
    p_curve_loss = np.zeros(len(ts))
    p_curve_loss_tr = np.zeros(len(ts))
    p_curve_preds_tr = np.zeros((len(ts), *y_train.shape))
    p_curve_grads_tr = np.zeros((len(ts), *X_train.shape))
    for i, t in enumerate(ts):
        model = p_curve.get_model_from_curve(TabularModel, t=t)
        p_curve_loss[i] = loss_fn(model.forward(torch.FloatTensor(X_test)), torch.tensor(y_test)).item()
        p_curve_loss_tr[i] = loss_fn(model.forward(torch.FloatTensor(X_train)), torch.tensor(y_train)).item()
        p_curve_preds_tr[i] = model.predict(X_train, return_numpy=True)
        p_curve_grads_tr[i] = model.compute_gradients(X_train, return_numpy=True)
    weight_norms = np.zeros(len(ts))
    weight_diffs = np.zeros(len(ts))
    model_0 = p_curve.get_model_from_curve(TabularModel, t=0)
    for i, t in enumerate(ts):
        model_t = p_curve.get_model_from_curve(TabularModel, t=t)
        weight_norms[i] = get_weight_norm(model_t.state_dict())
        weight_diffs[i] = get_weight_diff(model_t.state_dict(), model_0.state_dict())
    return p_curve_loss, p_curve_loss_tr, p_curve_preds_tr, p_curve_grads_tr, weight_norms, weight_diffs

def plot_statistics(p_curve_loss, p_curve_loss_tr,
                    p_curve_preds, p_curve_preds_tr,
                    p_curve_grads, p_curve_grads_tr,
                    weight_norms, weight_diffs, ts=np.linspace(0,1,101)):
    fig, ax = plt.subplots(1, 5, figsize=(20, 4), dpi=150)
    ax[0].plot(ts, p_curve_loss, label='Test')
    ax[0].plot(ts, p_curve_loss_tr, label='Train')

    ax[1].plot(ts, 100*(p_curve_preds==y_test).mean(axis=1), label='Test')
    ax[1].plot(ts, 100*(p_curve_preds_tr==y_train).mean(axis=1), label='Train')

    q = np.quantile(np.linalg.norm(p_curve_grads, axis=2), [0.25, 0.5, 0.75], axis=1)
    ax[2].plot(ts, q[1], label='Test')
    ax[2].fill_between(ts, q[0], q[2], alpha=0.2)
    q = np.quantile(np.linalg.norm(p_curve_grads_tr, axis=2), [0.25, 0.5, 0.75], axis=1)
    ax[2].plot(ts, q[1], label='Train')
    ax[2].fill_between(ts, q[0], q[2], alpha=0.2)

    topk, signs = get_top_k(1, p_curve_grads, return_sign=True)
    gt_score = ground_truth_score(topk, signs, gt, signs_gt, top_k_sa)
    q = np.quantile(gt_score, [0.25, 0.5, 0.75], axis=1)
    ax[3].plot(ts, q[1], label='Test')
    ax[3].fill_between(ts, q[0], q[2], alpha=0.2)

    ax[4].plot(ts, weight_norms, label='Weight Norm')
    ax[4].plot(ts, weight_diffs, label='Weight Diff')

    titles = ['Loss', 'Accuracy (%)', 'Gradient Norm', 'Ground Truth SA Similarity', 'Weight Norm']
    for i in range(5):
        ax[i].set_xlabel('t')
        ax[i].legend()
        ax[i].set_title(titles[i])
    plt.show()

In [None]:
grad = np.zeros((n_models, X_test.shape[0], 2))
for i in tqdm(range(n_models)):
    # Compute gradients
    grad[i] = models[i].compute_gradients(X_test, return_numpy=True)

In [None]:
gt, signs_gt = get_top_k(1, grad.mean(axis=0), return_sign=True)
tk, s = get_top_k(1, grad, return_sign=True)
orig_sa = average_ground_truth_score(tk, s, gt, signs_gt, top_k_sa)

In [None]:
weights = linear_weight_interpolation(models[54].state_dict(), models[155].state_dict(), [0, 0.25, 0.5, 0.75, 1.0])
curve_models = []
for weight in weights:
    model = TabularModel(*model_args)
    model.load_state_dict(weight)
    curve_models.append(model)

In [None]:
from align import align_tabular

In [None]:
loader = torch.utils.data.DataLoader(trainset, 32, shuffle=True)
model = TabularModel(*model_args)
model.state_dict()['network.0.weight'][:3]

In [None]:
ts = np.linspace(0, 1, 50)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=False)
model_0 = models[9]
model_1 = models[10]
model_a = align_tabular(model_0, model_1, trainloader, model_args)
curve_models = [model_a, model_1]

In [None]:
print(get_weight_norm(model_0), get_weight_norm(model_1), get_weight_norm(model_a))
print(get_weight_diff(model_0, model_1), get_weight_diff(model_0, model_a), get_weight_diff(model_a, model_1))

In [None]:
n_curves = 100
ts = np.linspace(0, 1, 50)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=False)
p_curve_grads = np.zeros((n_curves, len(ts), X_test.shape[0], 2))
p_curves = []
for i in tqdm(range(n_curves)):
    model_0 = models[i*2]
    model_1 = models[i*2+1]
    model_a = align_tabular(model_0, model_1, trainloader, model_args)
    curve_models = [model_a, model_1]
    p_curve, p_curve_logits, p_curve_grads[i] = mode_connect(curve_models, trainloader=trainloader, lr=0.1, optim='sgd',
                                                        epochs=100, curve_type='bezier', ts=ts,
                                                        disable_tqdm=True, fix_start=False, fix_end=True)
    p_curves.append(p_curve)
    # outputs = get_curve_statistics(p_curve, ts=ts)
    # p_curve_loss, p_curve_loss_tr, p_curve_preds_tr, p_curve_grads_tr, weight_norms, weight_diffs = outputs
    # plot_statistics(p_curve_loss, p_curve_loss_tr, p_curve_logits.argmax(axis=2), p_curve_preds_tr,
    #                 p_curve_grads, p_curve_grads_tr, weight_norms, weight_diffs, ts=ts)

In [None]:
len(pert_models)

In [None]:
idx = 5858
x = X[idx:idx+1]
#p_curve = mode_models[8]
pg = grads[0,:198,idx].reshape(3, 66, 2).mean(axis=0)
pg_mode = np.zeros((n_curves, 2))
pg_pert = np.zeros((len(pert_models), 2))
for i in tqdm(range(n_curves)):
    pg_mode[i] = p_curves[i].compute_gradients(x, TabularModel, ts=np.linspace(0,1,100))[:,0].mean(axis=0)
for i in tqdm(range(len(pert_models))):
    pg_pert[i] = pert_models[i].compute_gradients(x, mean=True)
pg_pert = pg_pert[:198].reshape(3, 66, 2).mean(axis=0)
pg_mode = pg_mode[:66]
x = x[0]

In [None]:
def x_angle(g):
    angs = np.zeros(len(g))
    for i in range(len(g)):
        angs[i] = np.arctan(g[i,1]/g[i,0])*180/np.pi
    return angs

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(20, 5), dpi=100)
method_grads = [pg_mode, pg]
ax[0].boxplot([np.linalg.norm(g, axis=1) for g in method_grads],
             labels=['Original', 'Mode Connect'])
ax[0].set_title('Gradient Norms')
for g in method_grads:
    ax[1].hist(np.linalg.norm(g, axis=1), bins=30, alpha=0.5)
    ax[2].hist(x_angle(g), bins=30, alpha=0.5)
ax[0].set_ylabel('Gradient Norm')
ax[1].set_ylabel('Count')
ax[2].set_ylabel('Count')
ax[1].legend(['Original', 'Mode Connect'])
ax[2].legend(['Original', 'Mode Connect'])
ax[1].set_title('Gradient Norms')
ax[2].set_title('Gradient Angles')
plt.show()

In [None]:
pg.shape, pg_mode.shape, pg_pert.shape

In [None]:
pg_mode.shape

In [None]:
quiver_plots([grads[0,:,idx], pg_mode, pg_pert], x, scale=200)

In [None]:
def quiver_plot(pg, x, scale=1.0):
    # Plot the gradients in pg as a quiver plot at x
    n_grads = pg.shape[0]
    plt.quiver(np.repeat(x[0], n_grads),
               np.repeat(x[1], n_grads),
               pg[:,0], pg[:,1], angles='xy',
               scale_units='xy', scale=scale,
               color=cmaps['RdYlGn'](np.linspace(0,1,n_grads)))
    # start_end_pg_x = [pg[0,0], pg[-1,0]]
    # start_end_pg_y = [pg[0,1], pg[-1,1]]
    # plt.quiver(np.repeat(x[0], 2),
    #            np.repeat(x[1], 2),
    #            start_end_pg_x, start_end_pg_y,
    #            angles='xy', scale_units='xy',
    #            scale=scale, color='black')
    mean_pg_x = pg.mean(axis=0)[0]
    mean_pg_y = pg.mean(axis=0)[1]
    # Dotted quiver
    plt.quiver(np.repeat(x[0], 1),
                np.repeat(x[1], 1),
                mean_pg_x, mean_pg_y,
                angles='xy', scale_units='xy',
                scale=scale, color='black')
    plt.xlabel('x1')
    plt.ylabel('x2')
    plt.show()

In [None]:
def quiver_plots(pgs, x, scale=1.0):
    # As above, but for multiple sets of gradients (one per column of plot)
    n_cols = len(pgs)
    fig, ax = plt.subplots(1, n_cols, figsize=(n_cols*7, 5), dpi=100)
    for i, pg in enumerate(pgs):
        n_grads = pg.shape[0]
        ax[i].quiver(np.repeat(x[0], n_grads),
                       np.repeat(x[1], n_grads),
                       pg[:,0], pg[:,1], angles='xy',
                       scale_units='xy', scale=scale,
                       color=cmaps['RdYlGn'](np.linspace(0,1,n_grads)))
        mean_pg_x = pg.mean(axis=0)[0]
        mean_pg_y = pg.mean(axis=0)[1]
        # Dotted quiver
        ax[i].quiver(np.repeat(x[0], 1),
                        np.repeat(x[1], 1),
                        mean_pg_x, mean_pg_y,
                        angles='xy', scale_units='xy',
                        scale=scale, color='black')
        ax[i].set_xlabel('x1')
        ax[i].set_ylabel('x2')
    plt.show()