In [2]:
from typing import List

import matplotlib
import pandas as pd
import matplotlib.pyplot as plt

from leabra7 import events as et
from leabra7 import net as nt
from leabra7 import rand as rd
from leabra7 import specs as sp

%matplotlib inline

In [3]:
## Define Layer Specs

# Generic Layer Spec
layer_spec = sp.LayerSpec(
    # For each layer, log the unit potential
    log_on_cycle=("unit_v_m", "unit_act", "unit_i_net",
                  "unit_net", "unit_gc_i", "unit_adapt",
                  "unit_spike")
)

# EC Layer Spec
EC_layer_spec = sp.LayerSpec(
    # For each layer, log the unit potential
    log_on_cycle=("unit_v_m", "unit_act", "unit_i_net",
                  "unit_net", "unit_gc_i", "unit_adapt",
                  "unit_spike"),
    inhibition_type = "kwta",
    kwta_pct = 0.2,
    kwta_pt = 0.5,
)

# DG Layer Spec
DG_layer_spec = sp.LayerSpec(
    # For each layer, log the unit potential
    log_on_cycle=("unit_v_m", "unit_act", "unit_i_net",
                  "unit_net", "unit_gc_i", "unit_adapt",
                  "unit_spike"),
    inhibition_type = "kwta_avg",
    kwta_pct = 0.01,
    kwta_pt = 0.9,
)

# CA3 Layer Spec
CA3_layer_spec = sp.LayerSpec(
    # For each layer, log the unit potential
    log_on_cycle=("unit_v_m", "unit_act", "unit_i_net",
                  "unit_net", "unit_gc_i", "unit_adapt",
                  "unit_spike"),
    inhibition_type = "kwta_avg",
    kwta_pct = 0.06,
    kwta_pt = 0.7,
)

# CA1 Layer Spec
CA1_layer_spec = sp.LayerSpec(
    # For each layer, log the unit potential
    log_on_cycle=("unit_v_m", "unit_act", "unit_i_net",
                  "unit_net", "unit_gc_i", "unit_adapt",
                  "unit_spike"),
    inhibition_type = "kwta_avg",
    kwta_pct = 0.25,
    kwta_pt = 0.7,
)

In [6]:
## Define Projections Spec

# Input Projection Spec
Input_projn_spec = sp.ProjnSpec(
    dist=rd.Uniform(low=0.25, high=0.75),
#     l_rate = 0,
    projn_type = "one_to_one",
)

# EC_in Projection Spec
EC_in_projn_spec = sp.ProjnSpec(
    dist=rd.Uniform(low=0.25, high=0.75),
#     l_rate = 0.2,
    sparsity = 0.25,
)

# DG (Mossy Fiber) Projection Spec
DG_projn_spec = sp.ProjnSpec(
    dist=rd.Uniform(low=0.89, high=0.91),
    wt_scale_rel = 8.0,
    sparsity = 0.05,
#     l_rate = 0,
)

# CA3 Recurrent Projection Spec
CA3_Recur_projn_spec = sp.ProjnSpec(
    dist=rd.Uniform(low=0.25, high=0.75),
#     l_rate = 0.2,
)

# CA3 -> CA1 (Schaffer) Projection Spec
CA3_CA1_projn_spec = sp.ProjnSpec(
    dist=rd.Uniform(low=0.25, high=0.75),
#     l_rate = 0.05,
)

# EC_in -> CA1 Projection Spec
EC_in_CA1_projn_spec = sp.ProjnSpec(
    dist=rd.Uniform(low=0.25, high=0.75),
    wt_scale_abs = 3.0,
#     l_rate = 0.02,
)

# EC_out <--> CA1 Projection Spec
EC_out_CA1_projn_spec = sp.ProjnSpec(
    dist=rd.Uniform(low=0.25, high=0.75),
#     l_rate = 0.02,
)

# EC_out -> EC_in Projection Spec
EC_out_EC_in_projn_spec = sp.ProjnSpec(
    dist=rd.Uniform(low=0.49, high=0.51),
#     l_rate = 0,
    wt_scale_abs = 2.0,
    wt_scale_rel = 0.5,
    projn_type = "one_to_one",
)

In [7]:
# Create the Network
net = nt.Net()

In [8]:
## Create Layers

# Create input and output layers
net.new_layer("Input", 8, layer_spec)
net.new_layer("EC_in", 8, EC_layer_spec)
net.new_layer("EC_out", 8, EC_layer_spec)

# Create cortical layers
net.new_layer("CA1", 100, CA1_layer_spec)
net.new_layer("CA3", 80, CA3_layer_spec)

# Create hippocampus
net.new_layer("DG", 400, DG_layer_spec)

In [9]:
## Create Projections

# Input Feed
net.new_projn("Input: Input -> EC_in", "Input", "EC_in", spec=Input_projn_spec)

# Create MSP
net.new_projn("MSP: EC_in -> CA1", "EC_in", "CA1", spec=EC_in_CA1_projn_spec)
net.new_projn("MSP: CA1 -> EC_out", "CA1", "EC_out", spec=EC_out_CA1_projn_spec)
net.new_projn("MSP: EC_out -> CA1", "EC_out", "CA1", spec=EC_out_CA1_projn_spec)

# Create TSP
net.new_projn("TSP: EC_in -> DG", "EC_in", "DG", spec=EC_in_projn_spec)
net.new_projn("TSP: EC_in -> CA3", "EC_in", "CA3", spec=EC_in_projn_spec)
net.new_projn("TSP: DG -> CA3", "DG", "CA3", spec=DG_projn_spec)
net.new_projn("TSP: CA3 -> CA3", "CA3", "CA3", spec=CA3_Recur_projn_spec)
net.new_projn("TSP: CA3 -> CA1", "CA3", "CA1", spec=CA3_CA1_projn_spec)

In [10]:
# Function to plot data for a certain attribute for each unit of layer
def plot_by_unit(axes: List[matplotlib.axes.Axes], 
                 log: pd.DataFrame, attr: str, title: str, location: List):
    for name, group in log.groupby("unit"):
        group.plot(x="time", y=attr, ax=axes[location], 
                   title = title, label="unit " + str(name))