In [None]:
import os
import sys
import time
import random

import pandas as pd

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm

sys.path.extend(["../", "./models"])

from pc_model import PCNET, train_model, test_model
from mlp import MLP
from cnn import CNN

In [None]:
seed = 333

random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)

if torch.cuda.is_available():
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if torch.backends.mps.is_available():
    torch.mps.manual_seed(seed)

In [None]:
def create_loaders(dataset: str):

    transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465] if "cifar10" == dataset else [0.5],
        std=[0.2023, 0.1994, 0.2010] if "cifar10" == dataset else [0.5]
        )
    ])

    dataset_path = os.path.abspath(f'../datasets/{dataset}/dataset')
    os.makedirs(dataset_path, exist_ok=True)
    batch_size = 128

    if "mnist" == dataset:
        DatasetClass = torchvision.datasets.MNIST
    elif "fmnist" == dataset:
        DatasetClass = torchvision.datasets.FashionMNIST
    elif "cifar10" == dataset:
        DatasetClass = torchvision.datasets.CIFAR10
    else:
        raise ValueError("Unsupported dataset.")

    train_dataset = DatasetClass(root=dataset_path, train=True, transform=transform, download=True)
    test_dataset  = DatasetClass(root=dataset_path, train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

In [None]:
def train_test(model, is_pc, train_loader, test_loader, epochs, optimizer, criterion, inference_lr, inference_iterations, inference_momentum):
    device = next(model.parameters()).device
    best_accuracy = 0

    for epoch in range(epochs):
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{epochs}", leave=False) as progress_bar:
            for inputs, labels in train_loader:

                inputs = inputs.flatten(1).to(device) if inputs.shape[1:] == (1, 28, 28) else inputs.to(device)
                labels = F.one_hot(labels, num_classes=10).float().to(device)
                _, acc = train_model(
                    model,
                    inputs,
                    labels,
                    optimizer,
                    criterion,
                    inference_lr,
                    inference_iterations,
                    inference_momentum,
                    pc_mode=is_pc
                )
                progress_bar.update(1)
            
            model.eval()
            accuracies = []
            with torch.no_grad():
                for inputs, labels in test_loader:
                    inputs = inputs.flatten(1).to(device) if inputs.shape[1:] == (1, 28, 28) else inputs.to(device)
                    labels = F.one_hot(labels, num_classes=10).float().to(device)

                    _, acc = test_model(model, inputs, labels, criterion)
                    accuracies.append(acc)
                    
            best_accuracy = max(best_accuracy, sum(accuracies)/len(accuracies))
            model.train()
            print(best_accuracy, sum(accuracies)/len(accuracies))
            
    
    return best_accuracy

In [None]:
device = 'cuda:6' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print('Available device: ', device)

epochs = 25
trials = 5
criterion = torch.nn.MSELoss()

datasets = ['mnist', 'fmnist', 'cifar10']

In [None]:
for dataset in datasets:

    pc_results = []
    bp_results = []

    if dataset in ['mnist', 'fmnist']:
        inference_lr = 0.5
        inference_iterations = 8
        inference_momentum = 0.8
        generative_lr_pc = 0.0002715
        generative_wd_pc = 0.02715
        generative_lr_bp = 0.000275
        generative_wd_bp = 0.00002


    elif dataset == 'cifar10':
        inference_lr = 0.15
        inference_iterations = 8
        inference_momentum = 0.35
        generative_lr_pc = 0.0006
        generative_wd_pc = 0.0075

        generative_lr_bp = 0.00025
        generative_wd_bp = 0.000075

    train_loader, test_loader = create_loaders(dataset)
    print(f'Comparing PC and BP on {dataset} dataset...')

    for run in range(trials):

        bp_model = CNN(3,10) if dataset == 'cifar10' else MLP(input_size=28*28, hidden_layers=3, hidden_size=128, output_size=10)
        bp_model.to(device)
        bp_optimizer = torch.optim.Adam(bp_model.parameters(), lr=generative_lr_bp, weight_decay=generative_wd_bp)

        pc_model = PCNET(bp_model)
        pc_model.to(device)
        pc_optimizer = torch.optim.Adam(pc_model.module_dict['generative'].parameters(), lr=generative_lr_pc, weight_decay=generative_wd_pc)

        start_time = time.time()
        pc_best_accuracy = train_test(pc_model, True, train_loader, test_loader, epochs, pc_optimizer, criterion, inference_lr, inference_iterations, inference_momentum)
        pc_training_time = time.time() - start_time
        print(f"Predictive Coding best accuracy: {pc_best_accuracy * 100:.2f}% - total training time: {pc_training_time:.2f} s")
        pc_results.append({"run": run + 1, "best_accuracy": pc_best_accuracy, "training_time_sec": pc_training_time})

        start_time = time.time()
        bp_best_accuracy = train_test(bp_model, False, train_loader, test_loader, epochs, bp_optimizer, criterion, 0, 0, 0)
        bp_training_time = time.time() - start_time
        print(f"Backpropagation best accuracy: {bp_best_accuracy * 100:.2f}% - total training time: {bp_training_time:.2f} s")
        bp_results.append({"run": run + 1, "best_accuracy": bp_best_accuracy, "training_time_sec": bp_training_time})

    pc_results = pd.DataFrame(pc_results)
    bp_results = pd.DataFrame(bp_results)

    os.makedirs("./results", exist_ok=True)
    pc_results.to_csv(f"./results/pc_performance_{dataset}.csv", index=False)
    bp_results.to_csv(f"./results/bp_performance_{dataset}.csv", index=False)

    print()
    print(f"{'-'*50}")
    print("Predictive Coding Performance:")
    print(f"Accuracy: Mean = {pc_results['best_accuracy'].mean() * 100:.2f}%, Std = {pc_results['best_accuracy'].std() * 100:.2f}%")
    print(f"Training Time: Mean = {pc_results['training_time_sec'].mean():.2f} s, Std = {pc_results['training_time_sec'].std():.2f} s")

    print("\nBackpropagation Performance:")
    print(f"Accuracy: Mean = {bp_results['best_accuracy'].mean() * 100:.2f}%, Std = {bp_results['best_accuracy'].std() * 100:.2f}%")
    print(f"Training Time: Mean = {bp_results['training_time_sec'].mean():.2f} s, Std = {bp_results['training_time_sec'].std():.2f} s")
    print(f"{'-'*50}")
    print(f"{'-'*50}")
    print()

In [None]:
for dataset in ['mnist', 'fmnist', 'cifar10']:
    pc_results = pd.read_csv(f'./results/pc_performance_{dataset}.csv')
    bp_results = pd.read_csv(f'./results/bp_performance_{dataset}.csv')

    print(f"\n--- Results on {dataset.upper()} ---")

    print("\nBackpropagation:")
    print(f"Accuracy:     {bp_results['best_accuracy'].mean() * 100:.2f} ± {bp_results['best_accuracy'].std() * 100:.2f}%")
    print(f"Train time:   {(bp_results['training_time_sec']/25).mean():.2f} ± {(bp_results['training_time_sec']/25).std():.2f} s")
    
    print("\nPredictive Coding:")
    print(f"Accuracy:     {pc_results['best_accuracy'].mean() * 100:.2f} ± {pc_results['best_accuracy'].std() * 100:.2f}%")
    print(f"Train time:   {(pc_results['training_time_sec']/25).mean():.2f} ± {(pc_results['training_time_sec']/25).std():.2f} s")

