This Notebook provides a minimal example for using LFP to train a simple LeNet on MNIST.

For more complex examples, refer to the experiment notebooks in ./nbs

### Imports

In [1]:
import os
import joblib
import random
import time

import numpy as np
import torch
import torch.nn as tnn
import torcheval.metrics
import torchvision.datasets as tvisiondata
import torchvision.transforms as T
from tqdm import tqdm

from experiment_utils.model.models import ACTIVATION_MAP

from lfprop.propagation import (
    propagator_lxt as propagator,
)  # LFP propagator. Alternatively, use propagator_zennit
from lfprop.rewards import reward_functions as rewards  # Reward Functions
from lfprop.rewards import rewards as loss_fns
from torch_pso import ParticleSwarmOptimizer
from fa import *
from dladmm import dladmm 
from dladmm import input_data as dladmm_data

  from .autonotebook import tqdm as notebook_tqdm


### Parameters

In [None]:
model_name = "lenet"
method_name = "vanilla-gradient" # lfp-epsilon, vanilla-gradient, pso, fa, dladmm | TODO ldtp, ga
seed = 0
epochs = 50

data_path = "/media/lweber/f3ed2aae-a7bf-4a55-b50d-ea8fb534f1f52/Datasets/mnist"
savepath = f"/media/lweber/f3ed2aae-a7bf-4a55-b50d-ea8fb534f1f52/reward-backprop/resubmission-1-experiments/clocktime-comparison/{method_name}-{model_name}-{seed}-{epochs}"
os.makedirs(savepath, exist_ok=True)

n_channels = 1
n_outputs = 10
batch_size = 128

general_params = {
    "n_channels": n_channels,
    "n_outputs": n_outputs,
    "batch_size": batch_size,
    "epochs": epochs
}

def set_random_seeds(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False

set_random_seeds(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Dataset

In [3]:
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])
training_data = tvisiondata.MNIST(
    root=data_path,
    transform=transform,
    download=True,
    train=True,
)

testing_data = tvisiondata.MNIST(
    root=data_path,
    transform=transform,
    download=True,
    train=False,
)

training_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
testing_loader = torch.utils.data.DataLoader(testing_data, batch_size=batch_size, shuffle=False)

### Load Model

In [4]:
class MLP(tnn.Module):
    """
    Small MLP
    """

    def __init__(self, n_channels, n_outputs, activation=tnn.ReLU):
        super().__init__()

        # Classifier
        self.classifier = tnn.Sequential(
            tnn.Linear(28*28, 120),
            activation(),
            tnn.Linear(120, 84),
            activation(),
            tnn.Linear(84, n_outputs),
        )

    def forward(self, x):
        """
        forwards input through network
        """

        # Forward through network
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        # Return output
        return x
    
class FaMLP(tnn.Module):
    """
    Small MLP supporting feedback alignment
    """

    def __init__(self, n_channels, n_outputs, activation=tnn.ReLU):
        super().__init__()

        # Classifier
        self.classifier = tnn.Sequential(
            LinearFA(28*28, 120),
            activation(),
            LinearFA(120, 84),
            activation(),
            LinearFA(84, n_outputs),
        )

    def forward(self, x):
        """
        forwards input through network
        """

        # Forward through network
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        # Return output
        return x

class LeNet(tnn.Module):
    """
    Small LeNet
    """

    def __init__(self, n_channels, n_outputs, activation=tnn.ReLU):
        super().__init__()

        # Feature extractor
        self.features = tnn.Sequential(
            tnn.Conv2d(n_channels, 16, 5),
            activation(),
            tnn.MaxPool2d(2, 2),
            tnn.Conv2d(16, 16, 5),
            activation(),
            tnn.MaxPool2d(2, 2),
        )

        # Classifier
        self.classifier = tnn.Sequential(
            tnn.Linear(256 if n_channels == 1 else 400, 120),
            activation(),
            tnn.Dropout(),
            tnn.Linear(120, 84),
            activation(),
            tnn.Dropout(),
        )

        self.last = tnn.Linear(84, n_outputs)
        
    def forward(self, x):
        """
        forwards input through network
        """

        # Forward through network
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        x = self.last(x)

        # Return output
        return x
    
