# EOS for population dynamics

In [None]:
import sys
import subprocess

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from scipy.linalg import eigh
from matplotlib import rc

plt.rc('font', family='serif', size=8)

sys.path.append('.')
from model import AssociativeMemory, get_embeddings

torch.manual_seed(42)

WIDTH = 8.5              # inches (from ICML style file)
HEIGHT = 8.5 / 1.5     # golden ratio

rc('font', family='serif', size=8)
usetex = not subprocess.run(['which', 'pdflatex']).returncode
rc('text', usetex=usetex)
if usetex:
    rc('text.latex', preamble=r'\usepackage{times}')

def f(x, epsilon=0):
    return x

In [None]:
def get_data(n, d):
    
    E = get_embeddings(n, d, norm=True)
    U = get_embeddings(n, d, norm=True)
    
    alpha = 1.
    proba = (torch.arange(n) + 1.) ** (-alpha)
    proba /= proba.sum()
    return E, U, proba


In [None]:
# niter = 100
# lrs = [5, 10, 20]
# k = 1

def get_res(E, U, proba, lrs, niter=100):
    n, d = E.shape
    all_x = torch.arange(n)
    all_y = f(all_x)
    res = {}
    for lr in lrs:
        # model
        # model = AssociativeMemory(E, U, random_init=True, layer_norm=True)
        model = AssociativeMemory(E, U, random_init=False)
        model.to(torch.float64)
        all_x.to(torch.float64)
        all_y.to(torch.float64)

        r = 2 # check progress on Wstar = sum_i<r ui ei'

        Wstar = (model.UT[:,torch.arange(r)] @ model.E[torch.arange(r)]).T

        star_scores = model.E[all_x[:10]] @ Wstar @ model.UT
        star_acc = (proba[:10] * (star_scores.argmax(-1) == all_y[:10]).float()).sum()

        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=niter, eta_min=0)

        losses = torch.zeros(niter)
        accs = torch.zeros(niter)
        accs_start = torch.zeros(niter)
        accs_end = torch.zeros(niter)
        all_accs = torch.zeros((niter, n))
        eigenvals_pop = np.zeros(niter)
        progress = torch.zeros(niter)
        progress_orth = torch.zeros(niter)
        all_scores = torch.zeros((niter, n))
        all_margins = torch.zeros((niter, n))


        for i in range(niter):
            optimizer.zero_grad()

            # compute loss

            score = model(all_x)
            loss = (proba * F.cross_entropy(score, all_y, reduction='none')).sum()

            loss.backward()
            optimizer.step()
        #     scheduler.step()

            # record statistics
            losses[i] = loss.item()
            accs[i] = (proba * (score.argmax(-1) == all_y).float()).sum()
            accs_start[i] = ((score.argmax(-1) == all_y)[:d]).float().mean()
            accs_end[i] = ((score.argmax(-1) == all_y)[d:2*d]).float().mean()
            all_accs[i] = ((score.argmax(-1) == all_y)).float()
            all_scores[i] = score.detach().diag()
            all_margins[i] = (score.detach() - (score.detach() - torch.diag(torch.inf * torch.ones(n))).max(-1)[0]).diag()
            population_hessian = model.hessian(all_x, proba)
            eigenvals_pop[i] = lr * eigh(population_hessian.numpy(), eigvals_only=True, subset_by_index=[d*d-1, d*d-1]) / 2.
        #     progress[i] = (torch.sum(Wstar * model.W.data).item()
        #                        / torch.sqrt(torch.sum(model.W.data ** 2))
        #                        / torch.sqrt(torch.sum(Wstar.data ** 2)))
            progress[i] = (model.E[0] - model.E[1]) @ model.W.data @ (model.UT[:,0] - model.UT[:,1])
            progress_orth[i] = (model.E[0] + model.E[1]) @ model.W.data @ (model.UT[:,0] - model.UT[:,1])

        #     lrs[i] = scheduler.get_last_lr()[0]

        res[lr] = {'losses': losses, 'accs': accs, 'accs_start': accs_start, 'all_margins': all_margins,
                   'progress': progress, 'progress_orth': progress_orth}
    return res

# Save plots

In [None]:

N = 5
lrs = [3, 10, 20]

for d in [3, 5, 10]:
    torch.manual_seed(43)
    E, U, proba = get_data(N, d)
    res = get_res(E, U, proba, lrs, niter=100)

    fig, ax = plt.subplots(1, 1, figsize=(.2 * WIDTH, .2 * HEIGHT))
    for lr in res:
        ax.plot(res[lr]['losses'][:30], label=f'$\\eta = {lr}$')
    ax.set_xlabel(r'iteration $t$', fontsize=6)
    ax.legend()
    ax.set_ylabel(r'${\cal L}(W_t)$')
    ax.set_title(fr"$d = {d}$")
    fig.savefig(f'figures/loss_N{N}_d{d}.pdf', pad_inches=0, bbox_inches='tight')

    
    for lr in res:
        fig, ax = plt.subplots(1, 1, figsize=(.18 * WIDTH, .18 * HEIGHT))
        leg = []
        for i in range(5):
            a, = ax.plot(res[lr]['all_margins'][:30,i])
            leg.append(a)
        ax.set_yticks([0,2])
        ax.set_ylabel(r'$m_t(x)$')
        ax.set_title(rf"$d = {d},\ \eta = {lr}$")
        if lr == 20 and d == 10:
            ax.legend(leg, [r'1', r'2', r'3', r'4', r'5'],fontsize=6, handlelength=1, ncol=2, frameon=True, loc='lower right')
        if d != 10:
            ax.set_xticks([])
        else:
            ax.set_xlabel(r'iteration $t$')
        ax.grid(axis='y', alpha=.5)
        fig.savefig(f'figures/margins_N{N}_d{d}_lr{lr}.pdf', pad_inches=0, bbox_inches='tight')
