# PC graph example, MNIST dataset
This notebook gives a simple example use of the PC graph implementation applied to MNIST.

In [1]:
import sys
sys.path.append('/Users/6884407/PRECO')

from PRECO.utils import *
import PRECO.optim as optim
from PRECO.PCG import *

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, SubsetRandomSampler, random_split
from tqdm import tqdm



Set folder for dataset and for saved files, and set the seed for reproducibility.

In [2]:
DATASET_PATH = './data'
SAVE_PATH = f"output/PCG_{dt_string}"

seed(0)

Define dataset. Here, we use MNIST.

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = torchvision.datasets.MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
test_set = torchvision.datasets.MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

Define structural parameters. 

We can use get_mask_hierarchical() to get a *hierarchical mask*, i.e. to make the PCG equal to a PCN (FNN trained with IL). Note that to get the same updates as a PCN, use_input_error = False is required. 

AMB means we use a prediction convention Activation-Matrix-Bias, i.e. $\mu=wf(a)+b$.
MBA means we use a prediction convention Matrix-Bias-Activation, i.e. $\mu=f(wa+b)$.

In [4]:
f = tanh
use_bias = True
shape = [784, 48, 10] # input, hidden, output
mask = get_mask_hierarchical([784,32,16,10])

structure = PCG_AMB(f=f, 
                    use_bias=use_bias,
                    shape=shape,
                    mask=mask,
                    )

2024-07-22 12:56:51,535 - INFO - Hierarchical mask, layers: 3, using feedforward initialization and testing.


Define PC training parameters. Define the PCgraph object, an optimizer, and couple the optimizer to the PCgraph. (This is necessary in PC because for incremental mode one has to call the optimizer from within the PCgraph class.)

Compared to PCNs, PC graphs also have a T_test variable.

In [5]:
# Inference
lr_x = 0.5                  # inference rate 
T_train = 5                 # inference time scale
T_test = 10                 # unused for hierarchical model
incremental = True          # whether to use incremental EM or not
use_input_error = False     # whether to use errors in the input layer or not

# Learning
lr_w = 0.00001              # learning rate
batch_size = 250 
weight_decay = 0             
grad_clip = 1
batch_scale = False

PCG = PCgraph(structure=structure,
            lr_x=lr_x, 
            T_train=T_train,
            T_test=T_test,
            incremental=incremental, 
            use_input_error=use_input_error,
            )

optimizer = optim.Adam(
    PCG.params,
    learning_rate=lr_w,
    grad_clip=grad_clip,
    batch_scale=batch_scale,
    weight_decay=weight_decay,
)

PCG.set_optimizer(optimizer)

Define the training and validation set.

In [6]:
train_set, val_set = random_split(train_dataset, [50000, 10000])
train_indices = train_subset_indices(train_set, 10, no_per_class=0) # if a certain number of samples per class is required, set no_per_class to that number. 0 means all samples are used.

train_loader = preprocess( DataLoader(train_set, batch_size=batch_size, sampler=SubsetRandomSampler( train_indices ), drop_last=False) ) # subsetrandomsampler shuffles the data.
val_loader = preprocess( DataLoader(val_set, batch_size=len(val_set), shuffle=False, drop_last=False) )
test_loader = preprocess( DataLoader(test_set, batch_size=len(test_set), shuffle=False, drop_last=False) )

The objective of PCGs is the *energy*, accessible by PCG.get_energy(). This is a sum of MSE-loss, and an *internal* energy:
$$
E = \mathcal{L} +\widetilde{E}
$$
with $\mathcal{L}$ the MSE loss and $\widetilde{E}$ the internal energy. $E$ is not computed during testing since the internal energy then is zero, so $E$ can simply be computed using torch.nn.MSELoss(). Thus, for our early stopper we also use the MSE loss, instead of the energy.

We define the MSE loss, and lists to keep track of performance metrics. 

Then we get the main training loop.

In [7]:
MSE = torch.nn.MSELoss()

train_energy, train_loss, train_acc = [], [], []
val_loss, val_acc = [], []

early_stopper = optim.EarlyStopper(patience=5, min_delta=0)

epochs = 30

start_time = datetime.now()

with torch.no_grad():
    for i in tqdm(range(epochs)):
        
        energy = 0
        for batch_no, (X_batch, y_batch) in enumerate(train_loader):
            PCG.train_supervised(X_batch, y_batch)
            energy += PCG.get_energy()
        train_energy.append(energy/len(train_loader))

        loss, acc = 0, 0
        for X_batch, y_batch in val_loader:
            y_pred = PCG.test_supervised(X_batch) 

            loss += MSE(y_pred, onehot(y_batch, N=10) ).item()
            acc += torch.mean(( torch.argmax(y_pred, axis=1) == y_batch ).float()).item()

        val_acc.append(acc/len(val_loader))
        val_loss.append(loss)

        print(f"\nEPOCH {i+1}/{epochs} \n #####################")   
        print(f"VAL acc:   {val_acc[i]:.3f}, VAL MSE:   {val_loss[i]:.3f}, TRAIN ENERGY:   {train_energy[i]:.3f}")

        if early_stopper.early_stop(val_loss[i]):
            print(f"\nEarly stopping at epoch {i+1}")          
            break

