This Notebook provides a minimal example for using LFP to train a simple LeNet on MNIST.

For more complex examples, refer to the experiment notebooks in ./nbs

### Imports

In [8]:
import os

import numpy as np
import torch
import torch.nn as tnn
import torcheval.metrics
import torchvision.datasets as tvisiondata
import torchvision.transforms as T
from tqdm import tqdm

from lfprop.propagation import (
    propagator_lxt as propagator,
)  # LFP propagator. Alternatively, use propagator_zennit
from lfprop.rewards import reward_functions as rewards  # Reward Functions

### Parameters

In [9]:
savepath = "./minimal-example-data"
os.makedirs(savepath, exist_ok=True)

batch_size = 128
n_channels = 1
n_outputs = 10

lr = 0.1
momentum = 0.9
epochs = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Dataset

In [10]:
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])
training_data = tvisiondata.MNIST(
    root=savepath,
    transform=transform,
    download=True,
    train=True,
)

validation_data = tvisiondata.MNIST(
    root=savepath,
    transform=transform,
    download=True,
    train=False,
)

training_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, shuffle=False)

### Load Model

In [11]:
class LeNet(tnn.Module):
    """
    Small LeNet
    """

    def __init__(self, n_channels, n_outputs, activation=tnn.ReLU):
        super().__init__()

        # Feature extractor
        self.features = tnn.Sequential(
            tnn.Conv2d(n_channels, 16, 5),
            activation(),
            tnn.MaxPool2d(2, 2),
            tnn.Conv2d(16, 16, 5),
            activation(),
            tnn.MaxPool2d(2, 2),
        )

        # Classifier
        self.classifier = tnn.Sequential(
            tnn.Linear(256 if n_channels == 1 else 400, 120),
            activation(),
            tnn.Dropout(),
            tnn.Linear(120, 84),
            activation(),
            tnn.Dropout(),
            tnn.Linear(84, n_outputs),
        )

    def forward(self, x):
        """
        forwards input through network
        """

        # Forward through network
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        # Return output
        return x


model = LeNet(
    n_channels=n_channels,
    n_outputs=n_outputs,
    activation=torch.nn.ReLU,
)

def name_modules(module, name):
    """
    Recursive function to name modules for debugging 
    """
    
    for cname, child in module.named_children():
        child.tmpname = cname if name == "" else f"{name}.{cname}"
        name_modules(child, child.tmpname)

name_modules(model, "")

model.to(device)
model.eval()

LeNet(
  (features): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=120, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=120, out_features=84, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=84, out_features=10, bias=True)
  )
)

### Set Up LFP

In [12]:
# Initialize the LFP Composite (cf. "composites" in zennit or lxt).
# This call is the same whether the lxt or zennit backend is used (propagator_lxt and propagator_zennit).
# Currently, only LFP-Epsilon is implemented. More composites may be added in the future.
propagation_composite = propagator.LFPEpsilonComposite()
#propagation_composite = propagator.LFPHebbianEpsilonComposite(use_oja=True)
#propagation_composite = propagator.LFPGammaComposite(gamma=0.0)

# Initialize the Reward Function.
# Here we use the Reward Function suggested in the LFP-Paper, but check out other reward functions in ./lfp/rewards/reward_functions.py
reward_func = rewards.SoftmaxLossReward(device)

# LFP writes its updates into the .grad attribute of the model parameters, and can thus utilize standard torch optimizers
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

### Set Up Simple Evaluation using torcheval

