In [1]:
import cantata
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from box import Box
from cantata.plotting import output as cp
from cantata import cfg

In [2]:
figsize = (14,5)

In [3]:
print(cfg.model_config)

1tier.yaml


In [4]:
# Extract populations & projections:
populations = Box()
drivers = Box()
projections = Box()
for name, pop in cfg.model.populations.items():
    p = pop.copy()
    projections[name] = p.targets
    del p.targets
    if p.rate >= 0:
        drivers[name] = p
    else:
        populations[name] = p

In [5]:
for name in drivers.keys():
    assert np.all([name not in t for t in projections.values()]), name

In [6]:
base_cfg = cfg.copy()
del base_cfg.model_config
base_cfg.model.populations = Box()

In [7]:
second = int(1/base_cfg.time_step)

In [8]:
# STDP check: Engineered pairs
def stdp_plot(source_name, target_name,
              n_probes = 51, probe_spacing = 0.002,
              burst = 1, burst_spacing = 0.005,
              v_post = 0.5, with_noise = False, batch_size = 32,
              plot_traces = True,
              kd = 1, kp = 1):
    stdp_cfg = base_cfg.copy()
    stdp_cfg.batch_size = batch_size
    stdp_cfg.n_inputs = 3
    stdp_cfg.n_steps = int(n_probes * probe_spacing / stdp_cfg.time_step) + 10
    stdp_cfg.model.populations.Driver_pre = Box(dict(rate=0, targets={}))
    stdp_cfg.model.populations.Driver_post = Box(dict(rate=1, targets={'Post':{}}))
    stdp_cfg.model.populations.Driver_vpost = Box(dict(rate=2, targets={'Post':{}}))
    for i in range(n_probes):
        stdp_cfg.model.populations.Driver_pre.targets[f'Pre_{i}'] = {'delay': probe_spacing*i}
    stdp_cfg.model.stdp_wmax_total = 2
    stdp_cfg.model.tau_ref = 0

    inputs = torch.zeros(stdp_cfg.batch_size, stdp_cfg.n_steps, stdp_cfg.n_inputs, **stdp_cfg.tspec)
    surefire = 1/stdp_cfg.time_step
    inputs[:,0,0] = surefire
    burst_idx = torch.ones(burst)*(1 + int(n_probes//2 * probe_spacing/stdp_cfg.time_step)) + \
        torch.arange(0,-burst,-1) * burst_spacing/stdp_cfg.time_step
    inputs[:,burst_idx.to(torch.long),1] = surefire
    if not with_noise:
        inputs[:,:,2] = surefire

    t = np.arange(stdp_cfg.n_steps) * .001/stdp_cfg.time_step
    dt = (np.arange(n_probes) - n_probes//2) * probe_spacing * 1000

    for i in range(n_probes):
        proj = projections[source_name][target_name].copy()
        proj.delay = 0
        proj.density = 1
        proj.A_p *= kp
        proj.A_d *= kd
        source = (populations[source_name] if source_name in populations else drivers[source_name]).copy()
        source.n = 1
        source.targets = {f'Post': proj}
        if not with_noise:
            source.noise_N = 0
        stdp_cfg.model.populations[f'Pre_{i}'] = source
    target = populations[target_name].copy()
    target.n = n_probes
    if not with_noise:
        target.noise_N = 0
    stdp_cfg.model.populations.Post = target

    cantata.config.load(stdp_cfg)
    m = cantata.Module()
    with torch.no_grad():
        state,epoch,record = m.forward_init(inputs, ['w_stdp', 'mem'])
        state.mem[:] = 0
        epoch.W[0][epoch.W[0]>0] = 1
        epoch.W[1][epoch.W[1]>0] = 1 + (-v_post if v_post < 0 else 0)
        epoch.W[2][epoch.W[2]>0] = v_post - v_post*m.alpha_mem
        state.mem[:,3+n_probes:] = v_post
        state.u_dep[:,3+n_probes:] = v_post
        state.u_pot[:,3+n_probes:] = v_post
        state.th_dep[:,3+n_probes:] = v_post
        epoch.W[3:3+n_probes, 3+n_probes:] = torch.eye(n_probes)*(1-v_post)/2*source.sign
        m.forward_run(state,epoch,record)
        m.forward_close(record)

    if plot_traces:
        plt.figure(figsize=figsize)
        for i,k in enumerate(range(n_probes)):
            plt.subplot(n_probes,2,2*i+1)
            plt.plot(record.mem[0,:,3 + k].cpu().numpy(), label='pre')
            plt.plot(record.mem[0,:,3 + n_probes + k].cpu().numpy(), label='post')
            if i < n_probes-1:
                plt.axis('off')
            else:
                sns.despine()
                plt.xlabel('Time (ms)')

            ax = plt.subplot(n_probes,2,2*i+2)
            plt.plot(record.w_stdp[0,:, 3+k, 3+n_probes+k].cpu().numpy())
            if i < n_probes-1:
                plt.axis('off')
            else:
                ax.get_yaxis().set_visible(False)
                sns.despine(left=True)
                plt.xlabel('Time (ms)')

    plt.figure(figsize=figsize)
    plt.axhline(0, c='lightgray', ls='--')
    first_half = np.ones(n_probes//2 + 1) * (2 + int(n_probes//2 * probe_spacing/stdp_cfg.time_step)) + 1
    second_half = np.arange(1, n_probes//2+1) * probe_spacing/stdp_cfg.time_step + first_half[0]-1
    idx_after = torch.tensor(np.concatenate((first_half, second_half)),dtype=torch.long)
    dw = record.w_stdp[:, idx_after, range(3,3+n_probes), range(3+n_probes,3+2*n_probes)] - \
         record.w_stdp[:, idx_after-1, range(3,3+n_probes), range(3+n_probes,3+2*n_probes)]
    dw_0 = np.flip(dw[0].cpu().numpy())
    plt.plot(dt, dw_0, 'g*-')
    plt.xlabel('t_post - t_pre (ms)')
    plt.ylabel('Relative weight change')
    
    if with_noise:
        dw_mean = np.flip(dw.mean(dim=0).cpu().numpy())
        plt.plot(dt, dw_mean, 'ro-')
        return dw_0, dw_mean
    else:
        return dw_0

In [9]:
def stdp_check_all(**kwargs):
    for source_name, targets in projections.items():
        for target_name, projection in targets.items():
            if projection.A_p > 0 or projection.A_d > 0:
                stdp_plot(source_name, target_name, **kwargs)

In [10]:
stdp_check_all(v_post=0.0)

BoxKeyError: "'Box' object has no attribute 'th_dep'"

In [None]:
stdp_check_all(v_post = 0.5)

In [None]:
stdp_check_all(v_post = 0.9)

In [None]:
stdp_check_all(v_post = 0.8, with_noise = True)