print(f"\nTraining time: {datetime.now() - start_time}")


  3%|▎         | 1/30 [00:05<02:29,  5.16s/it]


EPOCH 1/30 
 #####################
VAL acc:   0.535, VAL MSE:   0.082, TRAIN ENERGY:   215.799
Validation objective decreased (inf --> 0.081876).


  7%|▋         | 2/30 [00:10<02:30,  5.38s/it]


EPOCH 2/30 
 #####################
VAL acc:   0.654, VAL MSE:   0.070, TRAIN ENERGY:   181.281
Validation objective decreased (0.081876 --> 0.070473).


 10%|█         | 3/30 [00:15<02:20,  5.20s/it]


EPOCH 3/30 
 #####################
VAL acc:   0.701, VAL MSE:   0.062, TRAIN ENERGY:   156.505
Validation objective decreased (0.070473 --> 0.061989).


 13%|█▎        | 4/30 [00:21<02:19,  5.36s/it]


EPOCH 4/30 
 #####################
VAL acc:   0.747, VAL MSE:   0.055, TRAIN ENERGY:   137.362
Validation objective decreased (0.061989 --> 0.055219).


 17%|█▋        | 5/30 [00:26<02:10,  5.24s/it]


EPOCH 5/30 
 #####################
VAL acc:   0.784, VAL MSE:   0.050, TRAIN ENERGY:   121.290
Validation objective decreased (0.055219 --> 0.049563).


 20%|██        | 6/30 [00:31<02:04,  5.21s/it]


EPOCH 6/30 
 #####################
VAL acc:   0.810, VAL MSE:   0.045, TRAIN ENERGY:   108.166
Validation objective decreased (0.049563 --> 0.045128).


 23%|██▎       | 7/30 [00:36<01:58,  5.15s/it]


EPOCH 7/30 
 #####################
VAL acc:   0.828, VAL MSE:   0.042, TRAIN ENERGY:   97.867
Validation objective decreased (0.045128 --> 0.041702).


 27%|██▋       | 8/30 [00:41<01:53,  5.16s/it]


EPOCH 8/30 
 #####################
VAL acc:   0.842, VAL MSE:   0.039, TRAIN ENERGY:   89.604
Validation objective decreased (0.041702 --> 0.038929).


 30%|███       | 9/30 [00:46<01:49,  5.20s/it]


EPOCH 9/30 
 #####################
VAL acc:   0.854, VAL MSE:   0.037, TRAIN ENERGY:   82.561
Validation objective decreased (0.038929 --> 0.036519).


 33%|███▎      | 10/30 [00:52<01:46,  5.35s/it]


EPOCH 10/30 
 #####################
VAL acc:   0.862, VAL MSE:   0.034, TRAIN ENERGY:   76.176
Validation objective decreased (0.036519 --> 0.034294).


 37%|███▋      | 11/30 [00:57<01:40,  5.27s/it]


EPOCH 11/30 
 #####################
VAL acc:   0.871, VAL MSE:   0.032, TRAIN ENERGY:   70.188
Validation objective decreased (0.034294 --> 0.032195).


 40%|████      | 12/30 [01:03<01:35,  5.32s/it]


EPOCH 12/30 
 #####################
VAL acc:   0.878, VAL MSE:   0.030, TRAIN ENERGY:   64.613
Validation objective decreased (0.032195 --> 0.030257).


 43%|████▎     | 13/30 [01:08<01:32,  5.44s/it]


EPOCH 13/30 
 #####################
VAL acc:   0.883, VAL MSE:   0.029, TRAIN ENERGY:   59.604
Validation objective decreased (0.030257 --> 0.028546).


 47%|████▋     | 14/30 [01:14<01:26,  5.44s/it]


EPOCH 14/30 
 #####################
VAL acc:   0.889, VAL MSE:   0.027, TRAIN ENERGY:   55.269
Validation objective decreased (0.028546 --> 0.027090).


 50%|█████     | 15/30 [01:19<01:21,  5.46s/it]


EPOCH 15/30 
 #####################
VAL acc:   0.894, VAL MSE:   0.026, TRAIN ENERGY:   51.593
Validation objective decreased (0.027090 --> 0.025867).


 53%|█████▎    | 16/30 [01:25<01:16,  5.47s/it]


EPOCH 16/30 
 #####################
