In [1]:
import logging
import sys
from typing import Any, Dict, List, Iterable, Tuple, Union
import warnings
warnings.filterwarnings('ignore')

import leabra7 as lb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, mean_squared_error
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import torch
from torch import nn
from torch.optim import Adam
from tqdm.auto import tqdm

In [2]:
LoggerType = Union[None, logging.Logger, logging.LoggerAdapter]

In [3]:
def load_data(logger: LoggerType = None,
              threshold: int = 128,
              train_length:int = None,
              test_length:int = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Loads and preprocesses the data.
    
    Returns:
        An (X, Y) tuple containing the features and labels, respectively.
    """
    if logger is None:
        logger = logging.getLogger()
    else:
        logger.info("Loading data")
    train_set = MNIST(root="./data", train=True, download=True)
    test_set = MNIST(root="./data", train=False, download=True)
    X_train = (train_set.data > threshold).flatten(1).int()[:train_length]
    Y_train = nn.functional.one_hot(train_set.targets)[:train_length]
    X_test = (test_set.data > threshold).flatten(1).int()[:test_length]
    Y_test = nn.functional.one_hot(test_set.targets)[:test_length]
    return (X_train, Y_train, X_test, Y_test)


def build_network(input_size: int,
                  hidden_size: int,
                  output_size: int,
                  logger: LoggerType = None,
                  lrate: float = 0.02) -> lb.Net:
    """Builds the classifier network.
    
    Args:
        input_size: The size of the input layer.
        hidden_size: The size of the hidden layer.
        output_size: The size of the output layer.
        logger: The logger to use.
    
    Returns:
        A Leabra7 network for classification.
    """
    if logger is None:
        logger = logging.getLogger()
    logger.info("Building network")
    net = lb.Net()
    
    # Layers
    layer_spec = lb.LayerSpec(gi=1.5, ff=1, fb=1,
    unit_spec=lb.UnitSpec(spike_gain=0, vm_gain=0, adapt_dt=0))
    net.new_layer("input", size=input_size, spec=layer_spec)
    net.new_layer("hidden", size=hidden_size, spec=layer_spec)
    net.new_layer("output", size=output_size, spec=layer_spec)
    logger.debug("Input layer size: %d", input_size)
    logger.debug("Hidden layer size: %d", hidden_size)
    logger.debug("Output layer size: %d", output_size)
    
    # Projections
    up_spec = lb.ProjnSpec(
        lrate=lrate,
        dist=lb.Uniform(0.25, 0.75),
        cos_diff_thr_l_mix=False,
        cos_diff_lrate=False)
    down_spec = lb.ProjnSpec(
        lrate=lrate,
        dist=lb.Uniform(0.25, 0.5),
        wt_scale_rel=0.3,
        cos_diff_thr_l_mix=False,
        cos_diff_lrate=False)
    net.new_projn(
        "input_to_hidden", pre="input", post="hidden", spec=up_spec)
    net.new_projn(
        "hidden_to_output", pre="hidden", post="output", spec=up_spec)
    net.new_projn(
        "output_to_hidden", pre="output", post="hidden", spec=down_spec)
    
    return net

def trial(network: lb.Net, input_pattern: Iterable[float],
          output_pattern: Iterable[float]) -> None:
    """Runs a trial.
    
    Args:
        input_pattern: The pattern to clamp to the network's input layer.
        output_pattern: The pattern to clamp to the network's output layer.
    """
    network.clamp_layer("input", input_pattern)
    network.minus_phase_cycle(num_cycles=50)
    network.clamp_layer("output", output_pattern)
    network.plus_phase_cycle(num_cycles=25)
    network.unclamp_layer("input")
    network.unclamp_layer("output")
    network.learn()
    
def epoch(network: lb.Net,
          X: np.ndarray,
          Y: np.ndarray) -> None:
    """Runs an epoch (one pass through the whole dataset).
    
    Args:
        input_patterns: A numpy array with shape (n_samples, n_features).
        output_patterns: A numpy array with shape (n_samples, n_features).
    """
    for x, y in tqdm(zip(X, Y), total=len(X), leave=False):
        trial(network, x, y)
    network.end_epoch()
    
def train(network: lb.Net,
          X_train: np.ndarray,
          Y_train: np.ndarray,
          X_test: np.ndarray,
          Y_test: np.ndarray,
          num_epochs: int = 10,
          print_freq: int = 1,
          logger: LoggerType = None) -> pd.DataFrame:
    """Trains the network.
    
    Args:
        input_patterns: A numpy array with shape (n_samples, n_features).
        output_patterns: A numpy array with shape (n_samples, n_features).
        num_epochs: The number of epochs to run. Defaults to 500.
        print_freq: Frequency of predictions. Defaults to 5.
        logger: The logger to use. If None, will use the module's default logger.
    
    Returns:
        29pd.DataFrame:
        A dataframe of metrics from the training run.
    """
    if logger is None:
        logger = logging.getLogger()
    logger.info("Begin training")
    
    logger.debug("Training set size: %d", X_train.shape[0])
    logger.debug("Test set size: %d", X_test.shape[0])
    
    data: Dict[str, List[float]] = {
        "epoch": [],
        "train_loss": [],
        "train_accuracy": [],
        "test_loss": [],
        "test_accuracy": []
    }
        
    for i in range(1, num_epochs + 1):
        epoch(network, X_train, Y_train)
        # Predicting is slow
        if i % print_freq == 0:
            pred_train = predict(network, X_train)
            data["epoch"].append(i)
            train_loss = mean_squared_error(Y_train, pred_train)
            data["train_loss"].append(train_loss)
            train_acc = accuracy_score(Y_train, pred_train, normalize=True)
            data["train_accuracy"].append(train_acc)
            pred_test = predict(network, X_test)
            test_loss = mean_squared_error(Y_test, pred_test)
            data["test_loss"].append(test_loss)
            test_acc = accuracy_score(Y_test, pred_test, normalize=True)
            data["test_accuracy"].append(test_acc)
            
            logger.info('[Epoch {}]  train_loss: {:.4f}  train_accuracy: {:.2f}  test_loss: {:.4f}  test_accuracy: {:.2f}'.format(
                i, train_loss, train_acc, test_loss, test_acc))
    logger.info("End training")
    return pd.DataFrame(data)

def output(network: lb.Net, pattern: Iterable[float]) -> List[float]:
    """Calculates a prediction for a single input pattern.
    
    Args:
        network: The trained network.
        pattern: The input pattern.
    
    Returns:
        np.ndarray: The output of the network after clamping the input
        pattern to the input layer and settling. The max value is set to one,
        everything else is set to zero.
    """
    network.clamp_layer("input", pattern)
    for _ in range(50):
        network.cycle()
    network.unclamp_layer("input")
    out = network.observe("output", "unit_act")["act"].values
    max_idx = np.argmax(out)
    out[:] = 0
    out[max_idx] = 1
    return list(out)

def predict(network: lb.Net, input_patterns: np.ndarray) -> np.ndarray:
    """Calculates predictions for an array of input patterns.
    
    Args:
        network: The trained network.
        input_patterns: An array of shape (n_samples, n_features)
        containing the input patterns for which to calculate predictions.

    Returns:
        np.ndarray: An array of shape (n_samples, n_features) containing the
        predictions for the input patterns.
    """
    outputs = []
    for item in tqdm(input_patterns, leave=False):
        outputs.append(output(network, item))
    return np.array(outputs)

In [4]:
PROJ_NAME = "mnist"
np.seterr("warn")
logging.basicConfig(
    level=logging.DEBUG,
    format="%(levelname)s %(message)s",
    handlers=(
        logging.FileHandler(
            "{0}_log.txt".format(PROJ_NAME), mode="w"),
        logging.StreamHandler(sys.stdout)))

logging.info("Begin training %s", PROJ_NAME)

INFO Begin training mnist


In [5]:
X_train, Y_train, X_test, Y_test = load_data(train_length=1000, test_length=1000)
input_size = X_train.shape[1]
output_size = Y_train.shape[1]
hidden_size = np.int(np.sqrt(input_size*output_size))
net = build_network(input_size, hidden_size, output_size, lrate=0.001)

metrics = train(net, X_train, Y_train, X_test, Y_test, num_epochs=20)

# Save metrics and network for future analysis
metrics.to_csv("{0}_metrics.csv".format(PROJ_NAME), index=False)
net.save("{0}_network.pkl".format(PROJ_NAME))

INFO Building network
DEBUG Input layer size: 784
DEBUG Hidden layer size: 88
DEBUG Output layer size: 10
INFO Begin training
DEBUG Training set size: 1000
DEBUG Test set size: 1000


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 1]  train_loss: 0.1856  train_accuracy: 0.07  test_loss: 0.1930  test_accuracy: 0.04


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 2]  train_loss: 0.1846  train_accuracy: 0.08  test_loss: 0.1882  test_accuracy: 0.06


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 3]  train_loss: 0.1850  train_accuracy: 0.07  test_loss: 0.1846  test_accuracy: 0.08


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 4]  train_loss: 0.1844  train_accuracy: 0.08  test_loss: 0.1840  test_accuracy: 0.08


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 5]  train_loss: 0.1842  train_accuracy: 0.08  test_loss: 0.1838  test_accuracy: 0.08


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 6]  train_loss: 0.1840  train_accuracy: 0.08  test_loss: 0.1840  test_accuracy: 0.08


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 7]  train_loss: 0.1832  train_accuracy: 0.08  test_loss: 0.1838  test_accuracy: 0.08


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 8]  train_loss: 0.1790  train_accuracy: 0.10  test_loss: 0.1798  test_accuracy: 0.10


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 9]  train_loss: 0.1776  train_accuracy: 0.11  test_loss: 0.1804  test_accuracy: 0.10


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 10]  train_loss: 0.1768  train_accuracy: 0.12  test_loss: 0.1818  test_accuracy: 0.09


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 11]  train_loss: 0.1770  train_accuracy: 0.12  test_loss: 0.1820  test_accuracy: 0.09


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 12]  train_loss: 0.1764  train_accuracy: 0.12  test_loss: 0.1814  test_accuracy: 0.09


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 13]  train_loss: 0.1764  train_accuracy: 0.12  test_loss: 0.1808  test_accuracy: 0.10


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 14]  train_loss: 0.1764  train_accuracy: 0.12  test_loss: 0.1816  test_accuracy: 0.09


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 15]  train_loss: 0.1768  train_accuracy: 0.12  test_loss: 0.1816  test_accuracy: 0.09


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 16]  train_loss: 0.1768  train_accuracy: 0.12  test_loss: 0.1808  test_accuracy: 0.10


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 17]  train_loss: 0.1764  train_accuracy: 0.12  test_loss: 0.1810  test_accuracy: 0.10


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 18]  train_loss: 0.1768  train_accuracy: 0.12  test_loss: 0.1816  test_accuracy: 0.09


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 19]  train_loss: 0.1768  train_accuracy: 0.12  test_loss: 0.1810  test_accuracy: 0.10


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

INFO [Epoch 20]  train_loss: 0.1774  train_accuracy: 0.11  test_loss: 0.1812  test_accuracy: 0.09
INFO End training
