# PCN (hierarchical) example, MNIST dataset
This notebook gives a simple example use of the discriminative PCN implementation applied to classification on MNIST.

In [1]:
import sys
sys.path.append('/Users/6884407/PRECO')

from PRECO.utils import *
import PRECO.optim as optim
from PRECO.PCN import *

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, SubsetRandomSampler, random_split
from tqdm import tqdm



Set folder for dataset and for saved files, and set the seed for reproducibility.

In [2]:
DATASET_PATH = './data'
SAVE_PATH = f"output/PCG_{dt_string}"

seed(0)

Define dataset. Here, we use MNIST.

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = torchvision.datasets.MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
test_set = torchvision.datasets.MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

Define structural parameters. 

Setting the variable upward=True refers to the use of $\mu^\ell=a^{\ell-1}f(w^{\ell-1})+b^{\ell-1}$, i.e. use of the *discriminative PCN*. Setting this to upward=False refers to the use of $\mu^\ell=a^{\ell+1}f(w^{\ell+1})+b^{\ell+1}$, i.e. use of the *generative PCN*.

In [4]:
f = tanh
use_bias = True
upward = True
layers = [784, 32, 16, 10]

structure = PCN_AMB(f=f, 
                     use_bias=use_bias, 
                     upward=upward, 
                     layers=layers,
                     )


Define PC training parameters. Define the PCnet object, an optimizer, and couple the optimizer to the PCnet. (This is necessary in PC because for incremental mode one has to call the optimizer from within the PCgraph class.)

In [5]:
# Inference
lr_x = 0.5                  # inference rate 
T_train = 5                 # inference time scale
incremental = True          # whether to use incremental EM or not
use_input_error = False     # whether to use errors in the input layer or not

# Learning
lr_w = 0.0001               # learning rate
batch_size = 200 
weight_decay = 0            # weight decay
grad_clip = 1
batch_scale = False

PCN = PCnet(structure=structure,
              lr_x=lr_x,
              T_train=T_train, 
              incremental=False,
              use_feedforward_init=True,
            )

optimizer = optim.Adam(
    PCN.params,
    learning_rate=lr_w,
    grad_clip=grad_clip,
    batch_scale=batch_scale,
    weight_decay=weight_decay,
)

PCN.set_optimizer(optimizer)

In [6]:
train_set, val_set = random_split(train_dataset, [50000, 10000])
train_indices = train_subset_indices(train_set, 10, no_per_class=0) # if a certain number of samples per class is required, set no_per_class to that number. 0 means all samples are used.

train_loader = preprocess( DataLoader(train_set, batch_size=batch_size, sampler=SubsetRandomSampler( train_indices ), drop_last=False) ) # subsetrandomsampler shuffles the data.
val_loader = preprocess( DataLoader(val_set, batch_size=len(val_set), shuffle=False, drop_last=False) )
test_loader = preprocess( DataLoader(test_set, batch_size=len(test_set), shuffle=False, drop_last=False) )

The objective of PCNs is the *energy*, accessible by PCG.get_energy(). This is a sum of MSE-loss, and an *internal* energy:
$$
E = \mathcal{L} +\widetilde{E}
$$
with $\mathcal{L}$ the MSE loss and $\widetilde{E}$ the internal energy. $E$ is not computed during testing since the internal energy then is zero, so $E$ can simply be computed using torch.nn.MSELoss(). Thus, for our early stopper we also use the MSE loss, instead of the energy.

We define the MSE loss, and lists to keep track of performance metrics. 

Then we get the main training loop.

In [7]:
MSE = torch.nn.MSELoss()

train_energy, train_loss, train_acc = [], [], []
val_loss, val_acc = [], []

early_stopper = optim.EarlyStopper(patience=5, min_delta=0)

epochs = 30

start_time = datetime.now()

with torch.no_grad():
    for i in tqdm(range(epochs)):
        
        energy = 0
        for batch_no, (X_batch, y_batch) in enumerate(train_loader):
            PCN.train_supervised(X_batch, y_batch)
            energy += PCN.get_energy()
        train_energy.append(energy/len(train_loader))

        loss, acc = 0, 0
        for X_batch, y_batch in val_loader:
            y_pred = PCN.test_supervised(X_batch) 

            loss += MSE(y_pred, onehot(y_batch, N=10) ).item()
            acc += torch.mean(( torch.argmax(y_pred, axis=1) == y_batch ).float()).item()

        val_acc.append(acc/len(val_loader))
        val_loss.append(loss)

        print(f"\nEPOCH {i+1}/{epochs} \n #####################")   
        print(f"VAL acc:   {val_acc[i]:.3f}, VAL MSE:   {val_loss[i]:.3f}, TRAIN ENERGY:   {train_energy[i]:.3f}")

        if early_stopper.early_stop(val_loss[i]):
            print(f"\nEarly stopping at epoch {i+1}")          
            break

