This Notebook provides a minimal example for using LFP to train a simple LeNet on MNIST.

For more complex examples, refer to the experiment notebooks in ./nbs

### Imports

In [None]:
import os

import numpy as np
import torch
import torch.nn as tnn
import torcheval.metrics
import torchvision.datasets as tvisiondata
import torchvision.transforms as T
from tqdm import tqdm

from lxt import rules as lrules
from lxt.modules import INIT_MODULE_MAPPING

from zennit import types as ztypes

import open_clip
import open_clip.transformer

from lfprop.propagation import (
    propagator_lxt as propagator,
)  # LFP propagator. Alternatively, use propagator_zennit
from lfprop.rewards import reward_functions as rewards  # Reward Functions

  from .autonotebook import tqdm as notebook_tqdm


### Parameters

In [2]:
savepath = "./minimal-example-openclip-data"
os.makedirs(savepath, exist_ok=True)

batch_size = 128
n_channels = 1
n_outputs = 10

lr = 0.1
momentum = 0.9
epochs = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Dataset

In [3]:
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])
training_data = tvisiondata.MNIST(
    root=savepath,
    transform=transform,
    download=True,
    train=True,
)

validation_data = tvisiondata.MNIST(
    root=savepath,
    transform=transform,
    download=True,
    train=False,
)

training_loader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, shuffle=False)

100%|██████████| 9.91M/9.91M [00:00<00:00, 10.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 274kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 2.56MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.95MB/s]


### Load Model

In [4]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s34b_b88k')
model.eval()
model = model.to(device)

tokenizer = open_clip.get_tokenizer('ViT-g-14')

print(model)

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-39): 40 x ResidualAttentionBlock(
          (ln_1): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1408, out_features=1408, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=1408, out_features=6144, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): LayerNorm((1408,), eps=1e-05, elementwi

### Set Up LFP

In [None]:
# Initialize the LFP Composite (cf. "composites" in zennit or lxt).
# This composite is tailored to the specific model.
propagation_composite = propagator.ParameterizableComposite({
            ztypes.Activation: lrules.IdentityRule,
            ztypes.AvgPool: propagator.RuleGenerator(
                propagator.LFPEpsilon, epsilon=1e-6, norm_backward=False
            ),
            ztypes.Linear: propagator.RuleGenerator(
                propagator.LFPEpsilon, epsilon=1e-6, norm_backward=False
            ),
            ztypes.BatchNorm: propagator.RuleGenerator(
                propagator.LFPEpsilon, epsilon=1e-6, norm_backward=False
            ),
        }
)

# ({ 
#         nn.MultiheadAttention: lm.MultiheadAttention_CP,
#         # order matters! lm.LinearInProjection is inside lm.MultiheadAttention_CP
#         lm.LinearInProjection: rules.EpsilonRule,
#         lm.LinearOutProjection: rules.EpsilonRule,
#         open_clip.transformer.LayerNorm: lm.LayerNormEpsilon,
#         nn.GELU: rules.IdentityRule,
        
#         operator.add: lf.add2,
#         operator.matmul: lf.matmul,
#         F.normalize: lf.normalize,
#     })

# Initialize the Reward Function.
# Here we use the Reward Function suggested in the LFP-Paper, but check out other reward functions in ./lfp/rewards/reward_functions.py
reward_func = rewards.SoftmaxLossReward(device)

# LFP writes its updates into the .grad attribute of the model parameters, and can thus utilize standard torch optimizers
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

### Set Up Simple Evaluation using torcheval

In [6]:
def eval_model(loader):
    """
    Evaluates the model on a single dataset
    """
    eval_metrics = {
        "reward": torcheval.metrics.Mean(device=device),
        "accuracy": torcheval.metrics.MulticlassAccuracy(average="micro", num_classes=10, k=1, device=device),
    }

    model.eval()

    # Iterate over Data Loader
    for index, (inputs, labels) in enumerate(loader):
        inputs = inputs.to(device)
        labels = torch.tensor(labels).to(device)

        with torch.no_grad():
            # Get model predictions
            outputs = model(inputs)

        with torch.set_grad_enabled(True):
            # Get rewards
            reward = reward_func(outputs, labels)

        for k, v in eval_metrics.items():
            if k == "reward":
                eval_metrics[k].update(reward)
            else:
                eval_metrics[k].update(outputs, labels)

    return_dict = {m: metric.compute().detach().cpu().numpy() for m, metric in eval_metrics.items()}

    # Return evaluation
    return return_dict

### Training Loop

In [7]:
def lfp_step(inputs, labels):
    """
    Performs a single training step using LFP. This is quite similar to a standard gradient descent training loop.
    """
    # Set Model to training mode
    model.train()

    with torch.enable_grad():
        # Zero Optimizer
        optimizer.zero_grad()

        # This applies LFP Hooks/Functions (which depends on whether lxt or zennit backend is used)
        with propagation_composite.context(model) as modified:
            inputs = inputs.detach().requires_grad_(True)
            outputs = modified(inputs)

            # Calculate reward
            # Do like this to avoid tensors being kept in memory
            reward = torch.from_numpy(reward_func(outputs, labels).detach().cpu().numpy()).to(device)

            # Calculate LFP and write into .feedback attribute of parameters
            torch.autograd.grad((outputs,), (inputs,), grad_outputs=(reward,), retain_graph=False)[0]

            # Write LFP Values into .grad attributes. Note the negative sign: LFP requires maximization instead of minimization like gradient descent
            for name, param in model.named_parameters():
                param.grad = -param.feedback

            # Update Clipping. Training may become unstable otherwise, especially in small models with large learning rates.
            # In larger models (e.g., VGG, ResNet), where smaller learning rates are generally utilized, not clipping updates may result in better performance.
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)

            # Optimization step
            optimizer.step()

    # Set Model back to eval mode
    model.eval()