class FaLeNet(tnn.Module):
    """
    Small LeNet supporting feedback alignment
    """

    def __init__(self, n_channels, n_outputs, activation=tnn.ReLU):
        super().__init__()

        # Feature extractor
        self.features = tnn.Sequential(
            Conv2dFA(n_channels, 16, 5),
            activation(),
            tnn.MaxPool2d(2, 2),
            Conv2dFA(16, 16, 5),
            activation(),
            tnn.MaxPool2d(2, 2),
        )

        # Classifier
        self.classifier = tnn.Sequential(
            LinearFA(256 if n_channels == 1 else 400, 120),
            activation(),
            tnn.Dropout(),
            LinearFA(120, 84),
            activation(),
            tnn.Dropout(),
        )

        self.last = LinearFA(84, n_outputs)

    def forward(self, x):
        """
        forwards input through network
        """

        # Forward through network
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        x = self.last(x)

        # Return output
        return x

def name_modules(module, name):
    """
    Recursive function to name modules for debugging 
    """
    
    for cname, child in module.named_children():
        child.tmpname = cname if name == "" else f"{name}.{cname}"
        name_modules(child, child.tmpname)

### Evaluation and Training Helpers

In [5]:
def eval_model(model, loader, objective_func):
    """
    Evaluates the model on a single dataset
    """
    eval_metrics = {
        "objective": torcheval.metrics.Mean(device=device),
        "accuracy": torcheval.metrics.MulticlassAccuracy(average="micro", num_classes=10, k=1, device=device),
    }

    model.eval()

    # Iterate over Data Loader
    for index, (inputs, labels) in enumerate(loader):
        inputs = inputs.to(device)
        labels = torch.tensor(labels).to(device)

        with torch.no_grad():
            # Get model predictions
            outputs = model(inputs)

        with torch.set_grad_enabled(True):
            # Get rewards
            objective = objective_func(outputs, labels)

        for k, v in eval_metrics.items():
            if k == "objective":
                eval_metrics[k].update(objective)
            else:
                eval_metrics[k].update(outputs, labels)

    return_dict = {m: metric.compute().detach().cpu().numpy() for m, metric in eval_metrics.items()}

    # Return evaluation
    return return_dict

def lfp_step(model, optimizer, objective_func, propagation_composite, inputs, labels):
    """
    Performs a single training step using LFP. This is quite similar to a standard gradient descent training loop.
    """
    # Set Model to training mode
    model.train()

    with torch.enable_grad():
        # Zero Optimizer
        optimizer.zero_grad()

        # This applies LFP Hooks/Functions (which depends on whether lxt or zennit backend is used)
        with propagation_composite.context(model) as modified:
            inputs = inputs.detach().requires_grad_(True)
            outputs = modified(inputs)

            # Calculate reward
            # Do like this to avoid tensors being kept in memory
            reward = torch.from_numpy(objective_func(outputs, labels).detach().cpu().numpy()).to(device)

            # Calculate LFP and write into .feedback attribute of parameters
            torch.autograd.grad((outputs,), (inputs,), grad_outputs=(reward,), retain_graph=False)[0]

            # Write LFP Testues into .grad attributes. Note the negative sign: LFP requires maximization instead of minimization like gradient descent
            for name, param in model.named_parameters():
                param.grad = -param.feedback

            # Update Clipping. Training may become unstable otherwise, especially in small models with large learning rates.
            # In larger models (e.g., VGG, ResNet), where smaller learning rates are generally utilized, not clipping updates may result in better performance.
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)

            # Optimization step
            optimizer.step()

    # Set Model back to eval mode
    model.eval()
    
def grad_step(model, optimizer, objective_func, inputs, labels):
    """
    Performs a single training step using Gradient Descent
    """
    # Set Model to training mode
    model.train()

    with torch.enable_grad():
        # Zero Optimizer
        optimizer.zero_grad()
            
        inputs = inputs.detach()
        outputs = model(inputs)

        # Calculate reward
        # Do like this to avoid tensors being kept in memory
        loss = objective_func(outputs, labels)
        loss.backward()

        # Update Clipping. Training may become unstable otherwise, especially in small models with large learning rates.
        # In larger models (e.g., VGG, ResNet), where smaller learning rates are generally utilized, not clipping updates may result in better performance.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)

        # Optimization step
        optimizer.step()

    # Set Model back to eval mode
    model.eval()

def pso_step(model, optimizer, objective_func, inputs, labels, max_steps):
    """
    Performs a single training step using PSO
    """
    # Set Model to training mode
    model.train()

    # def schedule_inertial_weight():
    #     if hasattr(optimizer, "inertial_weight_min") and hasattr(optimizer, "inertial_weight_max"):
    #         inertial_weight = optimizer.inertial_weight_max - (optimizer.inertial_weight_max - optimizer.inertial_weight_min)/max_steps * optimizer.current_step
    #         optimizer.inertial_weight = inertial_weight
    #         for particle in optimizer.particles:
    #             particle.inertial_weight = inertial_weight

    with torch.no_grad():
        # Zero Optimizer
        optimizer.zero_grad()
            
        inputs = inputs.detach()

        def closure():
            # Clear any grads from before the optimization step, since we will be changing the parameters
            optimizer.zero_grad()  
            return objective_func(model(inputs), labels) # VERY IMPORTANT that model forward is INSIDE closure

        # Optimization step
        optimizer.step(closure)
        #schedule_inertial_weight()

    # Set Model back to eval mode
    model.eval()