print(f"\nTraining time: {datetime.now() - start_time}")


  3%|▎         | 1/30 [00:00<00:26,  1.11it/s]


EPOCH 1/30 
 #####################
VAL acc:   0.735, VAL MSE:   0.068, TRAIN ENERGY:   0.028
Validation objective decreased (inf --> 0.067676).


  7%|▋         | 2/30 [00:01<00:24,  1.14it/s]


EPOCH 2/30 
 #####################
VAL acc:   0.833, VAL MSE:   0.049, TRAIN ENERGY:   0.006
Validation objective decreased (0.067676 --> 0.049452).


 10%|█         | 3/30 [00:02<00:23,  1.15it/s]


EPOCH 3/30 
 #####################
VAL acc:   0.869, VAL MSE:   0.038, TRAIN ENERGY:   0.003
Validation objective decreased (0.049452 --> 0.038203).


 13%|█▎        | 4/30 [00:03<00:22,  1.15it/s]


EPOCH 4/30 
 #####################
VAL acc:   0.880, VAL MSE:   0.031, TRAIN ENERGY:   0.002
Validation objective decreased (0.038203 --> 0.031086).


 17%|█▋        | 5/30 [00:04<00:22,  1.13it/s]


EPOCH 5/30 
 #####################
VAL acc:   0.889, VAL MSE:   0.027, TRAIN ENERGY:   0.001
Validation objective decreased (0.031086 --> 0.027072).


 20%|██        | 6/30 [00:05<00:21,  1.11it/s]


EPOCH 6/30 
 #####################
VAL acc:   0.898, VAL MSE:   0.024, TRAIN ENERGY:   0.001
Validation objective decreased (0.027072 --> 0.024468).


 23%|██▎       | 7/30 [00:06<00:21,  1.09it/s]


EPOCH 7/30 
 #####################
VAL acc:   0.905, VAL MSE:   0.023, TRAIN ENERGY:   0.001
Validation objective decreased (0.024468 --> 0.022578).


 27%|██▋       | 8/30 [00:07<00:20,  1.08it/s]


EPOCH 8/30 
 #####################
VAL acc:   0.909, VAL MSE:   0.021, TRAIN ENERGY:   0.001
Validation objective decreased (0.022578 --> 0.021173).


 30%|███       | 9/30 [00:08<00:19,  1.09it/s]


EPOCH 9/30 
 #####################
VAL acc:   0.915, VAL MSE:   0.020, TRAIN ENERGY:   0.001
Validation objective decreased (0.021173 --> 0.020101).


 33%|███▎      | 10/30 [00:09<00:18,  1.09it/s]


EPOCH 10/30 
 #####################
VAL acc:   0.918, VAL MSE:   0.019, TRAIN ENERGY:   0.001
Validation objective decreased (0.020101 --> 0.019244).


 37%|███▋      | 11/30 [00:10<00:17,  1.07it/s]


EPOCH 11/30 
 #####################
VAL acc:   0.920, VAL MSE:   0.019, TRAIN ENERGY:   0.001
Validation objective decreased (0.019244 --> 0.018532).


 40%|████      | 12/30 [00:10<00:16,  1.07it/s]


EPOCH 12/30 
 #####################
VAL acc:   0.922, VAL MSE:   0.018, TRAIN ENERGY:   0.001
Validation objective decreased (0.018532 --> 0.017928).


 43%|████▎     | 13/30 [00:11<00:15,  1.06it/s]


EPOCH 13/30 
 #####################
VAL acc:   0.923, VAL MSE:   0.017, TRAIN ENERGY:   0.001
Validation objective decreased (0.017928 --> 0.017413).


 47%|████▋     | 14/30 [00:12<00:14,  1.08it/s]


EPOCH 14/30 
 #####################
VAL acc:   0.925, VAL MSE:   0.017, TRAIN ENERGY:   0.001
Validation objective decreased (0.017413 --> 0.016971).


 50%|█████     | 15/30 [00:14<00:15,  1.01s/it]


EPOCH 15/30 
 #####################
VAL acc:   0.926, VAL MSE:   0.017, TRAIN ENERGY:   0.001
Validation objective decreased (0.016971 --> 0.016581).


 53%|█████▎    | 16/30 [00:14<00:13,  1.00it/s]


EPOCH 16/30 
 #####################
VAL acc:   0.926, VAL MSE:   0.016, TRAIN ENERGY:   0.001
Validation objective decreased (0.016581 --> 0.016231).


 57%|█████▋    | 17/30 [00:15<00:12,  1.03it/s]


