In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms

mean = torch.tensor(0.13066045939922333)
std = torch.tensor(0.30810779333114624)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((mean.item(),), (std.item(),)),  # Convert tensors to scalars
    transforms.Lambda(lambda x: torch.flatten(x))
])

In [3]:
from torchvision import datasets

full_train_dataset = datasets.MNIST(
    root='./data/',
    train=True,
    download=True,
    transform=transform
)

In [4]:
train_indices = list(range(0, 50000))   
test_indices = list(range(50000, 60000))

In [5]:
from torch.utils.data import DataLoader, Subset

train_subset = Subset(full_train_dataset, train_indices)
test_subset = Subset(full_train_dataset, test_indices)

In [6]:
train_loader = DataLoader(
    dataset=train_subset,
    batch_size=50000,
    shuffle=True
)

In [7]:
test_loader = DataLoader(
    dataset=test_subset,
    batch_size=10000, 
    shuffle=False
)

In [None]:
print(f"Training dataset size: {len(train_subset)}")
print(f"Test dataset size: {len(test_subset)}") 

In [None]:
def add_labels(images, labels):
    modified_images = images.clone()
    modified_images[:, :10] = 0.0
    one_hot_labels = torch.nn.functional.one_hot(labels, num_classes=10).float()
    max_pixel_value = images.max()
    modified_images[:, :10] = one_hot_labels * max_pixel_value
    return modified_images

In [None]:
from torch.optim import Adam

class Layer(nn.Linear):
    def __init__(self, input_dim, output_dim, bias=True, device=None, dtype=None):
        super().__init__(input_dim, output_dim, bias, device, dtype)
        self.activation = nn.ReLU()
        self.optimizer = Adam(self.parameters(), lr=0.03)
        self.threshold = 2.0
        self.num_epochs = 500
        self.peer_norm_coefficient = 0.01

    def forward(self, input_features):
        normalized_input = input_features / (input_features.norm(p=2, dim=1, keepdim=True) + 1e-4)
        linear_output = torch.mm(normalized_input, self.weight.t()) + self.bias.unsqueeze(0)
        activated_output = self.activation(linear_output)

        return activated_output

    def train(self, positive_features, negative_features):
        for epoch in tqdm(range(self.num_epochs), desc="training layer"):
            positive_output = self.forward(positive_features)
            negative_output = self.forward(negative_features)

            goodness_positive = positive_output.pow(2).mean(dim=1)
            goodness_negative = negative_output.pow(2).mean(dim=1)

            concatenated_goodness = torch.cat([
                -goodness_positive + self.threshold,
                goodness_negative - self.threshold
            ])
            primary_loss = torch.log(1 + torch.exp(concatenated_goodness)).mean()

            #peer normalization loss
            mean_activity = positive_output.mean(dim=0)
            global_mean = mean_activity.mean()
            peer_loss = (mean_activity - global_mean).pow(2).mean()

            #total loss
            total_loss = primary_loss + self.peer_norm_coefficient * peer_loss

            #backpropagation
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

        return self.forward(positive_features).detach(), self.forward(negative_features).detach()

In [None]:
class ForwardForwardNet(nn.Module):
    def __init__(self, layer_dimensions, device=torch.device('cuda')):
        super().__init__()
        self.device = device
        self.layers = nn.ModuleList([
            Layer(input_dim, output_dim).to(self.device)
            for input_dim, output_dim in zip(layer_dimensions[:-1], layer_dimensions[1:])
        ])

    def predict(self, input_features):
        goodness_scores_per_label = []

        for label in range(10):
            labeled_input = add_labels(input_features, torch.tensor([label] * input_features.size(0)).to(self.device))

            hidden_state = labeled_input
            goodness_scores = []

            for layer in self.layers:
                hidden_state = layer(hidden_state)
                goodness = hidden_state.pow(2).mean(dim=1)
                goodness_scores.append(goodness)
            total_goodness = sum(goodness_scores).unsqueeze(1)
            goodness_scores_per_label.append(total_goodness)

        concatenated_goodness = torch.cat(goodness_scores_per_label, dim=1)

        predicted_labels = concatenated_goodness.argmax(dim=1)
        return predicted_labels

    def forward_train(self, positive_features, negative_features):
        for layer_index, layer in enumerate(self.layers):
            print(f'training Layer {layer_index + 1}/{len(self.layers)}...')
            positive_features, negative_features = layer.train(positive_features, negative_features)


