<a href="https://colab.research.google.com/github/bghdd/Forward-Forward/blob/main/forward_forward_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
import torch
import math
import torch.nn as nn
import os
import random
from datetime import timedelta
import torchvision
import time
from collections import defaultdict

In [9]:
opt_seed =42
device="cuda"  # cpu or cuda

input_path= './datasets'
input_batch_size= 100


model_peer_normalization= 0.03
model_momentum= 0.9  # Momentum to use for the running mean in peer normalization loss.

model_hidden_dim= 1000
model_num_layers= 3


training_epochs= 100

training_learning_rate= 1e-3
training_weight_decay= 3e-4
training_momentum= 0.9

training_downstream_learning_rate= 1e-2
training_downstream_weight_decay= 3e-3

training_val_idx= -1  # -1: validate only once training has finished; n: validate every n epochs.
training_final_test= True  # Set to true to evaluate performance on test-set.


make the mnist dataset suitable for FF

In [4]:
class FF_MNIST(torch.utils.data.Dataset):
    def __init__(self, partition, num_classes=10):
        self.mnist = get_MNIST_partition(partition)
        self.num_classes = num_classes
        self.uniform_label = torch.ones(self.num_classes) / self.num_classes

    def __getitem__(self, index):
        pos_sample, neg_sample, neutral_sample, class_label = self._generate_sample(
            index
        )

        inputs = {
            "pos_images": pos_sample,
            "neg_images": neg_sample,
            "neutral_sample": neutral_sample,
        }
        labels = {"class_labels": class_label}
        return inputs, labels

    def __len__(self):
        return len(self.mnist)

    def _get_pos_sample(self, sample, class_label):
        one_hot_label = torch.nn.functional.one_hot(
            torch.tensor(class_label), num_classes=self.num_classes
        )
        pos_sample = sample.clone()
        pos_sample[:, 0, : self.num_classes] = one_hot_label
        return pos_sample

    def _get_neg_sample(self, sample, class_label):
        # Create randomly sampled one-hot label.
        classes = list(range(self.num_classes))
        classes.remove(class_label)  # Remove true label from possible choices.
        wrong_class_label = np.random.choice(classes)
        one_hot_label = torch.nn.functional.one_hot(
            torch.tensor(wrong_class_label), num_classes=self.num_classes
        )
        neg_sample = sample.clone()
        neg_sample[:, 0, : self.num_classes] = one_hot_label
        return neg_sample

    def _get_neutral_sample(self, z):
        z[:, 0, : self.num_classes] = self.uniform_label
        return z

    def _generate_sample(self, index):
        # Get MNIST sample.
        sample, class_label = self.mnist[index]
        pos_sample = self._get_pos_sample(sample, class_label)
        neg_sample = self._get_neg_sample(sample, class_label)
        neutral_sample = self._get_neutral_sample(sample)
        return pos_sample, neg_sample, neutral_sample, class_label

model

