In [1]:
import torch

import os, sys
import ipynbname
import time
from glob import glob
import itertools
import pprint
pp = pprint.PrettyPrinter()

import numpy as np
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt
import h5py

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
# Timing utilities
start_time = None

def start_timer():
    global start_time
    torch.cuda.synchronize()
    start_time = time.time()

def end_timer_and_print():
    torch.cuda.synchronize()
    end_time = time.time()
    print("\nTotal execution time {:.3f} sec".format(end_time - start_time))

## Loading Data

In [4]:
data_h5 = {os.path.basename(file).split('_features.h5')[0]: h5py.File(file, 'r')
           for file in glob('data/*_features.h5')}

In [5]:
EC_patterns = {model: [np.array(data[dataset]) for dataset in data.keys()]
               for model, data in data_h5.items()}

In [6]:
def map_data(dictionary, fn, *args, **kwargs):
    return {model: [fn(items, *args, **kwargs) for items in data] 
               for model, data in dictionary.items()}

## Feedforward Propagation

Mouse cell counts per hemisphere from https://doi.org/10.1007/s00429-019-01940-7
- MEC LII: 66,365

Mouse cell counts per hemisphere from https://www.nature.com/articles/s41593-018-0109-1
- DG gr: 625,000
- CA3 pyr: 285,000

Rat cell counts per hemisphere from https://www.sciencedirect.com/science/article/pii/S0079612308612376?via%3Dihub
- MEC LII: 200,000
- DG gr: 1,000,000
- CA3 pyr: 330,000

Rat connectivity from https://www.sciencedirect.com/science/article/pii/S0079612308612376?via%3Dihub
- Each CA3 neuron receives 46 inputs from DG
- Each EC neuron sends output to 2% of DG
- Each DG neuron receives input from 1.4-2.2% of EC (3600 neurons)
- Each CA3 neuron receives 3600 inputs from EC

Rat activity from https://www.frontiersin.org/articles/10.3389/fnins.2013.00050 
- DG sparseness: 1-2% in an environment

Rat activity from https://www.jneurosci.org/content/28/52/14271
- CA3 sparseness: 30% in an environment

Rat activity from https://onlinelibrary.wiley.com/doi/abs/10.1002/hipo.22002
- CA3 firing rate: 0.4 Hz

Rat activity from https://www.frontiersin.org/articles/10.3389/fncir.2014.00074/full
- EC LII sparsity: 95% in an environment

Rat activity from https://www.cell.com/cell-reports/fulltext/S2211-1247(13)00401-4
- EC LII firing rate: 1 Hz
- DG firing rate: 0.1 Hz
- CA3 firing rate: 0.3 Hz

See https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1004250

In [22]:
EC_size = list(EC_patterns.values())[0][0].shape[1]
DG_size = EC_size * 8
CA3_size = EC_size * 2

print(f"Number of neurons")
print(f"EC: {EC_size},  DG: {DG_size},  CA3: {CA3_size}")

Number of neurons
EC: 1024,  DG: 8192,  CA3: 2048


In [23]:
DG_conn = 0.2
PP_conn = 0.2
MF_conn = 0.001

DG_per_post = round(DG_conn * EC_size)
PP_per_post = round(PP_conn * EC_size)
MF_per_post = round(MF_conn * DG_size)

print(f"Number of inputs per postsynaptic neuron")
print(f"DG: {DG_per_post},  PP: {PP_per_post},  MF: {MF_per_post}")

Number of inputs per postsynaptic neuron
DG: 205,  PP: 205,  MF: 8


In [24]:
DG_sp = 0.005
PP_sp = 0.2
MF_sp = 0.02

DG_active = round(DG_sp * DG_size)
PP_active = round(PP_sp * CA3_size)
MF_active = round(MF_sp * CA3_size)

print(f"Number of active neurons")
print(f"DG: {DG_active},  PP: {PP_active},  MF: {MF_active}")

Number of active neurons
DG: 41,  PP: 410,  MF: 41


In [25]:
def create_synapses(pre_size, post_size, inputs_per_post):
    data = np.ones(post_size * inputs_per_post, dtype=int)
    indices = np.concatenate([np.random.choice(pre_size, inputs_per_post, replace=False)
                              for _ in range(post_size)])
    indptr = np.arange(post_size+1) * inputs_per_post
    return csr_matrix((data, indices, indptr), shape=(post_size, pre_size))