# Training Loop
def train(model, optimizer, objective_func, propagation_composite, **kwargs):
    
    evals = {
        "train_accuracy": [],
        "train_objective": [],
        "test_accuracy": [],
        "test_objective": [],
        "clock_time": [],
    }
    
    eval_stats_train = eval_model(model, training_loader, objective_func)
    eval_stats_test = eval_model(model, testing_loader, objective_func)
    print(
        "INIT: (Train Objective) {:.2f}; (Train Accuracy) {:.2f}; (Test Objective) {:.2f}; (Test Accuracy) {:.2f}".format(
            float(np.mean(eval_stats_train["objective"])),
            float(eval_stats_train["accuracy"]),
            float(np.mean(eval_stats_test["objective"])),
            float(eval_stats_test["accuracy"]),
        )
    )

    for epoch in range(epochs):
        # Iterate over Data Loader
        pre = time.time()
        for index, (inputs, labels) in enumerate(training_loader):
            inputs = inputs.to(device)
            labels = torch.tensor(labels).to(device)

            # Perform Update Step
            if propagation_composite is None:
                grad_step(model, optimizer, objective_func, inputs, labels)
            elif propagation_composite == "pso":
                pso_step(model, optimizer, objective_func, inputs, labels, max_steps = len(training_loader)*epochs)
            else:
                lfp_step(model, optimizer, objective_func, propagation_composite, inputs, labels)
            
            # Log zero ratios
            for cname, child in model.named_modules():
                if hasattr(child, "zeros_ratio"):
                    if f"zeros_{child.tmpname}" not in evals.keys():
                        evals[f"zeros_{child.tmpname}"] = []
                    evals[f"zeros_{child.tmpname}"].append(child.zeros_ratio)

        post = time.time()

        # Evaluate and print performance after every epoch
        eval_stats_train = eval_model(model, training_loader, objective_func)
        eval_stats_test = eval_model(model, testing_loader, objective_func)
        print(
            "Epoch {}/{}: (Train Objective) {:.2f}; (Train Accuracy) {:.2f}; (Test Objective) {:.2f}; (Test Accuracy) {:.2f}".format(
                epoch + 1,
                epochs,
                float(np.mean(eval_stats_train["objective"])),
                float(eval_stats_train["accuracy"]),
                float(np.mean(eval_stats_test["objective"])),
                float(eval_stats_test["accuracy"]),
            )
        )

        evals["train_accuracy"].append(float(eval_stats_train["accuracy"]))
        evals["train_objective"].append(float(eval_stats_train["objective"]))
        evals["test_accuracy"].append(float(eval_stats_test["accuracy"]))
        evals["test_objective"].append(float(eval_stats_test["objective"]))
        evals["clock_time"].append(post-pre)
        
    return evals

### Set Up Training Method

In [6]:
if method_name == "lfp-epsilon":
    model_class = MLP if model_name == "mlp" else LeNet
    model = model_class(
        n_channels=n_channels,
        n_outputs=n_outputs,
        activation=tnn.ReLU
    )
    name_modules(model, "")
    model.tmpname = "root"
    model.to(device)
    model.eval()

    training_cfg = {
        "model_class": model_class,
        "propagation_composite": propagator.LFPEpsilonComposite(),
        "objective_func": rewards.SoftmaxLossReward(device),
        "model": model,
        "optimizer": torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9),
    }
    training_func = train

elif method_name == "vanilla-gradient":
    model_class = MLP if model_name == "mlp" else LeNet
    model = model_class(
        n_channels=n_channels,
        n_outputs=n_outputs,
        activation=tnn.ReLU
    )
    name_modules(model, "")
    model.tmpname = "root"
    model.to(device)
    model.eval()

    training_cfg = {
        "model_class": model_class,
        "propagation_composite": None,
        "objective_func": loss_fns.CustomCrossEntropyLoss(),
        "model": model,
        "optimizer": torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
    }
    training_func = train

    print(training_cfg)