In [5]:
class FF_model(torch.nn.Module):
    """The model trained with Forward-Forward (FF)."""

    def __init__(self):
        super(FF_model, self).__init__()
        self.input_batch_size=input_batch_size
        self.model_num_layers= model_num_layers
        self.device = device
        self.model_momentum = model_momentum
        self.model_peer_normalization = model_peer_normalization
        self.num_channels = [model_hidden_dim] * self.model_num_layers
        self.act_fn = ReLU_full_grad()
        # Initialize the model.
        self.model = nn.ModuleList([nn.Linear(784, self.num_channels[0])])
        for i in range(1, len(self.num_channels)):
            self.model.append(nn.Linear(self.num_channels[i - 1], self.num_channels[i]))

        # Initialize forward-forward loss.
        self.ff_loss = nn.BCEWithLogitsLoss()

        # Initialize peer normalization loss.
        self.running_means = [
            torch.zeros(self.num_channels[i], device=self.device) + 0.5
            for i in range(self.model_num_layers)
        ]

        # Initialize downstream classification loss.
        channels_for_classification_loss = sum(
            self.num_channels[-i] for i in range(self.model_num_layers - 1)
        )
        self.linear_classifier = nn.Sequential(
            nn.Linear(channels_for_classification_loss, 10, bias=False)
        )
        self.classification_loss = nn.CrossEntropyLoss()

        # Initialize weights.
        self._init_weights()

    def _init_weights(self):
        for m in self.model.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.normal_(
                    m.weight, mean=0, std=1 / math.sqrt(m.weight.shape[0])
                )
                torch.nn.init.zeros_(m.bias)

        for m in self.linear_classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.zeros_(m.weight)

    def _layer_norm(self, z, eps=1e-8):
        return z / (torch.sqrt(torch.mean(z ** 2, dim=-1, keepdim=True)) + eps)

    def _calc_peer_normalization_loss(self, idx, z):
        # Only calculate mean activity over positive samples.
        mean_activity = torch.mean(z[:self.input_batch_size], dim=0)

        self.running_means[idx] = self.running_means[
            idx
        ].detach() * self.model_momentum + mean_activity * (
            1 - self.model_momentum
        )

        peer_loss = (torch.mean(self.running_means[idx]) - self.running_means[idx]) ** 2
        return torch.mean(peer_loss)

    def _calc_ff_loss(self, z, labels):
        sum_of_squares = torch.sum(z ** 2, dim=-1)

        logits = sum_of_squares - z.shape[1]
        ff_loss = self.ff_loss(logits, labels.float())

        with torch.no_grad():
            ff_accuracy = (
                torch.sum((torch.sigmoid(logits) > 0.5) == labels)
                / z.shape[0]
            ).item()
        return ff_loss, ff_accuracy

    def forward(self, inputs, labels):
        scalar_outputs = {
            "Loss": torch.zeros(1, device=self.device),
            "Peer Normalization": torch.zeros(1, device=self.device),
        }

        # Concatenate positive and negative samples and create corresponding labels.
        z = torch.cat([inputs["pos_images"], inputs["neg_images"]], dim=0)
        posneg_labels = torch.zeros(z.shape[0], device=self.device)
        posneg_labels[: self.input_batch_size] = 1

        z = z.reshape(z.shape[0], -1)
        z = self._layer_norm(z)

        for idx, layer in enumerate(self.model):
            z = layer(z)
            z = self.act_fn.apply(z)

            if self.model_peer_normalization > 0:
                peer_loss = self._calc_peer_normalization_loss(idx, z)
                scalar_outputs["Peer Normalization"] += peer_loss
                scalar_outputs["Loss"] += self.model_peer_normalization * peer_loss

            ff_loss, ff_accuracy = self._calc_ff_loss(z, posneg_labels)
            scalar_outputs[f"loss_layer_{idx}"] = ff_loss
            scalar_outputs[f"ff_accuracy_layer_{idx}"] = ff_accuracy
            scalar_outputs["Loss"] += ff_loss
            z = z.detach()

            z = self._layer_norm(z)

        scalar_outputs = self.forward_downstream_classification_model(
            inputs, labels, scalar_outputs=scalar_outputs
        )

        return scalar_outputs

    def forward_downstream_classification_model(
        self, inputs, labels, scalar_outputs=None,
    ):
        if scalar_outputs is None:
            scalar_outputs = {
                "Loss": torch.zeros(1, device=self.device),
            }

        z = inputs["neutral_sample"]
        z = z.reshape(z.shape[0], -1)
        z = self._layer_norm(z)

        input_classification_model = []

        with torch.no_grad():
            for idx, layer in enumerate(self.model):
                z = layer(z)
                z = self.act_fn.apply(z)
                z = self._layer_norm(z)

                if idx >= 1:
                    input_classification_model.append(z)

        input_classification_model = torch.concat(input_classification_model, dim=-1)

        output = self.linear_classifier(input_classification_model.detach())
        output = output - torch.max(output, dim=-1, keepdim=True)[0]
        classification_loss = self.classification_loss(output, labels["class_labels"])
        classification_accuracy = get_accuracy(
            output.data, labels["class_labels"]
        )

        scalar_outputs["Loss"] += classification_loss
        scalar_outputs["classification_loss"] = classification_loss
        scalar_outputs["classification_accuracy"] = classification_accuracy
        return scalar_outputs


