<a href="https://colab.research.google.com/github/cindyhfls/SpatialEmbeddedEquilibriumPropagation_Neuromatch_NeuroAI_TrustworthyHeliotrope/blob/EarlyStopping/equilibrium_propagation_toymodel_Lu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Adapted from https://github.com/smonsays/equilibrium-propagation/tree/master "run_energy_model_mnist.py"

**To-do:**

*Week 1 - Make the network architecture and train basic network, decide on the questions*
1. We first make a fake "distance" matrix by specifying the distance between each of the 1000x1000 pairs of units.
2. Implement spatial normalization through energy function?

*Week 2 - Calculating metrics to evaluate the network, each person pick a direction to test and produce a summary slide.*

In [1]:
# @title Clone Repository and Setup
!git clone https://github.com/cindyhfls/SpatialEmbeddedEquilibriumPropagation_Neuromatch_NeuroAI_TrustworthyHeliotrope.git

Cloning into 'SpatialEmbeddedEquilibriumPropagation_Neuromatch_NeuroAI_TrustworthyHeliotrope'...
remote: Enumerating objects: 151, done.[K
remote: Counting objects: 100% (151/151), done.[K
remote: Compressing objects: 100% (136/136), done.[K
remote: Total 151 (delta 78), reused 34 (delta 8), pack-reused 0[K
Receiving objects: 100% (151/151), 88.68 KiB | 5.54 MiB/s, done.
Resolving deltas: 100% (78/78), done.


In [2]:
cd /content/SpatialEmbeddedEquilibriumPropagation_Neuromatch_NeuroAI_TrustworthyHeliotrope/equilibrium-propagation-master/

/content/SpatialEmbeddedEquilibriumPropagation_Neuromatch_NeuroAI_TrustworthyHeliotrope/equilibrium-propagation-master


In [3]:
!pip install -q torchlens

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.3/83.3 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m48.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
import argparse
import json
import logging
import sys

import torch
import torchlens as tl

from lib import config, data, energy, train, utils

In [5]:
# @title Install torchlens and other utilities for visualization/RSA?
!pip install torchlens --quiet
!pip install rsatoolbox --quiet

import torchlens,rsatoolbox

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/656.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━[0m [32m522.2/656.0 kB[0m [31m15.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m656.0/656.0 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [6]:
def extract_features(model, imgs, return_layers, plot='none'):
    """
    Extracts features from specified layers of the model.

    Inputs:
    - model (torch.nn.Module): The model from which to extract features.
    - imgs (torch.Tensor): Batch of input images.
    - return_layers (list): List of layer names from which to extract features.
    - plot (str): Option to plot the features. Default is 'none'.

    Outputs:
    - model_features (dict): A dictionary with layer names as keys and extracted features as values.
    """
    model_history = tl.log_forward_pass(model, imgs, layers_to_save='all', vis_opt=plot)
    model_features = {}
    for layer in return_layers:
        model_features[layer] = model_history[layer].tensor_contents.flatten(1)

    return model_features

In [7]:
# @title Helper functions for parsing input
def load_default_config(energy):
    """
    Load default parameter configuration from file.

    Args:
        tasks: String with the energy name

    Returns:
        Dictionary of default parameters for the given energy
    """
    if energy == "restr_hopfield":
        default_config = "etc/energy_restr_hopfield.json"
    elif energy == "cond_gaussian":
        default_config = "etc/energy_cond_gaussian.json"
    else:
        raise ValueError("Energy based model \"{}\" not defined.".format(energy))

    with open(default_config) as config_json_file:
        cfg = json.load(config_json_file)

    return cfg


def parse_shell_args(args):
    """
    Parse shell arguments for this script.

    Args:
        args: List of shell arguments

    Returns:
        Dictionary of shell arguments
    """
    parser = argparse.ArgumentParser(
        description="Train an energy-based model on MNIST using Equilibrium Propagation."
    )

    parser.add_argument("--batch_size", type=int, default=argparse.SUPPRESS,
                        help="Size of mini batches during training.")
    parser.add_argument("--c_energy", choices=["cross_entropy", "squared_error"],
                        default=argparse.SUPPRESS, help="Supervised learning cost function.")
    parser.add_argument("--dimensions", type=int, nargs="+",
                        default=argparse.SUPPRESS, help="Dimensions of the neural network.")
    parser.add_argument("--energy", choices=["cond_gaussian", "restr_hopfield"],
                        default="cond_gaussian", help="Type of energy-based model.")
    parser.add_argument("--epochs", type=int, default=argparse.SUPPRESS,
                        help="Number of epochs to train.")
    parser.add_argument("--fast_ff_init", action='store_true', default=argparse.SUPPRESS,
                        help="Flag to enable fast feedforward initialization.")
    parser.add_argument("--learning_rate", type=float, default=argparse.SUPPRESS,
                        help="Learning rate of the optimizer.")
    parser.add_argument("--log_dir", type=str, default="",
                        help="Subdirectory within ./log/ to store logs.")
    parser.add_argument("--nonlinearity", choices=["leaky_relu", "relu", "sigmoid", "tanh"],
                        default=argparse.SUPPRESS, help="Nonlinearity between network layers.")
    parser.add_argument("--optimizer", choices=["adam", "adagrad", "sgd"],
                        default=argparse.SUPPRESS, help="Optimizer used to train the model.")
    parser.add_argument("--seed", type=int, default=argparse.SUPPRESS,
                        help="Random seed for pytorch")

    return vars(parser.parse_args(args))

In [8]:
sys.argv = ['','--energy', 'restr_hopfield', '--epochs', '1']

# Parse shell arguments as input configuration
user_config = parse_shell_args(sys.argv[1:])

# Load default parameter configuration from file for the specified energy-based model
cfg = load_default_config(user_config["energy"])

# Overwrite default parameters with user configuration where applicable
cfg.update(user_config)

# Setup global logger and logging directory
config.setup_logging(cfg["energy"] + "_" + cfg["c_energy"] + "_" + cfg["dataset"],
                      dir=cfg['log_dir'])

In [9]:
print(cfg['epochs'])

1


In [10]:
# @title Main function run_energy_model_mnist

"""
Main script.

Args:
    cfg: Dictionary defining parameters of the run
"""

# Initialize seed if specified (might slow down the model)
if cfg['seed'] is not None:
    torch.manual_seed(cfg['seed'])

# Create the cost function to be optimized by the model
c_energy = utils.create_cost(cfg['c_energy'], cfg['beta'])

# Create activation functions for every layer as a list
phi = utils.create_activations(cfg['nonlinearity'], len(cfg['dimensions']))

# Initialize energy based model
if cfg["energy"] == "restr_hopfield":
    model = energy.RestrictedHopfield(
        cfg['dimensions'], c_energy, cfg['batch_size'], phi).to(config.device)
elif cfg["energy"] == "cond_gaussian":
    model = energy.ConditionalGaussian(
        cfg['dimensions'], c_energy, cfg['batch_size'], phi).to(config.device)
else:
    raise ValueError(f'Energy based model \"{cfg["energy"]}\" not defined.')

# Define optimizer (may include l2 regularization via weight_decay)
w_optimizer = utils.create_optimizer(model, cfg['optimizer'],  lr=cfg['learning_rate'])

# Create torch data loaders with the MNIST data set
mnist_train, mnist_val, mnist_test = data.create_mnist_loaders(cfg['batch_size'])

logging.info("Start training with parametrization:\n{}".format(
    json.dumps(cfg, indent=4, sort_keys=True)))

# record the validation accuracy of each epoch for early stopping
PATIENCE = 2
wait = 0
best_val_acc = 0.0

for epoch in range(1, cfg['epochs'] + 1):
    # Training
    train.train(model, mnist_train, cfg['dynamics'], w_optimizer, cfg["fast_ff_init"])

    # Validation
    val_acc, val_energy = train.test(model, mnist_val, cfg['dynamics'], cfg["fast_ff_init"])

    # Testing
    test_acc, test_energy = train.test(model, mnist_test, cfg['dynamics'], cfg["fast_ff_init"])

    # Logging
    logging.info(
        "epoch: {} \t val_acc:{:.4f} \t test_acc: {:.4f} \t mean_E: {:.4f}".format(
            epoch, val_acc, test_acc, test_energy)
    )

    # early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        wait = 0
    else:
        wait += 1
        if wait >= PATIENCE:
            print(f'Early stopping at epoch {epoch}')
            break

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 5101046.38it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 135329.30it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1091113.23it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6329079.32it/s]
[INFO  08:39:40] Start training with parametrization:
{
    "batch_size": 100,
    "beta": 1,
    "c_energy": "squared_error",
    "dataset": "mnist",
    "dimensions": [
        784,
        1000,
        10
    ],
    "dynamics": {
        "dt": 0.1,
        "n_relax": 50,
        "tau": 1,
        "tol": 0
    },
    "energy": "restr_hopfield",
    "epochs": 1,
    "fast_ff_init": false,
    "learning_rate": 0.001,
    "log_dir": "",
    "nonlinearity": "sigmoid",
    "optimizer": "adam",
    "seed": null
}


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



[INFO  08:39:42] 0%:	E: 387.23	dE -28.25	batch_acc 0.1200
[INFO  08:40:10] 10%:	E: -1395.42	dE -21.56	batch_acc 0.3700
[INFO  08:40:37] 20%:	E: -1895.09	dE -21.25	batch_acc 0.4000
[INFO  08:41:04] 30%:	E: -2156.34	dE -20.97	batch_acc 0.3100
[INFO  08:41:31] 40%:	E: -2286.48	dE -20.32	batch_acc 0.3600
[INFO  08:41:58] 50%:	E: -2466.76	dE -19.01	batch_acc 0.4400
[INFO  08:42:25] 60%:	E: -2660.77	dE -17.93	batch_acc 0.5300
[INFO  08:42:52] 70%:	E: -2700.86	dE -15.90	batch_acc 0.6400
[INFO  08:43:20] 80%:	E: -2799.47	dE -14.55	batch_acc 0.6500
[INFO  08:43:47] 90%:	E: -2931.02	dE -14.01	batch_acc 0.6800
[INFO  08:44:54] epoch: 1 	 val_acc:0.7135 	 test_acc: 0.7266 	 mean_E: -3217.3703


In [12]:
# @title Main function run_backprop_model_mnist (reuse the cfg before for hyperparameters, model architecture etc.)

# Initialize seed if specified (might slow down the model) - this should have run before but let's state it again
if cfg['seed'] is not None:
    torch.manual_seed(cfg['seed'])

# Create activation functions for every layer as a list
phi = utils.create_activations(cfg['nonlinearity'], len(cfg['dimensions']))

if cfg['c_energy'] == 'squared_error':
  criterion = torch.nn.functional.mse_loss
elif cfg['c_energy'] == 'cross_entropy':
  criterion = torch.nn.functional.cross_entropy # it's classification so we use crossentropy
else:
  raise ValueError("c_energy \"{}\" not defined.".format(cfg['c_energy']))

model = energy.MLP(
    cfg['dimensions'], cfg['batch_size'],phi).to(config.device)

print(model)

# Define optimizer (may include l2 regularization via weight_decay)
w_optimizer = utils.create_optimizer(model, cfg['optimizer'],  lr=cfg['learning_rate'])

logging.info("Start training with parametrization:\n{}".format(
    json.dumps(cfg, indent=4, sort_keys=True)))

# record the validation accuracy of each epoch for early stopping
PATIENCE = 2
wait = 0
best_val_acc = 0.0

for epoch in range(1, cfg['epochs'] + 1):
    # Training
    train.train_backprop(model, mnist_train, criterion, w_optimizer)

    # Validation
    val_acc, val_energy = train.test_backprop(model, mnist_val, criterion)

    # Testing
    test_acc, test_energy = train.test_backprop(model, mnist_test, criterion)

    # Logging
    logging.info(
        "epoch: {} \t val_acc: {:.4f} \t test_acc: {:.4f} ".format(
            epoch, val_acc, test_acc)
    )

    # early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        wait = 0
    else:
        wait += 1
        if wait >= PATIENCE:
            print(f'Early stopping at epoch {epoch}')
            break

[INFO  08:48:08] Start training with parametrization:
{
    "batch_size": 100,
    "beta": 1,
    "c_energy": "squared_error",
    "dataset": "mnist",
    "dimensions": [
        784,
        1000,
        10
    ],
    "dynamics": {
        "dt": 0.1,
        "n_relax": 50,
        "tau": 1,
        "tol": 0
    },
    "energy": "restr_hopfield",
    "epochs": 1,
    "fast_ff_init": false,
    "learning_rate": 0.001,
    "log_dir": "",
    "nonlinearity": "sigmoid",
    "optimizer": "adam",
    "seed": null
}


MLP(
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=1000, bias=True)
    (1): Linear(in_features=1000, out_features=10, bias=True)
  )
)


[INFO  08:48:22] Epoch Finished: Avg. Loss: 0.0241, Accuracy: 86.39%
[INFO  08:48:23] Test Set: Avg. Loss: 0.0191, Accuracy: 89.87%
[INFO  08:48:26] Test Set: Avg. Loss: 0.0188, Accuracy: 90.24%
[INFO  08:48:26] epoch: 1 	 val_acc: 89.8667 	 test_acc: 90.2400 


In [None]:
# @title Visualize model
print(model)
# can't get this to work

from torch.utils.tensorboard import SummaryWriter
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('log/example')
print(model.W)
writer.add_graph(model.W)
writer.close()




In [None]:
!tensorboard --logdir=log