elif method_name == "pso":
    model_class = MLP if model_name == "mlp" else LeNet
    model = model_class(
            n_channels=n_channels,
            n_outputs=n_outputs,
            activation=tnn.ReLU
        ) 
    name_modules(model, "")
    model.tmpname = "root"
    model.to(device)
    model.eval()

    training_cfg = {
        "model_class": model_class,
        "propagation_composite": "pso",
        "objective_func": torch.nn.CrossEntropyLoss(),
        #"objective_func": loss_fns.CustomCrossEntropyLoss(),
        "model": model,
        "optimizer": ParticleSwarmOptimizer(
            model.parameters(), # TODO: tune hyperparams
            cognitive_coefficient=2, # Note: We decay this if particle best was not update for a while
            social_coefficient=2, # Note: We decay this if global best was not update for a while
            inertial_weight=0.8,
            num_particles=1000,
            max_param_value=0.1,
            min_param_value=-0.1
        )   
    }
    training_func = train

elif method_name == "fa":
    model_class = FaMLP if model_name == "mlp" else FaLeNet
    model = model_class(
        n_channels=n_channels,
        n_outputs=n_outputs,
        activation=tnn.ReLU
    )
    name_modules(model, "")
    model.tmpname = "root"
    model.to(device)
    model.eval()

    training_cfg = {
        "model_class": model_class,
        "propagation_composite": None,
        "objective_func": loss_fns.CustomCrossEntropyLoss(),
        "model": model,
        "optimizer": torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9),
    }
    training_func = train

# elif method_name == "ldtp": #https://arxiv.org/pdf/2201.13415
    
elif method_name == "dladmm":
    if model_name != "mlp":
        raise ValueError("dladmm only implemented for mlp")
    dladmm_mnist = dladmm_data.mnist(data_path)
    model_class = dladmm.DladmmNet
    model = tuple(model_class(
        images=torch.transpose(dladmm_mnist.x_train, 0, 1), 
        label=torch.transpose(dladmm_mnist.y_train, 0, 1), 
        num_of_neurons1=120, 
        num_of_neurons2=84,
    )) # Model is just a tuple of parameters, pre-acts, and activations here

    training_cfg = {
        "model_class": model_class,
        "model": model,
        "x_train": dladmm_mnist.x_train,
        "y_train": dladmm_mnist.y_train,
        "x_test": dladmm_mnist.x_test,
        "y_test": dladmm_mnist.y_test,
    }
    training_func = dladmm.train

else:
    raise ValueError

training_cfg = {**training_cfg, **general_params}

{'model_class': <class '__main__.LeNet'>, 'propagation_composite': None, 'objective_func': CustomCrossEntropyLoss(), 'model': LeNet(
  (features): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=120, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=120, out_features=84, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
  )
  (last): Linear(in_features=84, out_features=10, bias=True)
), 'optimizer': SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    fused: None
    lr: 0.01
    maximize: False
    momentum: 0.9
    nesterov: False
    we

### Train, if results not available

In [7]:
result_path = os.path.join(savepath, "result_dict.joblib")
if not os.path.exists(result_path):
    print(f"TRAINING method {method_name} with seed {seed}")
    evals = training_func(**training_cfg)
    joblib.dump(evals, result_path)
else:
    evals = joblib.load(result_path)

for k, v in evals.items():
    print("FINAL EPOCH")
    print(k, v[-1])

TRAINING method vanilla-gradient with seed 13
INIT: (Train Objective) 2.31; (Train Accuracy) 0.10; (Test Objective) 2.31; (Test Accuracy) 0.10
Epoch 1/50: (Train Objective) 0.14; (Train Accuracy) 0.96; (Test Objective) 0.12; (Test Accuracy) 0.96
Epoch 2/50: (Train Objective) 0.08; (Train Accuracy) 0.98; (Test Objective) 0.07; (Test Accuracy) 0.98
Epoch 3/50: (Train Objective) 0.06; (Train Accuracy) 0.98; (Test Objective) 0.06; (Test Accuracy) 0.98
Epoch 4/50: (Train Objective) 0.05; (Train Accuracy) 0.98; (Test Objective) 0.05; (Test Accuracy) 0.98
Epoch 5/50: (Train Objective) 0.04; (Train Accuracy) 0.99; (Test Objective) 0.04; (Test Accuracy) 0.99
Epoch 6/50: (Train Objective) 0.04; (Train Accuracy) 0.99; (Test Objective) 0.04; (Test Accuracy) 0.99
Epoch 7/50: (Train Objective) 0.03; (Train Accuracy) 0.99; (Test Objective) 0.03; (Test Accuracy) 0.99
Epoch 8/50: (Train Objective) 0.03; (Train Accuracy) 0.99; (Test Objective) 0.04; (Test Accuracy) 0.99
Epoch 9/50: (Train Objective) 0.0