In [26]:
DG_synapses = create_synapses(EC_size, DG_size, DG_per_post)
PP_synapses = create_synapses(EC_size, CA3_size, PP_per_post)
MF_synapses = create_synapses(DG_size, CA3_size, MF_per_post)

In [27]:
def kWTA(inputs, k):
    tiebreaker = np.random.rand(inputs.size)
    inds = np.lexsort((tiebreaker, inputs))
    output = np.zeros_like(inputs)
    np.put(output, inds[-k:], 1)
    return output

def feedforward_WTA(pre_patterns, synapses, post_active):
    post_inputs = synapses.dot(pre_patterns.T).T
    return np.apply_along_axis(kWTA, 1, post_inputs, k=post_active)

In [28]:
DG_patterns = map_data(EC_patterns, feedforward_WTA, DG_synapses, DG_active)

In [29]:
PP_patterns = map_data(EC_patterns, feedforward_WTA, PP_synapses, PP_active)

In [30]:
MF_patterns = map_data(DG_patterns, feedforward_WTA, MF_synapses, MF_active)

## Pattern statistics

In [31]:
print("EC sparsity:")
pp.pprint(map_data(EC_patterns, np.mean))
print("DG sparsity:")
pp.pprint(map_data(DG_patterns, np.mean))
print("MF sparsity:")
pp.pprint(map_data(MF_patterns, np.mean))
print("PP sparsity:")
pp.pprint(map_data(PP_patterns, np.mean))

EC sparsity:
{'conv-fashion': [0.13570785522460938,
                  0.12564468383789062,
                  0.13354873657226562],
 'linear-MNIST2': [0.10077285766601562, 0.08830642700195312, 0.103302001953125],
 'linear-fashion': [0.092010498046875,
                    0.09688186645507812,
                    0.09177398681640625]}
DG sparsity:
{'conv-fashion': [0.0050048828125, 0.0050048828125, 0.0050048828125],
 'linear-MNIST2': [0.0050048828125, 0.0050048828125, 0.0050048828125],
 'linear-fashion': [0.0050048828125, 0.0050048828125, 0.0050048828125]}
MF sparsity:
{'conv-fashion': [0.02001953125, 0.02001953125, 0.02001953125],
 'linear-MNIST2': [0.02001953125, 0.02001953125, 0.02001953125],
 'linear-fashion': [0.02001953125, 0.02001953125, 0.02001953125]}
PP sparsity:
{'conv-fashion': [0.2001953125, 0.2001953125, 0.2001953125],
 'linear-MNIST2': [0.2001953125, 0.2001953125, 0.2001953125],
 'linear-fashion': [0.2001953125, 0.2001953125, 0.2001953125]}


In [32]:
# def pearson_correlation(x, y):
#     if x.shape != y.shape:
#         raise Exception("Arguments must have same shape")
#     x_centered = x - x.mean(-1, keepdims=True)
#     y_centered = y - y.mean(-1, keepdims=True)
    
#     return ( np.sum(x_centered * y_centered, -1)
#              / np.sqrt(np.sum(x_centered**2, -1) * np.sum(y_centered**2, -1)) )

# def pairwise_correlation(array_of_vectors):
#     if array_of_vectors.ndim != 2:
#         raise Exception("Argument must be a 2D tensor")
#     array_of_vectors = array_of_vectors.astype('float')
#     inds_of_pairs = np.array(list(
#         itertools.combinations(range(len(array_of_vectors)), 2)
#     )).T
#     first_vectors = np.take(array_of_vectors, inds_of_pairs[0], 0)
#     second_vectors = np.take(array_of_vectors, inds_of_pairs[1], 0)
    
#     correlations = pearson_correlation(first_vectors, second_vectors)
#     return correlations.mean()

def pearson_correlation(x, y):
    if x.shape != y.shape:
        raise Exception("Arguments must have same shape")
    x_centered = x - x.mean(-1, keepdim=True)
    y_centered = y - y.mean(-1, keepdim=True)
    
    return ( torch.sum(x_centered * y_centered, -1)
             * torch.rsqrt(torch.sum(x_centered**2, -1) * torch.sum(y_centered**2, -1)) )

