In [1]:
# Rework of https://github.com/mohammadpz/pytorch_forward_forward.git

In [2]:
from tqdm import tqdm
import torch
from torch.optim import Adam
import torch.nn as nn
from torchvision import datasets, transforms

In [3]:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cuda


In [4]:
train_batch_size = 256
test_batch_size = 1000

train_kwargs = {'batch_size': train_batch_size, 'shuffle': False}
test_kwargs = {'batch_size': test_batch_size, 'shuffle': False}
cuda_kwargs = {'num_workers': 1, 'pin_memory': True}

if device.type == "cuda":
    train_kwargs |= cuda_kwargs
    test_kwargs |= cuda_kwargs

In [5]:
transform = transforms.Compose([transforms.ToTensor()])

In [6]:
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)

dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)

train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [7]:
def get_dataset_labels(data_loader) -> set:
    assert data_loader is not None

    result = set()
    with torch.no_grad():
        for (_, labels) in data_loader:
            assert isinstance(labels, torch.Tensor)
            for label in labels:
                result.add(label.item())

    return result

In [8]:
dataset_labels = get_dataset_labels(train_loader)
class_num = len(dataset_labels)

In [9]:
class Layer(nn.Linear):
    def __init__(self, in_features, out_features,
                 bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.opt = Adam(self.parameters(), lr=0.03)
        self.threshold = 2.0
        self.num_epochs = 10

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return torch.relu(super().forward(x_direction))            

    def train(self, x_pos: torch.Tensor, x_neg: torch.Tensor):
        for _ in range(self.num_epochs):
            g_pos = self.forward(x_pos).pow(2).mean(1)
            g_neg = self.forward(x_neg).pow(2).mean(1)
            # The following loss pushes pos (neg) samples to
            # values larger (smaller) than the self.threshold.
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()
            self.opt.zero_grad()
            # this backward just compute the derivative and hence
            # is not considered backpropagation.
            loss.backward()
            self.opt.step()
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()

In [10]:
def overlay_y_on_x(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    with torch.no_grad():
        x_ = x.clone()
        x_[:, :10] = 0.0
        x_[range(x.shape[0]), y] = x.max()
        return x_

In [11]:
class Net(torch.nn.Module):
    def __init__(self, dims, device):
        super().__init__()
        self.layers = [Layer(in_features=dims[d], out_features=dims[d + 1], device=device) for d in range(len(dims) - 1)]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        goodness_per_label = []
        for label in range(10):
            h = overlay_y_on_x(x, label)
            goodness = []
            for layer in self.layers:
                h = layer(h)
                assert isinstance(h, torch.Tensor)
                goodness += [h.pow(2).mean(1)]
            goodness_per_label += [sum(goodness).unsqueeze(1)]
        goodness_per_label = torch.cat(goodness_per_label, 1)
        
        output = torch.zeros(x.shape[0], 10)
        max_label = goodness_per_label.argmax(1)
        output[range(x.shape[0]), max_label] = 1.0
        return output

    def train(self, x_pos: torch.Tensor, x_neg: torch.Tensor) -> None:
        h_pos, h_neg = x_pos, x_neg
        for layer in self.layers:
            assert isinstance(layer, Layer)
            h_pos, h_neg = layer.train(h_pos, h_neg)

In [12]:
def train_net(net: Net, train_loader: torch.utils.data.DataLoader) -> None:
    for (x, y) in train_loader:
        assert isinstance(x, torch.Tensor)
        assert isinstance(y, torch.Tensor)
        
        x = x.to(device)
        y = y.to(device)

        x = x.view(x.shape[0], -1)

        x_pos = overlay_y_on_x(x, y)
        rnd = torch.randperm(x.shape[0], device=x.device)
        x_neg = overlay_y_on_x(x, y[rnd])

        net.train(x_pos, x_neg)

In [13]:
def evaluate(model, data_loader) -> float:
    assert callable(model)
    assert data_loader is not None

    with torch.no_grad():
        n_items = 0
        n_match = 0
        for (data, labels) in data_loader:
            assert isinstance(data, torch.Tensor)
            assert isinstance(labels, torch.Tensor)
            assert data.shape[0] == labels.shape[0]

            data = data.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            data = data.view(data.shape[0], -1)
            model_output = model(data)
            
            assert isinstance(model_output, torch.Tensor)
            assert model_output.shape[0] == labels.shape[0]

            for i in range(model_output.shape[0]):
                label = labels[i].item()
                assert label >= 0 and label < class_num
                max_feature_index = torch.argmax(model_output[i])
                assert max_feature_index >= 0 and max_feature_index < class_num
                if max_feature_index == label:
                    n_match += 1
                n_items += 1

        accuracy = float(n_match) / n_items
        print(f"Accuracy: {accuracy}")
        return accuracy

In [14]:
net = Net([784, 500, 500], device)

In [15]:
for _ in tqdm(range(10)):
    train_net(net, train_loader)
    evaluate(net, test_loader)

 10%|█         | 1/10 [00:22<03:18, 22.08s/it]

Accuracy: 0.8904


 20%|██        | 2/10 [00:42<02:49, 21.16s/it]

Accuracy: 0.9051


 30%|███       | 3/10 [00:59<02:14, 19.21s/it]

Accuracy: 0.9233


 40%|████      | 4/10 [01:15<01:48, 18.14s/it]

Accuracy: 0.9466


 50%|█████     | 5/10 [01:32<01:28, 17.61s/it]

Accuracy: 0.9452


 60%|██████    | 6/10 [01:49<01:09, 17.39s/it]

Accuracy: 0.947


 70%|███████   | 7/10 [02:05<00:51, 17.06s/it]

Accuracy: 0.946


 80%|████████  | 8/10 [02:22<00:33, 16.93s/it]

Accuracy: 0.957


 90%|█████████ | 9/10 [02:35<00:15, 15.67s/it]

Accuracy: 0.9495


100%|██████████| 10/10 [02:48<00:00, 16.84s/it]

Accuracy: 0.9524





In [16]:
evaluate(net, test_loader)

Accuracy: 0.9524


0.9524