<a href="https://colab.research.google.com/github/isottongloria/PMLS_Bio-Learning/blob/main/scripts/multilayer_bio_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Unsupervised learning by competing hidden units**

 ### Author

 - Gloria Isotton
 - Master degree in Physics of data

### Introduction

In this notebook, we extend the approach presented in the paper to develop a more sophisticated unsupervised learning framework capable of training a **multi-layer network**. The goal is to leverage the unsupervised learning paradigm in a more complex network structure while incorporating **static excitatory and inhibitory synaptic connections** to improve the realism and performance of the model.

The extension involves multiple layers of neurons, where each layer learns independently in an unsupervised manner, with connections between layers being modulated by excitatory and inhibitory synapses. This setup mimics the dynamics of biological neural networks more closely, where the balance between excitation and inhibition plays a crucial role in network stability, learning efficiency, and the emergence of meaningful representations.

In the following sections, we will present the modifications made to the original framework, explain the design choices for synaptic regulation, and showcase the performance of the multi-layer model on several benchmark datasets. The code and the experiments demonstrate how unsupervised learning can be scaled to deeper architectures while preserving biological plausibility through synaptic balance.

In [1]:
# Importing dependencies

import torch
import torchvision
from PIL import Image
from torch import nn,save,load
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, TensorDataset
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from torch.optim.lr_scheduler import MultiStepLR
import os


In [3]:
# Cuda Parameters
use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()
device = torch.device("cuda" if use_cuda else "cpu")

## Dataset loading
Variable containing the dataset name `data_name`:
- `1` : MNIST https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html?ref=hackernoon.com#torchvision.datasets.MNIST <br>
- `2` : CIFAR10 https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html?ref=hackernoon.com#torchvision.datasets.CIFAR10<br>
- `3` : FashionMNIST https://pytorch.org/vision/stable/generated/torchvision.datasets.FashionMNIST.html?ref=hackernoon.com#torchvision.datasets.FashionMNIST<br>

In [2]:
data_name=3

####### MNIST dataset ########
if data_name == 1:
    transform = transforms.Compose([transforms.ToTensor(),
                                   transforms.Lambda(lambda x: x.view(-1))])
    full_train_dataset = datasets.MNIST(root="data", download=True, train=True, transform=transform)

    train_size = 50000
    val_size = 10000
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=28, shuffle=True)
    test_loader = DataLoader(val_dataset, batch_size=28, shuffle=False)

    print(f"Training set size: {len(train_loader.dataset)}")
    print(f"Test set size: {len(test_loader.dataset)}")


####### CIFAR10 dataset ########
if data_name == 2:
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   transforms.Lambda(lambda x: x.view(-1))])
    full_train_dataset = datasets.CIFAR10(root="data", download=True, train=True, transform=transform)

    train_size = 40000
    val_size = 10000
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    print(f"Training set size: {len(train_loader.dataset)}")
    print(f"Test set size: {len(test_loader.dataset)}")



####### FASHION MNIST dataset ########
if data_name == 3:
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),
                                   transforms.Lambda(lambda x: x.view(-1))])

    full_train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True,download=True, transform=transform)
    train_size = 50000
    val_size = 10000
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=28,shuffle=True)
    test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=28,shuffle=False)

    print(f"Training set size: {len(train_loader.dataset)}")
    print(f"Test set size: {len(test_loader.dataset)}")



Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:01<00:00, 16.2MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 276kB/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:00<00:00, 5.00MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 7.36MB/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Training set size: 50000
Test set size: 10000





In [4]:
input_data = torch.stack([data[0].flatten() for data in train_dataset]).to(device)
n_input = input_data.shape[1]  # Number of input neurons
shuffled_epoch_data = input_data[torch.randperm(input_data.shape[0]), :]
i=10
batch_size = 100
mini_batch = shuffled_epoch_data[i:i + batch_size, :].to(device)


In [6]:
 mini_batch = mini_batch.transpose(0, 1)
 mini_batch.shape

torch.Size([784, 100])

