# Compare Krotov and Hopfields Code with mine

In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from pathlib import Path
from tqdm.notebook import tqdm_notebook as tqdm

# custom imports
from context import LocalLearning

import scipy.io
import numpy as np
import matplotlib.pyplot as plt

from collections import defaultdict
from torchvision import datasets, transforms

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [3]:
pSet = LocalLearning.LocalLearningModel.pSet
pSet["in_size"] = 32**2 * 3 # CIFAR10 consists of 32x32 pixel 3 channel coloured images
pSet["tau_l"] = 1.0 / 0.02 # learning rate 0.04 leads to nans in model.W
pSet["Delta"] = 0.4
pSet["p"] = 2.0
pSet["k"] = 2
pSet["hidden_size"]=2500

## Krotov and Hopfield parameters and model definition

In [4]:
num_pixel = 32**2*3
#num_test = x_test.shape[0]
#x_train_flat = x_train.reshape(num_train, num_pixel)
eps0 = 2e-2    # learning rate
Kx = 50
Ky = 50
num_hidden = Kx * Ky    # number of hidden units that are displayed in Ky by Kx array
mu = 0.0
sigma = 1.0
num_epochs = 1000    # number of epochs
num_batch = 1000      # size of the minibatch
prec = 1e-30
delta = 0.4    # Strength of the anti-hebbian learning
p = 2.0        # Lebesgue norm of the weights
k = 3          # ranking parameter, must be integer that is bigger or equal than 2

In [5]:
def synaptic_activation(synapses, inputs):
    return (synapses.sign() * synapses.abs() ** (p - 1)).matmul(inputs)

def learning_activation(indices):
    best_ind, best_k_ind = indices[0], indices[k-1]
    g_i = torch.zeros(num_hidden, num_batch).to(device)
    g_i[best_ind,   torch.arange(num_batch).to(device)] = 1.0
    g_i[best_k_ind, torch.arange(num_batch).to(device)] = -delta
    return g_i

## Renormalized Dataset for Training Data

In [6]:
training_data = LocalLearning.LpUnitCIFAR10(
    root="../data/CIFAR10", train=True, transform=ToTensor(), p=pSet["p"]
)

dataloader_train = DataLoader(
    training_data, batch_size=num_batch, num_workers=10, shuffle=True
)

num_train = len(training_data)

Files already downloaded and verified


### Compare Activations

In [7]:
flat = torch.nn.Flatten()
input_v = torch.rand((num_batch, 3, 32, 32))
input_flat = flat(input_v).to(device)
synapses = torch.Tensor(num_hidden, num_pixel).normal_(mu, sigma).to(device)

# Krotov and Hpfield
a_KH = synaptic_activation(synapses, input_flat.T)

# me
llmodel = LocalLearning.LocalLearningModel(pSet)
llmodel.to(device)
llmodel.train()
llmodel.W = torch.nn.Parameter(synapses.T, requires_grad=False)
a_me = llmodel(input_v.to(device))

# test for equality
print("Krotov and Hopfields and Konstantin's implemention is equivalent: ", torch.equal(a_KH.T, a_me))

Krotov and Hopfields and Konstantin's implemention is equivalent:  True


## Compare learning activations

In [8]:
# find indices maximizing the synapse
_, indices = a_KH.topk(k, dim=0)
# g(Q) learning activation function
g_i = learning_activation(indices)

In [9]:
from torch import Tensor

# Konstantin's implementation
def g(q: Tensor) -> Tensor:
    g_q = torch.zeros(q.size(), device=q.device)
    _, sorted_idxs = q.topk(k, dim=-1)
    batch_size = g_q.size(dim=0)
    g_q[range(batch_size), sorted_idxs[:, 0]] = 1.0
    g_q[range(batch_size), sorted_idxs[:, -1]] = -delta
    return g_q

In [10]:
torch.equal(g_i.T, g(a_KH.T))

True

## Compare weight increments

Calculate the weight increment for the Krotov and Hopfield implementation

In [65]:
xx = (g_i * a_KH).sum(dim=1)
ds = torch.matmul(g_i, input_flat) - xx.unsqueeze(1) * synapses

Calculate the weight increment on my own

In [12]:
#v = input_flat
#h = a_KH.T
#W = synapses.T
#inc = (1.0**p) * v[..., None] - torch.mul(h[:, None, ...], W)
#inc = torch.mul(g(h)[:, None, ...], inc).sum(dim=0)

In [54]:
v = input_flat
h = a_KH.T
W = synapses.T
R = 1.0
#inc = (R**p)*v
#inc = (R**p) * v[..., None] - torch.mul(h[:, None, ...], W)
#inc = torch.mul(g(h)[:, None, ...], inc).sum(dim=0)

In [80]:
g_mu = g(h)
inc = R**p*(v.T@g_mu)-(g_mu*h).sum(dim=0)[None, ...]*W

In [81]:
torch.isclose(ds.T, inc, atol=1e-6, rtol=1e-5).all()

tensor(True, device='cuda:0')

In [82]:
torch.equal(ds.T, inc)

False