def pairwise_correlation(array_of_vectors):
    if array_of_vectors.ndim != 2:
        raise Exception("Argument must be a 2D tensor")
    tensor_of_vectors = torch.tensor(array_of_vectors, dtype=torch.float, device=device)
    inds_of_pairs = torch.combinations(torch.arange(len(tensor_of_vectors)).T
    inds_of_pairs = inds_of_pairs.to(device)
    first_vectors = tensor_of_vectors.index_select(0, inds_of_pairs[0])
    second_vectors = tensor_of_vectors.index_select(0, inds_of_pairs[1])
    
    correlations = pearson_correlation(first_vectors, second_vectors)
    return correlations.mean().item()

In [33]:
print("EC correlation:")
pp.pprint(map_data(EC_patterns, pairwise_correlation))
print("DG correlation:")
pp.pprint(map_data(DG_patterns, pairwise_correlation))
print("MF correlation:")
pp.pprint(map_data(MF_patterns, pairwise_correlation))
print("PP correlation:")
pp.pprint(map_data(PP_patterns, pairwise_correlation))

EC correlation:
{'conv-fashion': [0.10083414614200592,
                  0.11498461663722992,
                  0.12594027817249298],
 'linear-MNIST2': [0.17476320266723633,
                   0.2764626443386078,
                   0.11416282504796982],
 'linear-fashion': [0.1176111176609993,
                    0.11730506271123886,
                    0.21472689509391785]}
DG correlation:
{'conv-fashion': [0.010395663790404797,
                  0.012289708480238914,
                  0.01228069607168436],
 'linear-MNIST2': [0.018960921093821526,
                   0.064083032310009,
                   0.010714842937886715],
 'linear-fashion': [0.01964433863759041,
                    0.015757117420434952,
                    0.0328921340405941]}
MF correlation:
{'conv-fashion': [0.005485582631081343,
                  0.00575856352224946,
                  0.006026206072419882],
 'linear-MNIST2': [0.010390844196081161,
                   0.03013770282268524,
                   0.0059

In [85]:
def category_WTA(patterns, num_active):
    return kWTA(patterns.sum(0), k=num_active)

cat_patterns = map_data(PP_patterns, category_WTA, PP_active)
pp.pprint(cat_patterns)

{'conv-fashion': [array([0, 0, 0, ..., 0, 0, 0]),
                  array([0, 0, 0, ..., 0, 1, 0]),
                  array([0, 1, 0, ..., 0, 0, 0])],
 'linear-MNIST2': [array([0, 0, 0, ..., 0, 1, 0]),
                   array([0, 0, 0, ..., 0, 0, 0]),
                   array([1, 0, 0, ..., 0, 0, 0])],
 'linear-fashion': [array([0, 0, 0, ..., 0, 1, 1]),
                    array([0, 0, 0, ..., 0, 1, 0]),
                    array([0, 0, 1, ..., 0, 0, 0])]}


In [86]:
print("cat sparsity:")
pp.pprint(map_data(cat_patterns, np.mean))

cat sparsity:
{'conv-fashion': [0.2001953125, 0.2001953125, 0.2001953125],
 'linear-MNIST2': [0.2001953125, 0.2001953125, 0.2001953125],
 'linear-fashion': [0.2001953125, 0.2001953125, 0.2001953125]}


In [87]:
def save_patterns_hdf5(pattern_dict, pattern_type):
    for model in EC_patterns.keys():
        h5file = h5py.File(f'results/{model}_{pattern_type}.h5', 'w')
        for i, patterns in enumerate(pattern_dict[model]):
            h5name = f'class_{i}'
            h5data = patterns.astype(np.int8)
            h5dataset = h5file.create_dataset(h5name, data=h5data, compression='gzip')
            print(h5dataset)
        h5file.close()

In [63]:
save_patterns_hdf5(MF_patterns, 'sparse')

<HDF5 dataset "class_0": shape (256, 2048), type "|i1">
<HDF5 dataset "class_1": shape (256, 2048), type "|i1">
<HDF5 dataset "class_2": shape (256, 2048), type "|i1">
<HDF5 dataset "class_0": shape (256, 2048), type "|i1">
<HDF5 dataset "class_1": shape (256, 2048), type "|i1">
<HDF5 dataset "class_2": shape (256, 2048), type "|i1">
<HDF5 dataset "class_0": shape (256, 2048), type "|i1">
<HDF5 dataset "class_1": shape (256, 2048), type "|i1">
<HDF5 dataset "class_2": shape (256, 2048), type "|i1">


In [64]:
save_patterns_hdf5(PP_patterns, 'dense')

<HDF5 dataset "class_0": shape (256, 2048), type "|i1">
<HDF5 dataset "class_1": shape (256, 2048), type "|i1">
<HDF5 dataset "class_2": shape (256, 2048), type "|i1">
<HDF5 dataset "class_0": shape (256, 2048), type "|i1">
<HDF5 dataset "class_1": shape (256, 2048), type "|i1">
<HDF5 dataset "class_2": shape (256, 2048), type "|i1">
<HDF5 dataset "class_0": shape (256, 2048), type "|i1">
<HDF5 dataset "class_1": shape (256, 2048), type "|i1">
<HDF5 dataset "class_2": shape (256, 2048), type "|i1">


In [88]:
save_patterns_hdf5(cat_patterns, 'cat')

<HDF5 dataset "class_0": shape (2048,), type "|i1">
<HDF5 dataset "class_1": shape (2048,), type "|i1">
<HDF5 dataset "class_2": shape (2048,), type "|i1">
<HDF5 dataset "class_0": shape (2048,), type "|i1">
<HDF5 dataset "class_1": shape (2048,), type "|i1">
<HDF5 dataset "class_2": shape (2048,), type "|i1">
<HDF5 dataset "class_0": shape (2048,), type "|i1">
<HDF5 dataset "class_1": shape (2048,), type "|i1">
<HDF5 dataset "class_2": shape (2048,), type "|i1">


In [89]:
def save_patterns_binary(pattern_dict, pattern_type):
    for model in EC_patterns.keys():
        binfile = open(f'results/{model}_{pattern_type}.dat', 'w')
        for patterns in pattern_dict[model]:
            bindata = np.packbits(patterns)
            bindata.tofile(binfile)
            print(bindata)
        binfile.close()

In [78]:
save_patterns_binary(MF_patterns, 'sparse')

[0 0 0 ... 0 0 0]
[ 0  0  0 ...  0  0 64]
[ 0 12  0 ...  0  0  0]
[ 0  0  0 ...  0  0 16]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[ 0  1 32 ...  0  0  0]
[0 0 0 ... 0 0 0]


In [79]:
save_patterns_binary(PP_patterns, 'dense')

[ 0  8 33 ...  0  6  0]
[  0  64  16 ... 128  16 130]
[ 72  77  64 ...   5 130  16]
[ 8  0 32 ... 80 32  0]
[ 16 194  72 ...   2  69 139]
[ 64  96   1 ...  32 139  66]
[  0  64 136 ...   0   0   0]
[16 28  4 ...  2  8 32]
[  2 136   1 ...   0  64 138]


In [90]:
save_patterns_binary(cat_patterns, 'cat')

[  4  41  32  16 172   0  33  98  17 132 130  34  32   0   0 129 150  32
  33  36  36 176   1 210  64   0  68  36  17   2  64  45  96  64  54   1
  38  16   2  36   0  18 228   8   6   8  20 130 198 128  21  26 128  74
   4   0   0 132 128 185   4   0 210  29 208  64   0   0   1   6  12   0
   1 128  17  22  18   4  32  17 156 113 129 176  26  15   0   2   0  68
  64  64 128   0 152 144 144   2 138   0   4  16  32   4   9   8 131  33
  16   4  14   3 128   0 104  96   2   4  26  68   6   1   1   1  15  18
  26  64 131  65   0  32   0   1   9  20  73  40  64   3   0   0 192  83
  64  16  40   0 128  29  40   0  64  13  33   0   8  66  24  16   0  34
  64  72  40  41  24  84  66   0   4   0  32   1   0 162   0  26 128 130
   0   2   0  80   0   1  20  80 104   0  96   0   1  16  16   2   2 164
   0   0  99 192  36  33   0   0  96  16   8 140   0  52   4 131   0   0
  35   0   0 224  64  64  32   1 138  64   0   4   4 156   0   4  16  65
   0  16   1  74  96   4 133  72   0   2  17   0 14