# Train Forward-Forward Models

Codes mainly referenced from: https://github.com/carloalbertobarbano/forward-forward-pytorch

1. Training base FF model, presented in paper: https://arxiv.org/abs/2212.13345

2. Training modified FF model.


## 0. Preparation

In [1]:
import forward_forward as ff

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import argparse
import torch.utils.tensorboard
from torch.utils.tensorboard.writer import SummaryWriter

from collections import defaultdict
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm

import random
import numpy as np
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## HYPERPARAMS ##

HARD_NEGATIVES = True
LAYER_SIZE = 2000

BATCH_SIZE = 200
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0
NUM_EPOCHS = 60
STEPS_PER_BLOCK = 60
THETA = 10.

SEED = 42
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


In [3]:
## UTILS ##

def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(seed)
    
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(1.0 / batch_size).item())
        return res

def norm_y(y_one_hot: torch.Tensor):
    return y_one_hot.sub(0.1307).div(0.3081)

In [4]:
## TRAINING ##

@torch.no_grad()
def test(network_ff, linear_cf, test_loader, verbose=False):
    all_outputs = []
    all_labels = []
    all_logits = []

    for (x_test, y_test) in test_loader:
        x_test, y_test = x_test.to(DEVICE), y_test.to(DEVICE)
        x_test = x_test.view(x_test.shape[0], -1)

        acts_for_labels = []

        # slow method
        for label in range(10):
            test_label = torch.ones_like(y_test).fill_(label)
            test_label = norm_y(F.one_hot(test_label, num_classes=10))
            x_with_labels = torch.cat((x_test, test_label), dim=1)
            
            acts = network_ff(x_with_labels)
            acts = acts.norm(dim=-1)
            acts_for_labels.append(acts)
        
        # these are logits
        acts_for_labels = torch.stack(acts_for_labels, dim=1) #should be BSZxLABELSxLAYERS (10)
        all_outputs.append(acts_for_labels)
        all_labels.append(y_test)

        # quick method
        neutral_label = norm_y(torch.full((x_test.shape[0], 10), 0.1, device=DEVICE))
        acts = network_ff(torch.cat((x_test, neutral_label), dim=1))
        acts = acts[:, 1:]
        logits = linear_cf(acts.view(acts.shape[0], -1))
        all_logits.append(logits)

    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    all_logits = torch.cat(all_logits)

    slow_acc = accuracy(all_outputs.mean(dim=-1), all_labels, topk=(1,))[0]
    fast_acc = accuracy(all_logits, all_labels, topk=(1,))[0]
    return slow_acc, fast_acc

def train(network_ff, optimizer, linear_cf, optimizer_cf, train_loader, start_block):
    running_loss = 0.
    running_ce = 0.

    for (x, y_pos) in train_loader:
        x, y_pos = x.to(DEVICE), y_pos.to(DEVICE)
        x = x.view(BATCH_SIZE, -1)

        # positive pairs
        y_pos_one_hot = norm_y(F.one_hot(y_pos, num_classes=10))
        x_pos = torch.cat((x, y_pos_one_hot), dim=1)
        
        # sample negatives (and train linear cf)
        with torch.no_grad():
            ys = network_ff(torch.cat((x, torch.ones_like(y_pos_one_hot).fill_(0.1)), dim=1))
            # first layer should be excluded!
            ys = ys[:, 1:]

        with torch.enable_grad():
            logits = linear_cf(ys.view(ys.shape[0], -1).detach())
            ce = F.cross_entropy(logits, y_pos)
            ce.backward()
            running_ce += ce.detach()

        optimizer_cf.step()
        optimizer_cf.zero_grad()

        # negative pairs from softmax layer
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(probs, dim=1)
        idx = torch.where(preds != y_pos)
        y_hard_one_hot = norm_y(F.one_hot(preds, num_classes=10))
        x_hard = torch.cat((x, y_hard_one_hot), dim=1)[idx]

        # negative pairs from random labels
        y_rand = torch.randint(0, 10, (BATCH_SIZE,), device=DEVICE)
        idx = torch.where(y_rand != y_pos) # correct labels
        y_rand_one_hot = norm_y(F.one_hot(y_rand, num_classes=10))
        x_rand = torch.cat((x, y_rand_one_hot), dim=1) #[idx] # keeping positives seems to work better

        x_neg = x_rand
        if HARD_NEGATIVES:
            x_neg = torch.cat((x_neg, x_hard), dim=0)
            
        with torch.enable_grad():
            z_pos = network_ff(x_pos, cat=False)
            z_neg = network_ff(x_neg, cat=False)

            for idx, (zp, zn) in enumerate(zip(z_pos, z_neg)):
                if idx < start_block:
                    continue

                positive_loss = torch.log(1 + torch.exp((-zp.norm(dim=-1) + THETA))).mean()
                negative_loss = torch.log(1 + torch.exp((zn.norm(dim=-1) - THETA))).mean()
                loss = positive_loss + negative_loss
                loss.backward()

                running_loss += loss.detach()
                optimizer[idx].step()
                optimizer[idx].zero_grad()
    
    running_loss /= len(train_loader)
    running_ce /= len(train_loader)

    return running_loss, running_ce

In [5]:
## PREPARE ##

set_seed(SEED)

T_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

T_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_loader = DataLoader(
    MNIST("~/data", train=True, download=True, transform=T_train), 
    batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=8,
    persistent_workers=True
)

test_loader = DataLoader(
    MNIST("~/data", train=False, download=True, transform=T_test), 
    batch_size=BATCH_SIZE, shuffle=True, num_workers=8,
    persistent_workers=True
)

size = LAYER_SIZE
network_ff = ff.FFBase(dims=[28*28 + 10, size, size, size, size]).to(DEVICE)
print(network_ff)