In [23]:
def unsupervised_bio_learning(train_dataset, n_hidden=28, n_epochs=500, batch_size=100,
                              learning_rate=0.01, precision=0.1,
                              anti_hebbian_learning_strength=0.3,
                              lebesgue_norm=2, rank=5, skip=1):
    """
    Unsupervised bio learning function.

    Parameters:
    - train_dataset: Input dataset (torch.Tensor), where each row is a training example.
    - n_input: number of input neurons
    - n_hidden: Number of hidden units (neurons).
    - n_epochs: Number of epochs to train the model.
    - batch_size: The size of the minibatch used to update the weights.
    - learning_rate: Initial learning rate that decreases over epochs.
    - precision: A threshold to normalize the gradient to avoid very small updates.
    - anti_hebbian_learning_strength: Strength of anti-Hebbian learning (penalizing neurons with low activation).
    - lebesgue_norm: Parameter for the Lebesgue norm used to weigh the contributions of the weights.
    - rank: Number of hidden neurons that are penalized using anti-Hebbian learning.
    - skip: Print the number of epochs every skip-times.
    """

    # Flatten the input data and determine input size
    input_data = torch.stack([data[0].flatten() for data in train_dataset]).to(device)
    n_input = input_data.shape[1]  # Number of input neurons

    # Initialize synapse weights
    synapses = torch.rand((n_hidden, n_input), dtype=torch.float, device=device)

    # Loop over epochs
    for epoch in range(n_epochs):
        if (epoch % skip == 0):
            print('Epoch -->', epoch)

        eps = learning_rate * (1 - epoch / n_epochs)  # Decaying learning rate

        # Shuffle dataset once for the entire epoch
        shuffled_epoch_data = input_data[torch.randperm(input_data.shape[0]), :]

        # Loop through minibatches
        for i in range(0, len(train_dataset), batch_size):
            mini_batch = shuffled_epoch_data[i:i + batch_size, :].to(device)
            mini_batch = mini_batch.transpose(0, 1)  # Transpose for correct shape (n_input, batch_size)

            # --- Currents --- #
            sign = torch.sign(synapses).to(device)  # Sign of weights
            tot_input = torch.mm(sign*torch.abs(synapses).pow(lebesgue_norm-1), mini_batch) # Weight matrix raised to the power of (p-1)
                                                                               # Compute the total input to the hidden layer
            # --- Activation --- #
            y = torch.argsort(tot_input, dim=0).to(device)  # Sort activations

            # Initialize the Hebbian and anti-Hebbian activations matrix
            yl = torch.zeros((n_hidden, batch_size)).to(device)

            # Apply Hebbian learning (max activation)
            yl[y[n_hidden - 1, :], torch.arange(batch_size, device=device)] = 1.0

            # Apply anti-Hebbian learning (penalize lowest activations)
            yl[y[n_hidden - rank, :], torch.arange(batch_size, device=device)] = -anti_hebbian_learning_strength

            # Compute the contribution of the activations on the total input received
            xx = torch.sum(torch.mul(yl, tot_input), 1)  # Sum over batch dimension

            # --- Compute change of weights --- #
            #ds = torch.mm(yl, mini_batch.transpose(0, 1)) - xx * synapses
            ds = torch.matmul(yl, torch.transpose(mini_batch, 0, 1)) - torch.mul(xx.reshape(xx.shape[0],1).repeat(1, n_input), synapses)

            # Normalize the gradient to prevent very large or very small updates
            nc = torch.max(torch.abs(ds))
            if nc < precision:
                nc = precision

            # Update the synapse weights
            synapses += torch.mul(torch.div(ds, nc), eps)

    return synapses


In [25]:
synapses = unsupervised_bio_learning(train_dataset, n_hidden=2000, n_epochs=500, batch_size=100,
                              learning_rate=0.01, precision=0.1,
                              anti_hebbian_learning_strength=0.3,
                              lebesgue_norm=2, rank=5, skip=1)

Epoch --> 0
Epoch --> 1
Epoch --> 2
Epoch --> 3
Epoch --> 4
Epoch --> 5
Epoch --> 6
Epoch --> 7
Epoch --> 8
Epoch --> 9
Epoch --> 10
Epoch --> 11
Epoch --> 12
Epoch --> 13
Epoch --> 14
Epoch --> 15
Epoch --> 16
Epoch --> 17
Epoch --> 18
Epoch --> 19
Epoch --> 20
Epoch --> 21
Epoch --> 22
Epoch --> 23
Epoch --> 24
Epoch --> 25
Epoch --> 26
Epoch --> 27
Epoch --> 28
Epoch --> 29
Epoch --> 30
Epoch --> 31
Epoch --> 32
Epoch --> 33
Epoch --> 34
Epoch --> 35
Epoch --> 36
Epoch --> 37
Epoch --> 38
Epoch --> 39
Epoch --> 40
Epoch --> 41
Epoch --> 42
Epoch --> 43
Epoch --> 44
Epoch --> 45
Epoch --> 46
Epoch --> 47
Epoch --> 48
Epoch --> 49
Epoch --> 50
Epoch --> 51
Epoch --> 52
Epoch --> 53
Epoch --> 54
Epoch --> 55
Epoch --> 56
Epoch --> 57
Epoch --> 58
Epoch --> 59
Epoch --> 60
Epoch --> 61
Epoch --> 62
Epoch --> 63
Epoch --> 64
Epoch --> 65