VAL acc:   0.898, VAL MSE:   0.025, TRAIN ENERGY:   48.479
Validation objective decreased (0.025867 --> 0.024842).


 57%|█████▋    | 17/30 [01:31<01:12,  5.54s/it]


EPOCH 17/30 
 #####################
VAL acc:   0.901, VAL MSE:   0.024, TRAIN ENERGY:   45.824
Validation objective decreased (0.024842 --> 0.023977).


 60%|██████    | 18/30 [01:36<01:06,  5.56s/it]


EPOCH 18/30 
 #####################
VAL acc:   0.903, VAL MSE:   0.023, TRAIN ENERGY:   43.537
Validation objective decreased (0.023977 --> 0.023240).


 63%|██████▎   | 19/30 [01:42<01:00,  5.51s/it]


EPOCH 19/30 
 #####################
VAL acc:   0.906, VAL MSE:   0.023, TRAIN ENERGY:   41.540
Validation objective decreased (0.023240 --> 0.022607).


 67%|██████▋   | 20/30 [01:47<00:54,  5.47s/it]


EPOCH 20/30 
 #####################
VAL acc:   0.907, VAL MSE:   0.022, TRAIN ENERGY:   39.773
Validation objective decreased (0.022607 --> 0.022057).


 70%|███████   | 21/30 [01:52<00:49,  5.45s/it]


EPOCH 21/30 
 #####################
VAL acc:   0.908, VAL MSE:   0.022, TRAIN ENERGY:   38.189
Validation objective decreased (0.022057 --> 0.021574).


 73%|███████▎  | 22/30 [01:58<00:43,  5.44s/it]


EPOCH 22/30 
 #####################
VAL acc:   0.910, VAL MSE:   0.021, TRAIN ENERGY:   36.750
Validation objective decreased (0.021574 --> 0.021145).


 77%|███████▋  | 23/30 [02:03<00:38,  5.43s/it]


EPOCH 23/30 
 #####################
VAL acc:   0.911, VAL MSE:   0.021, TRAIN ENERGY:   35.428
Validation objective decreased (0.021145 --> 0.020760).


 80%|████████  | 24/30 [02:09<00:32,  5.44s/it]


EPOCH 24/30 
 #####################
VAL acc:   0.912, VAL MSE:   0.020, TRAIN ENERGY:   34.200
Validation objective decreased (0.020760 --> 0.020410).


 83%|████████▎ | 25/30 [02:14<00:27,  5.46s/it]


EPOCH 25/30 
 #####################
VAL acc:   0.913, VAL MSE:   0.020, TRAIN ENERGY:   33.049
Validation objective decreased (0.020410 --> 0.020090).


 87%|████████▋ | 26/30 [02:19<00:21,  5.44s/it]


EPOCH 26/30 
 #####################
VAL acc:   0.915, VAL MSE:   0.020, TRAIN ENERGY:   31.965
Validation objective decreased (0.020090 --> 0.019797).


 90%|█████████ | 27/30 [02:25<00:16,  5.41s/it]


EPOCH 27/30 
 #####################
VAL acc:   0.917, VAL MSE:   0.020, TRAIN ENERGY:   30.940
Validation objective decreased (0.019797 --> 0.019527).


 93%|█████████▎| 28/30 [02:30<00:10,  5.37s/it]


EPOCH 28/30 
 #####################
VAL acc:   0.918, VAL MSE:   0.019, TRAIN ENERGY:   29.968
Validation objective decreased (0.019527 --> 0.019279).


 97%|█████████▋| 29/30 [02:35<00:05,  5.30s/it]


EPOCH 29/30 
 #####################
VAL acc:   0.919, VAL MSE:   0.019, TRAIN ENERGY:   29.046
Validation objective decreased (0.019279 --> 0.019052).


100%|██████████| 30/30 [02:41<00:00,  5.37s/it]


EPOCH 30/30 
 #####################
VAL acc:   0.920, VAL MSE:   0.019, TRAIN ENERGY:   28.169
Validation objective decreased (0.019052 --> 0.018844).

Training time: 0:02:41.054548





Final unbiased estimator on test set:

In [8]:
loss, acc = 0, 0
for X_batch, y_batch in test_loader:
    y_pred = PCG.test_supervised(X_batch) 

    loss += MSE(y_pred, onehot(y_batch,N=10) ).item()
    acc += torch.mean(( torch.argmax(y_pred, axis=1) == y_batch).float()).item() 

test_energy = energy/len(test_loader)
test_acc = acc/len(test_loader)
test_loss = loss/len(test_loader)

print(f"\nTEST acc:   {test_acc:.3f}, TEST MSE:   {test_loss:.3f}")
print("Training & testing finished in %s" % str((datetime.now() - start_time)).split('.')[0])


TEST acc:   0.923, TEST MSE:   0.018
Training & testing finished in 0:02:41