class ReLU_full_grad(torch.autograd.Function):
    """ ReLU activation function that passes through the gradient irrespective of its input value. """

    @staticmethod
    def forward(ctx, input):
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()


utils

In [6]:
def get_model_and_optimizer():
    model = FF_model()
    model = model.to(device)
    print(model, "\n")

    # Create optimizer with different hyper-parameters for the main model
    # and the downstream classification model.
    main_model_params = [
        p
        for p in model.parameters()
        if all(p is not x for x in model.classification_loss.parameters())
    ]
    optimizer = torch.optim.SGD(
        [
            {
                "params": main_model_params,
                "lr": training_learning_rate,
                "weight_decay": training_weight_decay,
                "momentum": training_momentum,
            },
            {
                "params": model.classification_loss.parameters(),
                "lr": training_downstream_learning_rate,
                "weight_decay": training_downstream_weight_decay,
                "momentum": training_momentum,
            },
        ]
    )
    return model, optimizer


def get_data(partition):
    dataset = FF_MNIST( partition)

    # Improve reproducibility in dataloader.
    g = torch.Generator()
    g.manual_seed(opt_seed)

    return torch.utils.data.DataLoader(
        dataset,
        batch_size=input_batch_size,
        drop_last=True,
        shuffle=True,
        worker_init_fn=seed_worker,
        generator=g,
        num_workers=4,
        persistent_workers=True,
    )


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def get_MNIST_partition(partition):
    if partition in ["train", "val", "train_val"]:
        mnist = torchvision.datasets.MNIST(
            root='./datasets',
            train=True,
            download=True,
            transform=torchvision.transforms.ToTensor(),
        )
    elif partition in ["test"]:
        mnist = torchvision.datasets.MNIST(
            root='./datasets',
            train=False,
            download=True,
            transform=torchvision.transforms.ToTensor(),
        )
    else:
        raise NotImplementedError

    if partition == "train":
        mnist = torch.utils.data.Subset(mnist, range(50000))
    elif partition == "val":
        mnist = torchvision.datasets.MNIST(
            root='./datasets',
            train=True,
            download=True,
            transform=torchvision.transforms.ToTensor(),
        )
        mnist = torch.utils.data.Subset(mnist, range(50000, 60000))

    return mnist


def dict_to_cuda(dict):
    for key, value in dict.items():
        dict[key] = value.cuda(non_blocking=True)
    return dict


def preprocess_inputs(inputs, labels):
    if "cuda" in device:
        inputs = dict_to_cuda(inputs)
        labels = dict_to_cuda(labels)
    return inputs, labels