KeyboardInterrupt: 

## Draw weights

In [21]:
def draw_weights(synapses, Kx, Ky, ax=None):
    # synapses: the weights
    Kw = int(np.sqrt(synapses.shape[1]//3)) # i.e. 32
    yy=0
    HM=np.zeros((Kw*Ky, Kw*Kx, 3))
    for y in range(Ky):
        for x in range(Kx):
            HM[y*Kw:(y+1)*Kw,x*Kw:(x+1)*Kw]=synapses[yy,:Kw*Kw*3].reshape(Kw, Kw, 3)
            yy += 1

    nc=np.amax(np.absolute(HM))
    tmp = (HM-HM.min())
    tmp /= tmp.max()
    tmp *= 255
    tmp = tmp.astype(np.uint8)
    if ax is not None:
        im = ax.imshow(tmp)
        ax.axis('off')
    else:
        plt.clf()
        im=plt.imshow(tmp.astype(np.uint8))
        plt.axis('off')
    fig.canvas.draw()

# %matplotlib inline
%matplotlib notebook
fig=plt.figure(figsize=(10, 7))
draw_weights(synapses.to('cpu'), 20, 20)
plt.tight_layout()

<IPython.core.display.Javascript object>

IndexError: index 28 is out of bounds for dimension 0 with size 28

In [22]:
synapses.shape

torch.Size([28, 784])

In [20]:
fig=plt.figure(figsize=(10, 7))
draw_weights(synapses.to('cpu'), 5, 5)
plt.tight_layout()
plt.savefig('FASHIONMNIST_synapses_hidden.png')
plt.show()

<IPython.core.display.Javascript object>

## Multi layer linear bio classifier

In [None]:
def multi_layer_unsupervised_bio_learning(train_dataset, layers_config, n_epochs=2, batch_size=32,
                                          learning_rate=0.01, precision=0.1, anti_hebbian_learning_strength=0.3,
                                          lebesgue_norm=2, rank=5):
    """
    Trains a multi-layer unsupervised model using a layer-wise approach.

    Parameters:
    - train_dataset: Input dataset (torch.Tensor), where each row is a training example.
    - layers_config: A list with the number of hidden units for each layer.
    - Other parameters are the same as for the unsupervised_bio_learning_layer function.

    Returns:
    - synapses_list: List of trained weights for each layer.
    """

    synapses_list = []

    # The number of input units for the current layer is the number of features from the previous layer
    input_data = torch.stack([data[0].flatten() for data in train_dataset])
    n_input = input_data.shape[1]

    for idx in range(len(layers_config)):

        # number of output neurons/hidden layers
        n_hidden = layers_config[idx]
        print(f"Training layer {idx + 1} with {n_hidden} hidden units.")
        print(f"Training with {n_input} input units.")

        # compute the unsupervised weights
        synapses = unsupervised_bio_learning(train_dataset, n_input, n_hidden,
                                            n_epochs, batch_size, learning_rate, precision,
                                            anti_hebbian_learning_strength, lebesgue_norm, rank)

        synapses_list.append(synapses)

        # Do a forward pass to infer the next input == 'train_dataset'
        model = BioClassifier(synapses, out_features=n_hidden).to(device)
        outputs = model(input_data)

        n_input = outputs.shape[1]  # Pass the output to the next layer
        print('next input', n_input)
        restored_dataset = TensorDataset(outputs)   #.view(-1, 1, 28, 28)

        # Update variables
        train_dataset = restored_dataset
        input_data = torch.stack([data[0].flatten() for data in train_dataset])

    return synapses_list  # Return the list of all layers' weights


In [None]:
layers_config = [20, 10]
synapses_list = multi_layer_unsupervised_bio_learning(train_dataset, layers_config, n_epochs=2, batch_size=32,
                                          learning_rate=0.01, precision=0.1, anti_hebbian_learning_strength=0.3,
                                          lebesgue_norm=2, rank=5)