In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Hyperparameters
label_dim = 10
hidden_dim = 256
batch_size = 128
num_epochs = 10
timesteps = 5
lr = 1e-3

# Data loading and preprocessing
transform = transforms.Compose(
    [
        transforms.ToTensor(),  # Keep image as (1, 28, 28)
    ]
)

train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256)

# One-hot encoding


def one_hot(labels, num_classes=10):
    return F.one_hot(labels, num_classes).float()


# PFF with CNN front-end


class PFF_CNN(nn.Module):
    def __init__(self, label_dim=10, hidden_size=256):
        super(PFF_CNN, self).__init__()
        # Convolutional layers (LeNet-like)
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)  # C1
        self.pool1 = nn.AvgPool2d(2, 2)  # S2
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)  # C3
        self.pool2 = nn.AvgPool2d(2, 2)  # S4
        self.conv3 = nn.Conv2d(16, 120, kernel_size=3)  # C5 - Reduced kernel size

        self.feature_dim = (
            120 * 2 * 2
        )  # Final flattened CNN output after conv3 (kernel 3)
        self.hidden_size = hidden_size
        self.label_dim = label_dim

        # Representation & generative pathways
        self.rep = nn.Linear(self.feature_dim + label_dim + hidden_size, hidden_size)
        self.gen = nn.Linear(hidden_size, self.feature_dim + label_dim)

    def extract_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        return x.view(x.size(0), -1)  # Flatten

    def forward_step(self, x, y_onehot, h):
        feats = self.extract_features(x)
        u = torch.cat([feats, y_onehot, h], dim=1)
        h_pred = torch.tanh(self.rep(u))
        recon = torch.sigmoid(self.gen(h_pred))
        return h_pred, recon

    def compute_goodness(self, h):
        return torch.sum(h**2, dim=1)

    def forward(self, x, y_onehot, h):
        return self.forward_step(x, y_onehot, h)


# Instantiate model and optimizers
model = PFF_CNN(label_dim, hidden_dim).to(device)
optimizer_rep = torch.optim.Adam(model.rep.parameters(), lr=lr)
optimizer_gen = torch.optim.Adam(model.gen.parameters(), lr=lr)

# Training loop
print("Training...")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        y_pos = one_hot(y).to(device)

        # Negative (wrong) labels
        y_neg_idx = torch.randint(0, 10, y.shape, device=device)
        y_neg_idx[y_neg_idx == y] = (y_neg_idx[y_neg_idx == y] + 1) % 10
        y_neg = one_hot(y_neg_idx).to(device)

        # Positive phase
        h_pos = torch.zeros(x.size(0), hidden_dim, device=device)
        for _ in range(timesteps):
            h_pos, recon_pos = model(x, y_pos, h_pos)
        g_pos = model.compute_goodness(h_pos)

        # Negative phase
        h_neg = torch.zeros(x.size(0), hidden_dim, device=device)
        for _ in range(timesteps):
            h_neg, recon_neg = model(x, y_neg, h_neg)
        g_neg = model.compute_goodness(h_neg)

        # Losses
        loss_rep = F.softplus(g_neg - g_pos).mean()
        recon_target = torch.cat([model.extract_features(x).detach(), y_pos], dim=1)
        loss_gen = F.mse_loss(recon_pos, recon_target)

        total_batch_loss = loss_rep + loss_gen
        optimizer_rep.zero_grad()
        optimizer_gen.zero_grad()
        total_batch_loss.backward()
        optimizer_rep.step()
        optimizer_gen.step()

        total_loss += total_batch_loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

# Evaluation loop
print("Evaluating...")
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        batch_size = x.size(0)
        goodness_scores = torch.zeros(batch_size, 10, device=device)

        for label in range(10):
            y_oh = one_hot(
                torch.full((batch_size,), label, device=device, dtype=torch.long)
            )
            h = torch.zeros(batch_size, hidden_dim, device=device)
            for _ in range(timesteps):
                h, _ = model(x, y_oh, h)
            goodness_scores[:, label] = model.compute_goodness(h)

        preds = goodness_scores.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += batch_size

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Using device: cpu
Training...
Epoch 1/10, Loss: 0.2418
Epoch 2/10, Loss: 0.1150
Epoch 3/10, Loss: 0.0943
Epoch 4/10, Loss: 0.0767
Epoch 5/10, Loss: 0.0662
Epoch 6/10, Loss: 0.0654
Epoch 7/10, Loss: 0.0615
Epoch 8/10, Loss: 0.0568
Epoch 9/10, Loss: 0.0537
Epoch 10/10, Loss: 0.0498
Evaluating...
Test Accuracy: 91.33%