In [13]:
def eval_model(loader):
    """
    Evaluates the model on a single dataset
    """
    eval_metrics = {
        "reward": torcheval.metrics.Mean(device=device),
        "accuracy": torcheval.metrics.MulticlassAccuracy(average="micro", num_classes=10, k=1, device=device),
    }

    model.eval()

    # Iterate over Data Loader
    for index, (inputs, labels) in enumerate(loader):
        inputs = inputs.to(device)
        labels = torch.tensor(labels).to(device)

        with torch.no_grad():
            # Get model predictions
            outputs = model(inputs)

        with torch.set_grad_enabled(True):
            # Get rewards
            reward = reward_func(outputs, labels)

        for k, v in eval_metrics.items():
            if k == "reward":
                eval_metrics[k].update(reward)
            else:
                eval_metrics[k].update(outputs, labels)

    return_dict = {m: metric.compute().detach().cpu().numpy() for m, metric in eval_metrics.items()}

    # Return evaluation
    return return_dict

### Training Loop

In [None]:
def lfp_step(inputs, labels):
    """
    Performs a single training step using LFP. This is quite similar to a standard gradient descent training loop.
    """
    # Set Model to training mode
    model.train()

    with torch.enable_grad():
        # Zero Optimizer
        optimizer.zero_grad()

        # This applies LFP Hooks/Functions (which depends on whether lxt or zennit backend is used)
        with propagation_composite.context(model) as modified:
            inputs = inputs.detach().requires_grad_(True)
            outputs = modified(inputs)

            # Calculate reward
            # Do like this to avoid tensors being kept in memory
            reward = torch.from_numpy(reward_func(outputs, labels).detach().cpu().numpy()).to(device)

            # Calculate LFP and write into .feedback attribute of parameters
            torch.autograd.grad((outputs,), (inputs,), grad_outputs=(reward,), retain_graph=False)[0]

            # Write LFP Values into .grad attributes. Note the negative sign: LFP requires maximization instead of minimization like gradient descent
            for name, param in model.named_parameters():
                param.grad = -param.feedback

            # Update Clipping. Training may become unstable otherwise, especially in small models with large learning rates.
            # In larger models (e.g., VGG, ResNet), where smaller learning rates are generally utilized, not clipping updates may result in better performance.
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)

            # Optimization step
            optimizer.step()

    # Set Model back to eval mode
    model.eval()


# Training Loop
for epoch in range(epochs):
    with tqdm(total=len(training_loader)) as pbar:
        # Iterate over Data Loader
        for index, (inputs, labels) in enumerate(training_loader):
            inputs = inputs.to(device)
            labels = torch.tensor(labels).to(device)

            # Perform Update Step
            lfp_step(inputs, labels)

            # Update Progress Bar
            pbar.update(1)

    # Evaluate and print performance after every epoch
    eval_stats_train = eval_model(training_loader)
    eval_stats_val = eval_model(validation_loader)
    print(
        "Epoch {}/{}: (Train Reward) {:.2f}; (Train Accuracy) {:.2f}; (Val Reward) {:.2f}; (Val Accuracy) {:.2f}".format(
            epoch + 1,
            epochs,
            float(np.mean(eval_stats_train["reward"])),
            float(eval_stats_train["accuracy"]),
            float(np.mean(eval_stats_val["reward"])),
            float(eval_stats_val["accuracy"]),
        )
    )

  labels = torch.tensor(labels).to(device)
  1%|          | 3/469 [00:00<00:19, 23.63it/s]

<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>


  2%|▏         | 9/469 [00:00<00:19, 24.19it/s]

<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>


  3%|▎         | 15/469 [00:00<00:18, 24.22it/s]

<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>


  4%|▍         | 18/469 [00:00<00:18, 24.16it/s]

<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>


  5%|▌         | 24/469 [00:01<00:18, 23.48it/s]

<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>


  6%|▋         | 30/469 [00:01<00:18, 24.10it/s]

<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>


  7%|▋         | 33/469 [00:01<00:18, 23.71it/s]

<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>


  8%|▊         | 39/469 [00:01<00:18, 23.36it/s]

<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>


 10%|▉         | 45/469 [00:01<00:18, 23.49it/s]

<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>


 10%|█         | 48/469 [00:02<00:18, 23.31it/s]

<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
<torch.autograd.function.epsilon_lfp_fnBackward object at 0x7f26d280df20>
