In [None]:
from math import sqrt
from random import seed

import matplotlib.pyplot as plt
import torch as th
import torch.nn as nn
import torch.nn.functional as F

th.manual_seed(0)
seed(0)

In [None]:
def get_embeddings(n, d, norm=True):
    emb = th.randn(n, d)
    if norm:
        emb /= emb.norm(dim=1, keepdim=True)
    else:
        emb /= sqrt(d)
    return emb


class AssMem(nn.Module):
    def __init__(self, E, U):
        """
        E: torch.Tensor
            Input embedding matrix of size $n \times d$,
            where $n$ is the number of tokens and $d$ is the embedding dimension.
        U: torch.Tensor
            Output unembedding matrix of size $d \times m$,
            where $m$ is the number of classes and $d$ is the embedding dimension.
        """
        super().__init__()
        d = E.shape[1]
        self.W = nn.Parameter(th.zeros(d, d))
        self.E = E
        self.U = U

    def forward(self, x):
        out = self.E[x] @ self.W
        out = out @ self.U
        return out

In [None]:
# number of input tokens
n = 10
# number of output classes
m = 5
# memory dimension
d = 5

alpha = 1.5

In [None]:
all_x = th.arange(n)
proba = (all_x + 1.) ** (-alpha)
proba /= proba.sum()
all_y = all_x % m

In [None]:
# number of data
batch_size = 1
nb_epoch = 1000
T = nb_epoch * batch_size
lr = 1e-1

In [None]:
# Embeddings
E = get_embeddings(n, d, norm=False)
U = get_embeddings(m, d, norm=True).T 

# models
assoc = AssMem(E, U)
opti = th.optim.SGD(assoc.parameters(), lr=lr, momentum=0)

train_loss = []
test_loss = []

for i in range(nb_epoch):
    x = th.multinomial(proba, batch_size, replacement=True)
    y = x % m

    out = assoc(x)
    loss = F.cross_entropy(out, y)
    train_loss.append(loss.item())

    with th.no_grad():
        pred = assoc(all_x).argmax(dim=-1)
        test_loss.append(proba[pred != all_y].sum().item())


    with th.no_grad():
        mat = assoc.E @ assoc.W @ assoc.U
        mat = F.softmax(mat, dim=-1)
        mat = mat.numpy()

    if mat[x, y] < .8:
        fig, ax = plt.subplots(1, 1, figsize=(4, 4))
        c = ax.imshow(mat, aspect='auto')
        ax.add_patch(plt.Rectangle((y.item() - .5, x.item() - .5), 1, 1, fill=False, edgecolor='red', lw=2))
        ax.set_axis_off()
        fig.savefig(f'sgd/mat_step{i}_0.png')

    opti.zero_grad()
    loss.backward()
    opti.step()

    if mat[x, y] < .8:
        with th.no_grad():
            mat = assoc.E @ assoc.W @ assoc.U
            mat = F.softmax(mat, dim=-1)
            mat = mat.numpy()

        fig, ax = plt.subplots(1, 1, figsize=(4, 4))
        c = ax.imshow(mat, aspect='auto')
        ax.add_patch(plt.Rectangle((y.item() - .5, x.item() - .5), 1, 1, fill=False, edgecolor='red', lw=2))
        ax.set_axis_off()
        fig.savefig(f'sgd/mat_step{i}_1.png')