In [1]:
# Rework of https://github.com/mohammadpz/pytorch_forward_forward.git

In [2]:
import math
from tqdm import tqdm
import torch
from torch.optim import Adam
import torch.nn as nn
from torchvision import datasets, transforms

In [3]:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cuda


In [4]:
train_batch_size = 256
test_batch_size = 1000

train_kwargs = {'batch_size': train_batch_size, 'shuffle': False}
test_kwargs = {'batch_size': test_batch_size, 'shuffle': False}
cuda_kwargs = {'num_workers': 1, 'pin_memory': True}

if device.type == "cuda":
    train_kwargs |= cuda_kwargs
    test_kwargs |= cuda_kwargs

In [5]:
transform = transforms.Compose([transforms.ToTensor()])

In [6]:
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)

dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)

data_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [7]:
class_num = len(dataset1.classes)
data_item_size = dataset1.data.shape[1:]

In [8]:
class Layer(nn.Linear):
    def __init__(self, in_features, out_features,
                 bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.opt = Adam(self.parameters(), lr=0.03)
        self.threshold = 2.0
        self.num_epochs = 10

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-6)
        return torch.relu(super().forward(x_direction))            

    def train(self, x_pos: torch.Tensor, x_neg: torch.Tensor):
        for _ in range(self.num_epochs):
            g_pos = self.forward(x_pos).pow(2).mean(1)
            g_neg = self.forward(x_neg).pow(2).mean(1)
            # The following loss pushes pos (neg) samples to
            # values larger (smaller) than the self.threshold.
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()
            self.opt.zero_grad()
            # this backward just compute the derivative and hence
            # is not considered backpropagation.
            loss.backward()
            self.opt.step()
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

In [9]:
class Net(torch.nn.Module):
    def __init__(self, dims, class_num, device):
        assert len(dims) >= 1
        assert class_num > 0
        super().__init__()
        self.class_num = class_num
        self.layers = [Layer(in_features=(dims[n] + (class_num if n == 0 else 0)), out_features=dims[n + 1], device=device) for n in range(len(dims) - 1)]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        x_max = x.max()
        goodness_per_label = torch.zeros(batch_size, self.class_num, device=x.device)
        for k, label in enumerate(range(self.class_num)):
            one_hot_label = torch.zeros(batch_size, self.class_num, device=x.device)
            one_hot_label[range(batch_size), label] = x_max
            h = torch.cat((one_hot_label, x), 1)
            goodness = torch.zeros(len(self.layers), batch_size, device=x.device)
            for l, layer in enumerate(self.layers):
                h = layer(h)
                assert isinstance(h, torch.Tensor)
                goodness[l] = h.pow(2).mean(1)
            goodness_per_label[:, k] = torch.sum(goodness, dim=0)
        
        output = torch.zeros(batch_size, self.class_num)
        max_label = goodness_per_label.argmax(1)
        output[range(batch_size), max_label] = 1.0
        return output

    def train(self, x: torch.Tensor, labels: torch.Tensor) -> None:        
        assert x.shape[0] == labels.shape[0]
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        x_max = x.max()
        one_hot_labels = torch.zeros(batch_size, self.class_num, device=x.device)
        one_hot_labels_random = torch.zeros_like(one_hot_labels, device=x.device)
        one_hot_labels[range(batch_size), labels] = x_max
        one_hot_labels_random[range(batch_size), labels[torch.randperm(batch_size, device=x.device)]] = x_max
        x_pos = torch.cat((one_hot_labels, x), 1)
        x_neg = torch.cat((one_hot_labels_random, x), 1)
        h_pos, h_neg = x_pos, x_neg
        for layer in self.layers:
            assert isinstance(layer, Layer)
            h_pos, h_neg = layer.train(h_pos, h_neg)

In [10]:
def train_net(net: Net, data_loader) -> None:
    for x, labels in data_loader:
        assert isinstance(x, torch.Tensor)
        assert isinstance(labels, torch.Tensor)
        x = x.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        net.train(x, labels)

In [11]:
def evaluate(model, data_loader) -> float:
    assert callable(model)
    assert data_loader is not None

    with torch.no_grad():
        n_items = 0
        n_match = 0
        for data, labels in data_loader:
            assert isinstance(data, torch.Tensor)
            assert isinstance(labels, torch.Tensor)
            assert data.shape[0] == labels.shape[0]

            data = data.to(device, non_blocking=True)
            model_output = model(data)
            
            assert isinstance(model_output, torch.Tensor)
            assert model_output.shape[0] == labels.shape[0]

            for i in range(model_output.shape[0]):
                label = labels[i].item()
                assert label >= 0 and label < class_num
                max_feature_index = torch.argmax(model_output[i])
                assert max_feature_index >= 0 and max_feature_index < class_num
                if max_feature_index == label:
                    n_match += 1
                n_items += 1

        accuracy = n_match / n_items
        print(f"Accuracy: {accuracy}")
        return accuracy

In [12]:
net = Net([math.prod(data_item_size), 500, 500], class_num, device)

In [13]:
for _ in tqdm(range(10)):
    train_net(net, data_loader)
    evaluate(net, test_loader)

 10%|█         | 1/10 [00:14<02:11, 14.66s/it]

Accuracy: 0.9079


 20%|██        | 2/10 [00:28<01:52, 14.00s/it]

Accuracy: 0.9096


 30%|███       | 3/10 [00:42<01:38, 14.01s/it]

Accuracy: 0.9359


 40%|████      | 4/10 [00:57<01:25, 14.33s/it]

Accuracy: 0.9406


 50%|█████     | 5/10 [01:09<01:09, 13.82s/it]

Accuracy: 0.9459


 60%|██████    | 6/10 [01:22<00:53, 13.42s/it]

Accuracy: 0.9371


 70%|███████   | 7/10 [01:34<00:39, 13.08s/it]

Accuracy: 0.9423


 80%|████████  | 8/10 [01:47<00:25, 12.93s/it]

Accuracy: 0.9459


 90%|█████████ | 9/10 [02:00<00:12, 12.94s/it]

Accuracy: 0.9499


100%|██████████| 10/10 [02:13<00:00, 13.38s/it]

Accuracy: 0.9483





In [14]:
evaluate(net, test_loader)

Accuracy: 0.9483


0.9483