In [None]:
ff_net = ForwardForwardNet([784, 2000, 2000, 2000, 2000]).cuda()

In [None]:
import matplotlib.pyplot as plt

def plot_imgs_side_by_side(datasets, names, idx=0):
    num_sets = len(datasets)
    plt.figure(figsize=(4 * num_sets, 4))
    for i, (data, name) in enumerate(zip(datasets, names)):
        plt.subplot(1, num_sets, i + 1)
        image = data[idx].cpu().numpy().reshape(28, 28)
        plt.title(name)
        plt.imshow(image, cmap='viridis')
        plt.axis('off')
    plt.show()

In [4]:
def derangement(n, device=None):
    while True:
        perm = torch.randperm(n, device=device)
        if (perm == torch.arange(n, device=device)).sum() == 0:
            return perm

In [5]:
import torch
from tqdm import tqdm
import time

In [6]:
from tqdm import tqdm

input_images, labels = next(iter(train_loader))
input_images, labels = input_images.to('cuda'), labels.to('cuda')

positive_samples = add_labels(input_images, labels)
random_indices = derangement(input_images.size(0), device=labels.device)
shuffled_labels = labels[random_indices]
negative_samples = add_labels(input_images, shuffled_labels)

if torch.equal(shuffled_labels, labels):
    print("warning: shuffled labels are equal to original labels. shuffling might have failed.")
else:
    print("shuffling successful: shuffled labels are different from original labels.")

plot_imgs_side_by_side([input_images, positive_samples, negative_samples], ['orig', 'pos', 'neg'])

print("initial training...")
ff_net.forward_train(positive_samples, negative_samples)

train_predictions = ff_net.predict(input_images)
train_accuracy = train_predictions.eq(labels).float().mean().item()
train_error = 1.0 - train_accuracy
print(f"train accuracy after first training: {train_accuracy:.4f}")
print(f"train error after first training:    {train_error:.4f}")

num_hard_passes = 10

for hard_pass in range(1, num_hard_passes + 1):
    print("\n")
    incorrect_mask = train_predictions != labels
    if incorrect_mask.sum() == 0:
        print(f"no more incorrect predictions after pass {hard_pass - 1}. stopping early.")
        break

    incorrect_images = input_images[incorrect_mask]
    incorrect_pred_labels = train_predictions[incorrect_mask]

    hard_negative_samples = add_labels(incorrect_images, incorrect_pred_labels)
    correct_mask = ~incorrect_mask
    correct_images = input_images[correct_mask]
    correct_labels = labels[correct_mask]

    num_incorrect = incorrect_images.size(0)
    num_correct = correct_images.size(0)
    num_to_sample = min(num_correct, num_incorrect)

    sampled_indices = torch.randperm(num_correct, device=input_images.device)[:num_to_sample]
    sampled_correct_images = correct_images[sampled_indices]
    sampled_correct_labels = correct_labels[sampled_indices]

    hard_positive_samples = add_labels(sampled_correct_images, sampled_correct_labels)

    print(f"hard pass {hard_pass}: re-training with hard negatives...")

    ff_net.forward_train(hard_positive_samples, hard_negative_samples)

    train_predictions = ff_net.predict(input_images)
    train_accuracy = train_predictions.eq(labels).float().mean().item()
    train_error = 1.0 - train_accuracy
    print(f"train accuracy after hard pass {hard_pass}: {train_accuracy:.4f}")
    print(f"train error after hard pass {hard_pass}:    {train_error:.4f}")

test_images, test_labels = next(iter(test_loader))
test_images, test_labels = test_images.to('cuda'), test_labels.to('cuda')
test_predictions = ff_net.predict(test_images)
test_accuracy = test_predictions.eq(test_labels).float().mean().item()
test_error = 1.0 - test_accuracy
print(f"final test accuracy:  {test_accuracy:.4f}")
print(f"final test error:     {test_error:.4f}")

NameError: name 'train_loader' is not defined