<a href="https://colab.research.google.com/github/cindyhfls/SpatialEmbeddedEquilibriumPropagation_Neuromatch_NeuroAI_TrustworthyHeliotrope/blob/main/equilibrium_propagation_toymodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Adapted from https://github.com/smonsays/equilibrium-propagation/tree/master "run_energy_model_mnist.py"

**To-do:**
*Week 1 - Make the network architecture and train basic network, decide on the questions*
1. We first make a fake "distance" matrix by specifying the distance between each of the 1000x1000 pairs of units.
2. Implement spatial normalization through energy function?
*Week 2 - Calculating metrics to evaluate the network, each person pick a direction to test and produce a summary slide.*

In [None]:
# @title Clone Repository and Setup
!git clone https://github.com/smonsays/equilibrium-propagation.git

fatal: destination path 'equilibrium-propagation' already exists and is not an empty directory.


In [None]:
cd /content/equilibrium-propagation/

/content/equilibrium-propagation


In [None]:
import argparse
import json
import logging
import sys

import torch

from lib import config, data, energy, train, utils

In [None]:
# @title Install torchlens and other utilities for visualization/RSA?
!pip install torchlens --quiet
!pip install rsatoolbox --quiet

import torchlens,rsatoolbox

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/83.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.3/83.3 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m43.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m656.0/656.0 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
def extract_features(model, imgs, return_layers, plot='none'):
    """
    Extracts features from specified layers of the model.

    Inputs:
    - model (torch.nn.Module): The model from which to extract features.
    - imgs (torch.Tensor): Batch of input images.
    - return_layers (list): List of layer names from which to extract features.
    - plot (str): Option to plot the features. Default is 'none'.

    Outputs:
    - model_features (dict): A dictionary with layer names as keys and extracted features as values.
    """
    model_history = tl.log_forward_pass(model, imgs, layers_to_save='all', vis_opt=plot)
    model_features = {}
    for layer in return_layers:
        model_features[layer] = model_history[layer].tensor_contents.flatten(1)

    return model_features

In [None]:
# @title Helper functions for parsing input
def load_default_config(energy):
    """
    Load default parameter configuration from file.

    Args:
        tasks: String with the energy name

    Returns:
        Dictionary of default parameters for the given energy
    """
    if energy == "restr_hopfield":
        default_config = "etc/energy_restr_hopfield.json"
    elif energy == "cond_gaussian":
        default_config = "etc/energy_cond_gaussian.json"
    else:
        raise ValueError("Energy based model \"{}\" not defined.".format(energy))

    with open(default_config) as config_json_file:
        cfg = json.load(config_json_file)

    return cfg


def parse_shell_args(args):
    """
    Parse shell arguments for this script.

    Args:
        args: List of shell arguments

    Returns:
        Dictionary of shell arguments
    """
    parser = argparse.ArgumentParser(
        description="Train an energy-based model on MNIST using Equilibrium Propagation."
    )

    parser.add_argument("--batch_size", type=int, default=argparse.SUPPRESS,
                        help="Size of mini batches during training.")
    parser.add_argument("--c_energy", choices=["cross_entropy", "squared_error"],
                        default=argparse.SUPPRESS, help="Supervised learning cost function.")
    parser.add_argument("--dimensions", type=int, nargs="+",
                        default=argparse.SUPPRESS, help="Dimensions of the neural network.")
    parser.add_argument("--energy", choices=["cond_gaussian", "restr_hopfield"],
                        default="cond_gaussian", help="Type of energy-based model.")
    parser.add_argument("--epochs", type=int, default=argparse.SUPPRESS,
                        help="Number of epochs to train.")
    parser.add_argument("--fast_ff_init", action='store_true', default=argparse.SUPPRESS,
                        help="Flag to enable fast feedforward initialization.")
    parser.add_argument("--learning_rate", type=float, default=argparse.SUPPRESS,
                        help="Learning rate of the optimizer.")
    parser.add_argument("--log_dir", type=str, default="",
                        help="Subdirectory within ./log/ to store logs.")
    parser.add_argument("--nonlinearity", choices=["leaky_relu", "relu", "sigmoid", "tanh"],
                        default=argparse.SUPPRESS, help="Nonlinearity between network layers.")
    parser.add_argument("--optimizer", choices=["adam", "adagrad", "sgd"],
                        default=argparse.SUPPRESS, help="Optimizer used to train the model.")
    parser.add_argument("--seed", type=int, default=argparse.SUPPRESS,
                        help="Random seed for pytorch")

    return vars(parser.parse_args(args))

