In [3]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

  from .autonotebook import tqdm as notebook_tqdm


# BP Baseline

In [6]:
from genhebb import FastMNIST

In [20]:
class baseline(nn.Module):
    def __init__(self):
        super(baseline, self).__init__()
        self.input = nn.Linear(28 * 28, 2000)
        self.hidden = nn.Linear(2000, 2000)
        self.output = nn.Linear(2000, 10)
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.input(x))
        x = F.relu(self.hidden(x))
        x = self.output(x)
        return x

In [None]:
epochs = 50
learning_rate = 0.001
batch_size = 64

In [23]:
trainset = FastMNIST('./data', train=True, download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

In [16]:
testset = FastMNIST('./data', train=False, download=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# train baseline model (BP end-to-end)
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in trainloader:

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # update training statistics
        running_loss += loss.item()
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(trainloader)}')

In [None]:
model.eval()
running_loss = 0.
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        # calculate outputs by running images through the network
        outputs = model(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        loss = criterion(outputs, labels)
        running_loss += loss.item()

print(f'Test accuracy: {100 * correct / total} %')
print(f'test loss: {running_loss / total:.3f}')

**Results from Cajal:**

Training...
Epoch [1/50], Loss: 0.188
Epoch [2/50], Loss: 0.080
Epoch [3/50], Loss: 0.058
Epoch [4/50], Loss: 0.043
Epoch [5/50], Loss: 0.034
Epoch [6/50], Loss: 0.030
Epoch [7/50], Loss: 0.027
Epoch [8/50], Loss: 0.021
Epoch [9/50], Loss: 0.026
Epoch [10/50], Loss: 0.023
Epoch [11/50], Loss: 0.014
Epoch [12/50], Loss: 0.017
Epoch [13/50], Loss: 0.017
Epoch [14/50], Loss: 0.017
Epoch [15/50], Loss: 0.016
Epoch [16/50], Loss: 0.013
Epoch [17/50], Loss: 0.013
Epoch [18/50], Loss: 0.017
Epoch [19/50], Loss: 0.010
Epoch [20/50], Loss: 0.014
Epoch [21/50], Loss: 0.012
Epoch [22/50], Loss: 0.011
Epoch [23/50], Loss: 0.012
Epoch [24/50], Loss: 0.016
Epoch [25/50], Loss: 0.010
Epoch [26/50], Loss: 0.009
Epoch [27/50], Loss: 0.009
Epoch [28/50], Loss: 0.009
Epoch [29/50], Loss: 0.016
Epoch [30/50], Loss: 0.008
Epoch [31/50], Loss: 0.013
Epoch [32/50], Loss: 0.009
Epoch [33/50], Loss: 0.012
Epoch [34/50], Loss: 0.008
Epoch [35/50], Loss: 0.013
Epoch [36/50], Loss: 0.012
Epoch [37/50], Loss: 0.015
Epoch [38/50], Loss: 0.004
Epoch [39/50], Loss: 0.019
Epoch [40/50], Loss: 0.013
Epoch [41/50], Loss: 0.008
Epoch [42/50], Loss: 0.010
Epoch [43/50], Loss: 0.009
Epoch [44/50], Loss: 0.012
Epoch [45/50], Loss: 0.011
Epoch [46/50], Loss: 0.006
Epoch [47/50], Loss: 0.009
Epoch [48/50], Loss: 0.011
Epoch [49/50], Loss: 0.008
Epoch [50/50], Loss: 0.011
Testing...
Accuracy: 98.33 %
Loss: 0.004

# Hebb's Rule

In [201]:
class HebbLinear(nn.Module):
    def __init__(
            self,
            input_dim: int,
            output_dim: int,
    ) -> None:
        """
        ...
        """
        super(HebbLinear, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.W = nn.Parameter(torch.randn(input_dim, output_dim))

    def forward(self, x):
        # standard forward pass
        y = torch.matmul(x, self.W)

        # compute Oja's learning rule (ΔW = x * y^T) and store in grad
        if self.training:
            dW = x.unsqueeze(-1) * y.unsqueeze(-2)
            if dW.dim() > 2:
                dW = torch.sum(dW, 0)
            self.W.grad = dW

        return y

**Results from Cajal (1 epoch):**

Epoch [1/50]
Train loss: 0.610, 	 train accuracy: 81 %
Test loss: 0.080, 	 test accuracy: 86 %

Epoch [11/50]
Train loss: 0.215, 	 train accuracy: 93 %
Test loss: 0.076, 	 test accuracy: 89 %

Epoch [21/50]
Train loss: 0.177, 	 train accuracy: 94 %
Test loss: 0.094, 	 test accuracy: 89 %

Epoch [31/50]
Train loss: 0.146, 	 train accuracy: 95 %
Test loss: 0.111, 	 test accuracy: 89 %

Epoch [41/50]
Train loss: 0.132, 	 train accuracy: 96 %
Test loss: 0.132, 	 test accuracy: 89 %

Epoch [50/50]
Train loss: 0.124, 	 train accuracy: 96 %
Test loss: 0.141, 	 test accuracy: 89 %

# Alternative Baseline: Random Frozen Weights