In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# Hyperparameters
input_dim = 784        # 28x28 images
label_dim = 10         # 10 classes
hidden_dim = 256
batch_size = 128
num_epochs = 5
timesteps = 5
lr = 1e-3

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256)

# Helper function to one-hot encode labels
def one_hot(labels, num_classes=10):
    return F.one_hot(labels, num_classes).float()

# PFF RNN Model
class PFF_RNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(PFF_RNN, self).__init__()
        self.rep = nn.Linear(input_size + hidden_size, hidden_size)
        self.gen = nn.Linear(hidden_size, input_size) # Modified output size of gen
        self.hidden_size = hidden_size

    def forward_step(self, x, h):
        u = torch.cat([x, h], dim=1)
        h_pred = torch.tanh(self.rep(u))
        recon = torch.sigmoid(self.gen(h_pred))
        return h_pred, recon

    def compute_goodness(self, h):
        return torch.sum(h ** 2, dim=1)

    def forward(self, x, h):
        return self.forward_step(x, h)

# Instantiate model and optimizers
model = PFF_RNN(input_dim + label_dim, hidden_dim).to(device)
optimizer_rep = torch.optim.Adam(model.rep.parameters(), lr=lr)
optimizer_gen = torch.optim.Adam(model.gen.parameters(), lr=lr)

# Training loop
print("Training...")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        y_pos = one_hot(y).to(device)

        # Create incorrect (negative) labels
        y_neg_idx = torch.randint(0, 10, y.shape, device=device)
        y_neg_idx[y_neg_idx == y] = (y_neg_idx[y_neg_idx == y] + 1) % 10
        y_neg = one_hot(y_neg_idx).to(device)

        # Positive phase
        h_pos = torch.zeros(x.size(0), hidden_dim, device=device)
        for _ in range(timesteps):
            x_pos = torch.cat([x, y_pos], dim=1)
            h_pos, recon_pos = model(x_pos, h_pos)
        g_pos = model.compute_goodness(h_pos)

        # Negative phase
        h_neg = torch.zeros(x.size(0), hidden_dim, device=device)
        for _ in range(timesteps):
            x_neg = torch.cat([x, y_neg], dim=1)
            h_neg, recon_neg = model(x_neg, h_neg)
        g_neg = model.compute_goodness(h_neg)

        # Losses
        loss_rep = F.softplus(g_neg - g_pos).mean()
        loss_gen = F.mse_loss(recon_pos, x_pos)

        total_batch_loss = loss_rep + loss_gen
        optimizer_rep.zero_grad()
        optimizer_gen.zero_grad()
        total_batch_loss.backward()
        optimizer_rep.step()
        optimizer_gen.step()

        total_loss += total_batch_loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

# Evaluation loop
print("Evaluating...")
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        batch_size = x.size(0)
        goodness_scores = torch.zeros(batch_size, 10, device=device)

        for label in range(10):
            y_oh = one_hot(torch.full((batch_size,), label, device=device, dtype=torch.long))
            h = torch.zeros(batch_size, hidden_dim, device=device)
            for _ in range(timesteps):
                x_in = torch.cat([x, y_oh], dim=1)
                h, _ = model(x_in, h)
            goodness_scores[:, label] = model.compute_goodness(h)

        preds = goodness_scores.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += batch_size

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

Using device: cpu
Training...
Epoch 1/5, Loss: 0.1273
Epoch 2/5, Loss: 0.0649
Epoch 3/5, Loss: 0.0565
Epoch 4/5, Loss: 0.0466
Epoch 5/5, Loss: 0.0439
Evaluating...
Test Accuracy: 92.82%