def get_linear_cooldown_lr(epoch, lr):
    if epoch > (training_epochs // 2):
        return lr * 2 * (1 + training_epochs - epoch) / training_epochs
    else:
        return lr


def update_learning_rate(optimizer, epoch):
    optimizer.param_groups[0]["lr"] = get_linear_cooldown_lr(
         epoch, training_learning_rate
    )
    optimizer.param_groups[1]["lr"] = get_linear_cooldown_lr(
         epoch, training_downstream_learning_rate
    )
    return optimizer


def get_accuracy(output, target):
    """Computes the accuracy."""
    with torch.no_grad():
        prediction = torch.argmax(output, dim=1)
        return (prediction == target).sum() / input_batch_size


def print_results(partition, iteration_time, scalar_outputs, epoch=None):
    if epoch is not None:
        print(f"Epoch {epoch} \t", end="")

    print(
        f"{partition} \t \t"
        f"Time: {timedelta(seconds=iteration_time)} \t",
        end="",
    )
    if scalar_outputs is not None:
        for key, value in scalar_outputs.items():
            print(f"{key}: {value:.4f} \t", end="")
    print()


def log_results(result_dict, scalar_outputs, num_steps):
    for key, value in scalar_outputs.items():
        if isinstance(value, float):
            result_dict[key] += value / num_steps
        else:
            result_dict[key] += value.item() / num_steps
    return result_dict


In [7]:
def train( model, optimizer):
    start_time = time.time()
    train_loader = get_data("train")
    num_steps_per_epoch = len(train_loader)

    for epoch in range(training_epochs):
        train_results = defaultdict(float)
        optimizer = update_learning_rate(optimizer, epoch)

        for inputs, labels in train_loader:
            inputs, labels = preprocess_inputs(inputs, labels)

            optimizer.zero_grad()

            scalar_outputs = model(inputs, labels)
            scalar_outputs["Loss"].backward()

            optimizer.step()
            train_results = log_results(
                train_results, scalar_outputs, num_steps_per_epoch
            )

        print_results("train", time.time() - start_time, train_results, epoch)
        start_time = time.time()

        # Validate.
        if epoch % training_val_idx == 0 and training_val_idx != -1:
            validate_or_test( model, "val", epoch=epoch)

    return model



In [10]:

def validate_or_test( model, partition, epoch=None):
    test_time = time.time()
    test_results = defaultdict(float)

    data_loader = get_data(partition)
    num_steps_per_epoch = len(data_loader)

    model.eval()
    print(partition)
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = preprocess_inputs(inputs, labels)

            scalar_outputs = model.forward_downstream_classification_model(
                inputs, labels
            )
            test_results = log_results(
                test_results, scalar_outputs, num_steps_per_epoch
            )

    print_results(partition, time.time() - test_time, test_results, epoch=epoch)
    model.train()


def my_main():
    model, optimizer = get_model_and_optimizer()
    model = train(model, optimizer)
    validate_or_test(model, "val")

    if training_final_test:
        validate_or_test( model, "test")

my_main()

FF_model(
  (model): ModuleList(
    (0): Linear(in_features=784, out_features=1000, bias=True)
    (1-2): 2 x Linear(in_features=1000, out_features=1000, bias=True)
  )
  (ff_loss): BCEWithLogitsLoss()
  (linear_classifier): Sequential(
    (0): Linear(in_features=2000, out_features=10, bias=False)
  )
  (classification_loss): CrossEntropyLoss()
) 





Epoch 0 	train 	 	Time: 0:00:18.177967 	Loss: 149.2321 	Peer Normalization: 0.6602 	loss_layer_0: 22.2709 	ff_accuracy_layer_0: 0.7688 	loss_layer_1: 58.9507 	ff_accuracy_layer_1: 0.7426 	loss_layer_2: 67.4063 	ff_accuracy_layer_2: 0.7088 	classification_loss: 0.5844 	classification_accuracy: 0.8096 	
Epoch 1 	train 	 	Time: 0:00:18.908163 	Loss: 33.2938 	Peer Normalization: 0.5907 	loss_layer_0: 6.2794 	ff_accuracy_layer_0: 0.8773 	loss_layer_1: 9.4464 	ff_accuracy_layer_1: 0.8987 	loss_layer_2: 17.2900 	ff_accuracy_layer_2: 0.8843 	classification_loss: 0.2602 	classification_accuracy: 0.9205 	
Epoch 2 	train 	 	Time: 0:00:18.021239 	Loss: 20.6114 	Peer Normalization: 0.5551 	loss_layer_0: 4.6982 	ff_accuracy_layer_0: 0.9025 	loss_layer_1: 6.3514 	ff_accuracy_layer_1: 0.9295 	loss_layer_2: 9.3454 	ff_accuracy_layer_2: 0.9217 	classification_loss: 0.1997 	classification_accuracy: 0.9395 	
Epoch 3 	train 	 	Time: 0:00:18.589355 	Loss: 14.6266 	Peer Normalization: 0.5485 	loss_layer_0: 3