In [None]:
import os
import sys
import time
import random

import pandas as pd

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision
from tqdm import tqdm

sys.path.extend(["./models"])

from pc_model import PCNET, train_model, test_model
from mlp import Autoencoder
from cnn import CNNAutoencoder

In [None]:
seed = 333

random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)

if torch.cuda.is_available():
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if torch.backends.mps.is_available():
    torch.mps.manual_seed(seed)

In [None]:
def create_loaders(dataset: str):

    transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    ])

    dataset_path = os.path.abspath(f'../datasets/{dataset}/dataset')
    os.makedirs(dataset_path, exist_ok=True)
    batch_size = 250

    if "mnist" == dataset:
        DatasetClass = torchvision.datasets.MNIST
    elif "fmnist" == dataset:
        DatasetClass = torchvision.datasets.FashionMNIST
    elif "cifar10" == dataset:
        DatasetClass = torchvision.datasets.CIFAR10
    else:
        raise ValueError("Unsupported dataset.")

    train_dataset = DatasetClass(root=dataset_path, train=True, transform=transform, download=True)
    test_dataset  = DatasetClass(root=dataset_path, train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

In [None]:
def show_reconstruction(model, inputs, is_pc):
    original = inputs[0].cpu()
    a = model.module_dict['variational'][0] if is_pc else inputs
    reconstruction = model.forward(a)[0].detach().cpu()

    fig, axs = plt.subplots(1, 2, figsize=(6, 3))

    if original.shape[0] == 3:
        axs[0].imshow(original.permute(1, 2, 0).numpy())
        axs[1].imshow(reconstruction.permute(1, 2, 0).numpy())
    else:
        axs[0].imshow(original.squeeze().numpy().reshape(28, 28), cmap='gray')
        axs[1].imshow(reconstruction.numpy().reshape(28, 28), cmap='gray')

    axs[0].set_title('Original')
    axs[0].axis('off')

    axs[1].set_title('Reconstruction')
    axs[1].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
def train_test(model, is_pc, train_loader, test_loader, epochs, optimizer, criterion, inference_lr, inference_iterations, inference_momentum, plot_reconstruction):
    device = next(model.parameters()).device
    best_loss = 10**10

    for epoch in range(epochs):
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{epochs}", leave=False) as progress_bar:
            for train_inputs, labels in train_loader:
                
                train_inputs = train_inputs.view(-1, 28 * 28).to(device) if train_inputs.shape[1:] == (1, 28, 28) else train_inputs.to(device)
                loss = train_model(
                    model,
                    train_inputs,
                    optimizer,
                    criterion,
                    inference_lr,
                    inference_iterations,
                    inference_momentum,
                    pc_mode=is_pc
                )
                progress_bar.update(1)

            if plot_reconstruction:
                show_reconstruction(model, train_inputs, is_pc)
            
            model.eval()
            losses = []
            
            for test_inputs, labels in test_loader:
                test_inputs = test_inputs.view(-1, 28 * 28).to(device) if test_inputs.shape[1:] == (1, 28, 28) else test_inputs.to(device)

                loss = test_model(
                    model=model,
                    stimuli=test_inputs,
                    criterion=criterion,
                    inference_lr=inference_lr,
                    inference_time=inference_iterations,
                    momentum=inference_momentum,
                    pc_mode=is_pc
                )
                losses.append(loss)

            if plot_reconstruction:
                show_reconstruction(model, test_inputs, is_pc)

            best_loss = min(best_loss, sum(losses)/len(losses))
            #print(best_loss, sum(losses)/len(losses))
            model.train()
            
    
    return best_loss

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print('Available device: ', device)

epochs = 5
trials = 5
criterion = torch.nn.MSELoss()

datasets = ['mnist', 'fmnist', 'cifar10']
plot_reconstruction = False

inference_lr = 0.175
inference_iterations = 35
inference_momentum = 0.25

In [None]:
for dataset in datasets:
    pc_results = []
    bp_results = []

    train_loader, test_loader = create_loaders(dataset)
    print(f'Comparing PC and BP on {dataset} dataset...')

    if dataset in ['mnist', 'fashionmnist']:
        generative_lr_pc = 0.00035
        generative_wd_pc = 0.00002

        generative_lr_bp = 0.00035 
        generative_wd_bp = 0.00005

    elif dataset == 'cifar10':
        generative_lr_pc = 0.001
        generative_wd_pc = 0.000025

        generative_lr_bp = 0.0005
        generative_wd_bp = 0.00003

    for run in range(trials):

        bp_model = CNNAutoencoder(3) if dataset == 'cifar10' else Autoencoder() 
        bp_optimizer = torch.optim.Adam(bp_model.parameters(), lr=generative_lr_bp, weight_decay=generative_wd_bp)

        pc_model = PCNET(bp_model.get_decoder(), batch_size=250)
        pc_optimizer = torch.optim.Adam(pc_model.module_dict['generative'].parameters(), lr=generative_lr_pc, weight_decay=generative_wd_pc)

        bp_model.to(device)
        pc_model.to(device)

        start_time = time.time()
        pc_best_loss = train_test(pc_model, True, train_loader, test_loader, epochs, pc_optimizer, criterion, inference_lr, inference_iterations, inference_momentum, plot_reconstruction)
        pc_training_time = time.time() - start_time
        print(f"Predictive Coding best loss: {pc_best_loss:.5f} - total training time: {pc_training_time:.2f} s")
        pc_results.append({"run": run + 1, "best_loss": pc_best_loss, "training_time_sec": pc_training_time})

        start_time = time.time()
        bp_best_loss = train_test(bp_model, False, train_loader, test_loader, epochs, bp_optimizer, criterion, inference_lr, inference_iterations, inference_momentum, plot_reconstruction)
        bp_training_time = time.time() - start_time
        print(f"Backpropagation best loss: {bp_best_loss:.5f} - total training time: {bp_training_time:.2f} s")
        bp_results.append({"run": run + 1, "best_loss": bp_best_loss, "training_time_sec": bp_training_time})

    pc_results = pd.DataFrame(pc_results)
    bp_results = pd.DataFrame(bp_results)

    os.makedirs("./results", exist_ok=True)
    pc_results.to_csv(f"./results/pc_performance_{dataset}.csv", index=False)
    bp_results.to_csv(f"./results/bp_performance_{dataset}.csv", index=False)

    print('-'*50)
    print("Predictive Coding Performance:")
    print(f"Loss: Mean = {pc_results['best_loss'].mean():.5f}, Std = {pc_results['best_loss'].std():.5f}")
    print(f"Training Time: Mean = {pc_results['training_time_sec'].mean():.2f} s, Std = {pc_results['training_time_sec'].std():.2f} s")

    print("\nBackpropagation Performance:")
    print(f"Loss: Mean = {bp_results['best_loss'].mean():.5f}, Std = {bp_results['best_loss'].std():.5f}")
    print(f"Training Time: Mean = {bp_results['training_time_sec'].mean():.2f} s, Std = {bp_results['training_time_sec'].std():.2f} s")
    print('-'*50)
    print('-'*50)

In [None]:
for dataset in ['mnist', 'fmnist', 'cifar10']:
    pc_results = pd.read_csv(f'./results/pc_performance_{dataset}.csv')
    bp_results = pd.read_csv(f'./results/bp_performance_{dataset}.csv')

    print(f"\n--- Results on {dataset.upper()} ---")

    print("\nPredictive Coding:")
    print(f"Loss:     {pc_results['best_loss'].mean():.4f} ± {pc_results['best_loss'].std():.5f}")
    print(f"Train time:   {(pc_results['training_time_sec']/5).mean():.2f} ± {(pc_results['training_time_sec']/5).std():.2f} s")

    print("\nBackpropagation:")
    print(f"Loss:     {bp_results['best_loss'].mean():.4} ± {bp_results['best_loss'].std():.5f}")
    print(f"Train time:   {(bp_results['training_time_sec']/5).mean():.2f} ± {(bp_results['training_time_sec']/5).std():.2f} s")