# Training Loop
for epoch in range(epochs):
    with tqdm(total=len(training_loader)) as pbar:
        # Iterate over Data Loader
        for index, (inputs, labels) in enumerate(training_loader):
            inputs = inputs.to(device)
            labels = torch.tensor(labels).to(device)

            # Perform Update Step
            lfp_step(inputs, labels)

            # Update Progress Bar
            pbar.update(1)

    # Evaluate and print performance after every epoch
    eval_stats_train = eval_model(training_loader)
    eval_stats_val = eval_model(validation_loader)
    print(
        "Epoch {}/{}: (Train Reward) {:.2f}; (Train Accuracy) {:.2f}; (Val Reward) {:.2f}; (Val Accuracy) {:.2f}".format(
            epoch + 1,
            epochs,
            float(np.mean(eval_stats_train["reward"])),
            float(eval_stats_train["accuracy"]),
            float(np.mean(eval_stats_val["reward"])),
            float(eval_stats_val["accuracy"]),
        )
    )

  labels = torch.tensor(labels).to(device)
  warn("This functionality is not yet fully tested. Please check the model after removing the composite.")
100%|██████████| 469/469 [00:18<00:00, 25.01it/s]
  labels = torch.tensor(labels).to(device)


Epoch 1/10: (Train Reward) -0.00; (Train Accuracy) 0.94; (Val Reward) 0.00; (Val Accuracy) 0.95


100%|██████████| 469/469 [00:18<00:00, 25.47it/s]


Epoch 2/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) -0.00; (Val Accuracy) 0.96


100%|██████████| 469/469 [00:18<00:00, 25.36it/s]


Epoch 3/10: (Train Reward) 0.00; (Train Accuracy) 0.96; (Val Reward) 0.00; (Val Accuracy) 0.96


100%|██████████| 469/469 [00:18<00:00, 25.89it/s]


Epoch 4/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) -0.00; (Val Accuracy) 0.96


100%|██████████| 469/469 [00:18<00:00, 25.73it/s]


Epoch 5/10: (Train Reward) -0.00; (Train Accuracy) 0.95; (Val Reward) 0.00; (Val Accuracy) 0.96


100%|██████████| 469/469 [00:18<00:00, 24.94it/s]


Epoch 6/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) -0.00; (Val Accuracy) 0.96


100%|██████████| 469/469 [00:18<00:00, 25.22it/s]


Epoch 7/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) -0.00; (Val Accuracy) 0.96


100%|██████████| 469/469 [00:17<00:00, 26.16it/s]


Epoch 8/10: (Train Reward) -0.00; (Train Accuracy) 0.95; (Val Reward) -0.00; (Val Accuracy) 0.96


100%|██████████| 469/469 [00:17<00:00, 26.17it/s]


Epoch 9/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) -0.00; (Val Accuracy) 0.96


100%|██████████| 469/469 [00:17<00:00, 27.25it/s]


Epoch 10/10: (Train Reward) -0.00; (Train Accuracy) 0.96; (Val Reward) 0.00; (Val Accuracy) 0.96