In [None]:
sys.argv = ['','--energy', 'restr_hopfield', '--epochs', '1']

# Parse shell arguments as input configuration
user_config = parse_shell_args(sys.argv[1:])

# Load default parameter configuration from file for the specified energy-based model
cfg = load_default_config(user_config["energy"])

# Overwrite default parameters with user configuration where applicable
cfg.update(user_config)

# Setup global logger and logging directory
config.setup_logging(cfg["energy"] + "_" + cfg["c_energy"] + "_" + cfg["dataset"],
                      dir=cfg['log_dir'])

In [None]:
# @title Main function run_energy_model_mnist

"""
Main script.

Args:
    cfg: Dictionary defining parameters of the run
"""
# Initialize seed if specified (might slow down the model)
if cfg['seed'] is not None:
    torch.manual_seed(cfg['seed'])

# Create the cost function to be optimized by the model
c_energy = utils.create_cost(cfg['c_energy'], cfg['beta'])

# Create activation functions for every layer as a list
phi = utils.create_activations(cfg['nonlinearity'], len(cfg['dimensions']))

# Initialize energy based model
if cfg["energy"] == "restr_hopfield":
    model = energy.RestrictedHopfield(
        cfg['dimensions'], c_energy, cfg['batch_size'], phi).to(config.device)
elif cfg["energy"] == "cond_gaussian":
    model = energy.ConditionalGaussian(
        cfg['dimensions'], c_energy, cfg['batch_size'], phi).to(config.device)
else:
    raise ValueError(f'Energy based model \"{cfg["energy"]}\" not defined.')

# Define optimizer (may include l2 regularization via weight_decay)
w_optimizer = utils.create_optimizer(model, cfg['optimizer'],  lr=cfg['learning_rate'])

# Create torch data loaders with the MNIST data set
mnist_train, mnist_test = data.create_mnist_loaders(cfg['batch_size'])

logging.info("Start training with parametrization:\n{}".format(
    json.dumps(cfg, indent=4, sort_keys=True)))

for epoch in range(1, cfg['epochs'] + 1):
    # Training
    train.train(model, mnist_train, cfg['dynamics'], w_optimizer, cfg["fast_ff_init"])

    # Testing
    test_acc, test_energy = train.test(model, mnist_test, cfg['dynamics'], cfg["fast_ff_init"])

    # Logging
    logging.info(
        "epoch: {} \t test_acc: {:.4f} \t mean_E: {:.4f}".format(
            epoch, test_acc, test_energy)
    )

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4179334.59it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 133700.46it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1273478.78it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2945806.21it/s]
[INFO  11:33:08] Start training with parametrization:
{
    "batch_size": 100,
    "beta": 1,
    "c_energy": "squared_error",
    "dataset": "mnist",
    "dimensions": [
        784,
        1000,
        10
    ],
    "dynamics": {
        "dt": 0.1,
        "n_relax": 50,
        "tau": 1,
        "tol": 0
    },
    "energy": "restr_hopfield",
    "epochs": 1,
    "fast_ff_init": false,
    "learning_rate": 0.001,
    "log_dir": "",
    "nonlinearity": "sigmoid",
    "optimizer": "adam",
    "seed": null
}


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



[INFO  11:33:10] 0%:	E: 327.24	dE -25.53	batch_acc 0.1100
[INFO  11:33:44] 10%:	E: -1988.66	dE -21.74	batch_acc 0.1700
[INFO  11:34:17] 20%:	E: -2382.46	dE -21.02	batch_acc 0.3300
[INFO  11:34:50] 30%:	E: -2640.65	dE -20.05	batch_acc 0.3700
[INFO  11:35:24] 40%:	E: -2907.97	dE -18.68	batch_acc 0.5000
[INFO  11:35:59] 50%:	E: -3175.15	dE -18.53	batch_acc 0.5700
[INFO  11:36:32] 60%:	E: -3205.23	dE -16.03	batch_acc 0.6700
[INFO  11:37:08] 70%:	E: -3328.60	dE -15.76	batch_acc 0.6700
[INFO  11:37:41] 80%:	E: -3560.30	dE -13.93	batch_acc 0.7100
[INFO  11:38:15] 90%:	E: -3867.30	dE -12.51	batch_acc 0.7800
[INFO  11:39:17] epoch: 1 	 test_acc: 0.7475 	 mean_E: -3968.1018


In [None]:
# @title Visualize model
print(model)

ConditionalGaussian(
  (W): ModuleList(
    (0): Linear(in_features=784, out_features=1000, bias=True)
    (1): Linear(in_features=1000, out_features=10, bias=True)
  )
)
