In [None]:
import sys

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.profiler import profile, ProfilerActivity

sys.path.extend(["./models"])

from pc_model import PCNET
from mlp import Autoencoder
from cnn import CNNAutoencoder

In [None]:
def run_profiling(model_type, model, optimizer, inputs, outputs, label):
    criterion = torch.nn.MSELoss()
    model.to('cpu')
    inputs = inputs.to('cpu')
    outputs = outputs.to('cpu')

    print(f"\nProfiling {label} ({model_type}):")
    with profile(activities=[ProfilerActivity.CPU],
                profile_memory=True, record_shapes=True, with_flops=True) as prof:
        if model_type == "PC":
            model.inference_phase(stimuli=inputs, alpha=0.01, inference_timestamps=35, momentum=0.9)
        else: 
            out = model(inputs)
            loss = criterion(out, outputs)
            loss.backward()
        optimizer.step()
    print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))


In [None]:
bp_model = Autoencoder()
bp_optimizer = torch.optim.Adam(bp_model.parameters(), lr=0.001, weight_decay=0.0001)

pc_model = PCNET(bp_model, 1)
pc_optimizer = torch.optim.Adam(pc_model.parameters(), lr=0.001, weight_decay=0.0001)

inputs = torch.ones(1, 28*28)

run_profiling("PC", pc_model, pc_optimizer, inputs, inputs, "MLP")
run_profiling("BP", bp_model, bp_optimizer, inputs, inputs, "MLP")

In [None]:
bp_params = sum(p.numel() for p in bp_model.parameters())
pc_params = sum(p.numel() for p in pc_model.parameters())

print(f"{'Parameter Type':<30} {'Count':>15}")
print(f"{'-'*50}")
print(f"{'BP generative parameters':<30} {bp_params:>15,}")
print(f"{'PC generative parameters':<30} {pc_params:>15,}")
pc_model.module_dict['variational']

In [None]:
bp_model = CNNAutoencoder(3)
bp_optimizer = torch.optim.Adam(bp_model.parameters(), lr=0.001, weight_decay=0.0001)

pc_model = PCNET(bp_model.get_decoder(), 1)
pc_optimizer = torch.optim.Adam(pc_model.parameters(), lr=0.001, weight_decay=0.0001)

inputs = torch.ones(1, 3,32,32)

run_profiling("PC", pc_model, pc_optimizer, inputs, inputs, "MLP")
run_profiling("BP", bp_model, bp_optimizer, inputs, inputs, "MLP")

In [None]:
bp_params = sum(p.numel() for p in bp_model.parameters())
pc_params = sum(p.numel() for p in pc_model.parameters())

print(f"{'Parameter Type':<30} {'Count':>15}")
print(f"{'-'*50}")
print(f"{'BP generative parameters':<30} {bp_params:>15,}")
print(f"{'PC generative parameters':<30} {pc_params:>15,}")
pc_model.module_dict['variational']