
# Local network

- self-supervised learning on MNIST

- Apply feedback alignment to both layers in the PV class (fc1 and fc2), not just the second one (fc2)? 


In [1]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

  Referenced from: <E03EDA44-89AE-3115-9796-62BA9E0E2EDE> /Users/irismarmouset-delataille/mambaforge/envs/localglobal/lib/python3.11/site-packages/torchvision/image.so
  warn(


<torch._C.Generator at 0x1161f2a10>

### Dataset

In [2]:
from dataset import get_mnist_dataset
working_directory = os.getcwd()
data_dir = os.path.join(working_directory, './dataset')
batch_size_train = 64
batch_size_test = 1000

train_data_loader, test_data_loader = get_mnist_dataset(data_dir, batch_size_train=batch_size_train, batch_size_test=batch_size_test)

print('Train dataset size: {}'.format(len(train_data_loader.dataset)))
print('Test dataset size: {}'.format(len(test_data_loader.dataset)))

Train dataset size: 60000
Test dataset size: 10000


### Model

In [3]:
from modules.network import LocalNetwork

lr = 0.001

model = LocalNetwork(input_dim=784, latent_dim=128)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()

print(model)

LocalNetwork(
  (PV): PV(
    (flatten): Flatten()
    (fc1): Linear(in_features=784, out_features=784, bias=True)
    (activation): Sigmoid()
    (fc2): Linear(in_features=784, out_features=128, bias=True)
  )
  (Pyr): Pyr(
    (flatten): Flatten()
    (fc1): Linear(in_features=784, out_features=128, bias=True)
    (activation): Sigmoid()
  )
  (Decoder): Decoder(
    (fc1): Linear(in_features=128, out_features=784, bias=True)
  )
)


### Training & testing

In [7]:
num_epochs = 10

train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_data_loader.dataset) for i in range(num_epochs + 1)]

def train(epoch, dataloader):
    model.train()
    for batch in dataloader:
        inputs, _ = batch  

        optimizer.zero_grad()
        _, _, _, recon = model(inputs, inputs)  # same inputs for both PV and pyramidal cells
        inputs = inputs.view(-1, 784)
        loss = criterion(recon, inputs)
        loss.backward()
        optimizer.step()

    print('Epoch: {} Loss: {:.6f}'.format(epoch, loss.item()))

def test(epoch, dataloader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            inputs, _ = batch
            _, _, _, recon = model(inputs, inputs)
            inputs = inputs.view(-1, 784)
            test_loss += criterion(recon, inputs).item()
    test_loss /= len(dataloader.dataset)
    test_losses.append(test_loss)
    print('\n Epoch: {} Test set: Avg. loss: {:.4f}\n'.format(epoch, test_loss))

### In practice

In [9]:
test(test_data_loader)

for epoch in range(1, num_epochs + 1):
  train(epoch, train_data_loader)
  test(test_data_loader)


Test set: Avg. loss: 0.0011

Epoch: 1 Loss: 0.123183

Test set: Avg. loss: 0.0001

Epoch: 2 Loss: 0.092175

Test set: Avg. loss: 0.0001

Epoch: 3 Loss: 0.084952

Test set: Avg. loss: 0.0001

Epoch: 4 Loss: 0.061202

Test set: Avg. loss: 0.0001

Epoch: 5 Loss: 0.055561

Test set: Avg. loss: 0.0001

Epoch: 6 Loss: 0.056677

Test set: Avg. loss: 0.0001

Epoch: 7 Loss: 0.067480

Test set: Avg. loss: 0.0001

Epoch: 8 Loss: 0.056903

Test set: Avg. loss: 0.0001

Epoch: 9 Loss: 0.052013

Test set: Avg. loss: 0.0001

Epoch: 10 Loss: 0.048799

Test set: Avg. loss: 0.0001