# Create one optimizer for evey relu layer (block)
optimizers = [
    torch.optim.Adam(block.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        for block in network_ff.blocks.children()
] 

# Softmax layer for predicting classes from embeddings (fast method)
linear_cf = nn.Linear(size*(network_ff.n_blocks-1), 10).to(DEVICE)
optimizer_cf = torch.optim.Adam(linear_cf.parameters(), lr=0.0001)


FFBase(
  (blocks): Sequential(
    (0): BlockBase(
      (fc): Linear(in_features=794, out_features=2000, bias=False)
      (relu): ReLU(inplace=True)
    )
    (1): BlockBase(
      (fc): Linear(in_features=2000, out_features=2000, bias=False)
      (relu): ReLU(inplace=True)
    )
    (2): BlockBase(
      (fc): Linear(in_features=2000, out_features=2000, bias=False)
      (relu): ReLU(inplace=True)
    )
    (3): BlockBase(
      (fc): Linear(in_features=2000, out_features=2000, bias=False)
      (relu): ReLU(inplace=True)
    )
  )
)


## 1. Train Base Model

Train base model from paper.

In [6]:
## RUN ##

writer = SummaryWriter()
start_block = 0

max_fast_acc = 0
max_fast_acc_epoch = 0
max_slow_acc = 0
max_slow_acc_epoch = 0

pbar = tqdm(range(1, NUM_EPOCHS+1), total=NUM_EPOCHS)

for step in pbar:
    running_loss, running_ce = train(network_ff, optimizers, linear_cf, optimizer_cf,
                                        train_loader, start_block)
    if step % STEPS_PER_BLOCK == 0:
        if start_block+1 < network_ff.n_blocks:
            start_block += 1
            print("Freezing block", start_block-1)
    
    writer.add_scalar("train/loss", running_loss, step)
    writer.add_scalar("train/ce", running_ce, step)

    train_slow_acc, train_fast_acc = test(network_ff, linear_cf, train_loader)
    test_slow_acc, test_fast_acc = test(network_ff, linear_cf, test_loader)

    writer.add_scalar("acc_fast/train", train_fast_acc, step)
    writer.add_scalar("acc_fast/test", test_fast_acc, step)
    writer.add_scalar("acc_slow/train", train_slow_acc, step)
    writer.add_scalar("acc_slow/test", test_slow_acc, step)

    pbar.set_postfix({
        "train_fast_acc": train_fast_acc,
        "train_slow_acc": train_slow_acc,
        "test_fast_acc": test_fast_acc,
        "test_slow_acc": test_slow_acc
    })

    if test_fast_acc > max_fast_acc:
        max_fast_acc = test_fast_acc
        max_fast_acc_epoch = step
    
    if test_slow_acc > max_slow_acc:
        max_slow_acc = test_slow_acc
        max_slow_acc_epoch = step

print("Min fast acc:", max_fast_acc, "at epoch", max_fast_acc_epoch)
print("Min slow acc:", max_slow_acc, "at epoch", max_slow_acc_epoch)

  0%|          | 0/60 [00:00<?, ?it/s]

torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
torch.Size([200, 3, 2000])
t

# 2. Train Custom Model

Train MLP-Blocked FF model.

In [None]:
size = LAYER_SIZE
ff_MLP = ff.FFMLP(dims=[28*28 + 10, size, size, size]).to(DEVICE)
print(ff_MLP)

# Create one optimizer for evey relu layer (block)
optimizers = [
    torch.optim.Adam(block.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        for block in network_ff.blocks.children()
] 

# Softmax layer for predicting classes from embeddings (fast method)
linear_MLP_cf = nn.Linear(size*(ff_MLP.n_blocks-1), 10).to(DEVICE)
optimizer_MLP_cf = torch.optim.Adam(linear_MLP_cf.parameters(), lr=0.0001)

In [None]:
## RUN ##

writer = SummaryWriter()
start_block = 0

max_fast_acc = 0
max_fast_acc_epoch = 0
max_slow_acc = 0
max_slow_acc_epoch = 0

pbar = tqdm(range(1, NUM_EPOCHS+1), total=NUM_EPOCHS)

for step in pbar:
    running_loss, running_ce = train(ff_MLP, optimizers, linear_MLP_cf, optimizer_MLP_cf,
                                        train_loader, start_block)
    if step % STEPS_PER_BLOCK == 0:
        if start_block+1 < ff_MLP.n_blocks:
            start_block += 1
            print("Freezing block", start_block-1)
    
    writer.add_scalar("train/loss", running_loss, step)
    writer.add_scalar("train/ce", running_ce, step)

    train_slow_acc, train_fast_acc = test(ff_MLP, linear_MLP_cf, train_loader)
    test_slow_acc, test_fast_acc = test(ff_MLP, linear_MLP_cf, test_loader)

    writer.add_scalar("acc_fast/train", train_fast_acc, step)
    writer.add_scalar("acc_fast/test", test_fast_acc, step)
    writer.add_scalar("acc_slow/train", train_slow_acc, step)
    writer.add_scalar("acc_slow/test", test_slow_acc, step)

    pbar.set_postfix({
        "train_fast_acc": train_fast_acc,
        "train_slow_acc": train_slow_acc,
        "test_fast_acc": test_fast_acc,
        "test_slow_acc": test_slow_acc
    })

    if test_fast_acc > max_fast_acc:
        max_fast_acc = test_fast_acc
        max_fast_acc_epoch = step
    
    if test_slow_acc > max_slow_acc:
        max_slow_acc = test_slow_acc
        max_slow_acc_epoch = step

print("Min fast acc:", max_fast_acc, "at epoch", max_fast_acc_epoch)
print("Min slow acc:", max_slow_acc, "at epoch", max_slow_acc_epoch)