EPOCH 17/30 
 #####################
VAL acc:   0.927, VAL MSE:   0.016, TRAIN ENERGY:   0.001
Validation objective decreased (0.016231 --> 0.015914).


 60%|██████    | 18/30 [00:16<00:11,  1.03it/s]


EPOCH 18/30 
 #####################
VAL acc:   0.928, VAL MSE:   0.016, TRAIN ENERGY:   0.001
Validation objective decreased (0.015914 --> 0.015631).


 63%|██████▎   | 19/30 [00:17<00:10,  1.05it/s]


EPOCH 19/30 
 #####################
VAL acc:   0.929, VAL MSE:   0.015, TRAIN ENERGY:   0.001
Validation objective decreased (0.015631 --> 0.015377).


 67%|██████▋   | 20/30 [00:18<00:09,  1.04it/s]


EPOCH 20/30 
 #####################
VAL acc:   0.930, VAL MSE:   0.015, TRAIN ENERGY:   0.001
Validation objective decreased (0.015377 --> 0.015150).


 70%|███████   | 21/30 [00:19<00:08,  1.04it/s]


EPOCH 21/30 
 #####################
VAL acc:   0.931, VAL MSE:   0.015, TRAIN ENERGY:   0.001
Validation objective decreased (0.015150 --> 0.014946).


 73%|███████▎  | 22/30 [00:20<00:07,  1.05it/s]


EPOCH 22/30 
 #####################
VAL acc:   0.931, VAL MSE:   0.015, TRAIN ENERGY:   0.001
Validation objective decreased (0.014946 --> 0.014763).


 77%|███████▋  | 23/30 [00:21<00:06,  1.07it/s]


EPOCH 23/30 
 #####################
VAL acc:   0.932, VAL MSE:   0.015, TRAIN ENERGY:   0.001
Validation objective decreased (0.014763 --> 0.014598).


 80%|████████  | 24/30 [00:22<00:05,  1.07it/s]


EPOCH 24/30 
 #####################
VAL acc:   0.932, VAL MSE:   0.014, TRAIN ENERGY:   0.001
Validation objective decreased (0.014598 --> 0.014450).


 83%|████████▎ | 25/30 [00:23<00:04,  1.07it/s]


EPOCH 25/30 
 #####################
VAL acc:   0.933, VAL MSE:   0.014, TRAIN ENERGY:   0.001
Validation objective decreased (0.014450 --> 0.014317).


 87%|████████▋ | 26/30 [00:24<00:03,  1.06it/s]


EPOCH 26/30 
 #####################
VAL acc:   0.933, VAL MSE:   0.014, TRAIN ENERGY:   0.000
Validation objective decreased (0.014317 --> 0.014198).


 90%|█████████ | 27/30 [00:25<00:02,  1.08it/s]


EPOCH 27/30 
 #####################
VAL acc:   0.933, VAL MSE:   0.014, TRAIN ENERGY:   0.000
Validation objective decreased (0.014198 --> 0.014090).


 93%|█████████▎| 28/30 [00:26<00:01,  1.09it/s]


EPOCH 28/30 
 #####################
VAL acc:   0.933, VAL MSE:   0.014, TRAIN ENERGY:   0.000
Validation objective decreased (0.014090 --> 0.013993).


 97%|█████████▋| 29/30 [00:27<00:00,  1.11it/s]


EPOCH 29/30 
 #####################
VAL acc:   0.933, VAL MSE:   0.014, TRAIN ENERGY:   0.000
Validation objective decreased (0.013993 --> 0.013903).


100%|██████████| 30/30 [00:27<00:00,  1.08it/s]


EPOCH 30/30 
 #####################
VAL acc:   0.933, VAL MSE:   0.014, TRAIN ENERGY:   0.000
Validation objective decreased (0.013903 --> 0.013821).

Training time: 0:00:27.902513





In [8]:
loss, acc = 0, 0
for X_batch, y_batch in test_loader:
    y_pred = PCN.test_supervised(X_batch) 

    loss += MSE(y_pred, onehot(y_batch,N=10) ).item()
    acc += torch.mean(( torch.argmax(y_pred, axis=1) == y_batch).float()).item() 

test_energy = energy/len(test_loader)
test_acc = acc/len(test_loader)
test_loss = loss/len(test_loader)

print(f"\nTEST acc:   {test_acc:.3f}, TEST MSE:   {test_loss:.3f}")
print("Training & testing finished in %s" % str((datetime.now() - start_time)).split('.')[0])


TEST acc:   0.936, TEST MSE:   0.013
Training & testing finished in 0:00:27
