In [1]:
import cantata
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from cantata.plotting import output as cp
import scipy.stats

In [2]:
figsize = (14,5)

batch_size = 1
dt = 1e-3
device = torch.device('cuda')

In [3]:
conf = cantata.config.read('/home/felix/projects/cantata/cantata/configs/1tier.yaml')
conf.pop('_prototypes')
print(conf.to_yaml())

input:
  n_channels: 1
  populations:
    Thal:
      n: 2
      channel: 0
      targets:
        L1.Exc:
          spatial: true
          density: 0.5
          delay: 0.001
          sigma: 0.5
          STDP_frac: 0.0
          A_p: 0.0
          A_d: 0.0
          wmax: 1.0
        L1.Inh:
          spatial: true
          density: 0.5
          delay: 0.001
          sigma: 0.5
          STDP_frac: 0.0
          A_p: 0.0
          A_d: 0.0
          wmax: 1.0
areas:
  L1:
    populations:
      Exc:
        sign: 1
        n: 200
        p: -0.1
        noise_N: 2000
        noise_rate: 8.0
        noise_weight: 0.002
        th_ampl: 0.05
        targets:
          Exc:
            spatial: true
            density: 0.5
            delay: 0.001
            STDP_frac: 1.0
            A_p: 1.5
            A_d: 0.5
            sigma: 0.5
            wmax: 1.0
          Inh:
            spatial: true
            density: 0.5
            delay: 0.001
            sigma: 0.5
         

In [4]:
second = int(1/dt)

In [5]:
n_steps = int(10 * second)
rate_off, rate_on = 0, 0 # Hz
t_off, t_on = .15*second, .15*second

In [6]:
def get_inputs(n_steps, periods, rates, device = torch.device('cpu'), reshape = True):
    seq = torch.arange(n_steps, device=device)
    pattern = torch.zeros(n_steps, device=device)
    stop, total = 0, int(sum(periods))
    for rate, period in zip(rates, periods):
        pattern = torch.where(
            seq % total > stop,
            torch.tensor([rate],device=device,dtype=torch.float),
            pattern
        )
        stop += int(period)
    return pattern.reshape(-1,1,1).expand(-1, batch_size, -1) if reshape else pattern

In [7]:
# inputs = get_inputs(n_steps, (t_off,t_on), (rate_off,rate_on), device, reshape=False)
# plt.plot(inputs.cpu())

# inputs = inputs.reshape(-1,1,1).expand(-1, batch_size, -1)

In [8]:
# m = cantata.Conductor(conf, batch_size, dt, STDP=cantata.elements.Abbott).to(device)
# with torch.no_grad():
#     X = m(inputs)

In [9]:
def get_STDP_mask(area, name_pre, name_post):
    e,o = area.N, area.N
    pre = area.p_idx[area.p_names.index(name_pre)]
    post = area.p_idx[area.p_names.index(name_post)]
    synapse = area.synapses_int
    
    projection = torch.zeros(e,o, dtype=torch.bool) # (e,o)
    projection[np.ix_(pre, post)] = True
    
    active = synapse.signs != 0 # (e,o)
    
    plastic = (synapse.longterm.A_p != 0) * (synapse.longterm.A_d != 0) # (e,o)
    
    mask = active * plastic * projection.to(active) # (e,o)
    return mask
    
#     mask_cpu = mask.cpu()
#     e,o = synapse.signs.shape
#     pre = torch.arange(e).unsqueeze(1).expand(e,o)[mask_cpu]
#     post = torch.arange(o).expand(e,o)[mask_cpu]
#     return mask, (pre, post)

In [10]:
def quantify_unstimulated(model, early = (10,20), late = (50,60)):
    def get_empty(secs):
        return get_inputs(int(secs * second), (0,), (0,), device)
#     def record_W(mask):
#         def inner(m, *args):
#             W.append(m.W[:, mask])
#         return inner
    STDP = model.areas[0].synapses_int.longterm
    mask = get_STDP_mask(model.areas[0], 'Exc', 'Exc')
    
    # Settle without observation until early begins:
    if early[0] > 0:
        model(get_empty(early[0]))
    
    # Observe spikes, STDP in early window:
#     observer = STDP.register_forward_hook(record_W(mask))
#     W = []
    Wpre = STDP.W[:, mask]
    X = model(get_empty(early[1] - early[0]))
    Wpost = STDP.W[:, mask]
#     observer.remove()
    w0 = get_stdp_measures(torch.stack((Wpre, Wpost)) / conf.areas.L1.populations.Exc.targets.Exc.wmax)
    r0 = get_rate_measures(X[1], model.areas[0])
    
    # Run through to late window:
    model(get_empty(late[0] - early[1]))
    
    # Observe spikes, STDP in late window:
#     observer = STDP.register_forward_hook(record_W(mask))
#     W = []
    Wpre = STDP.W[:, mask]
    X = model(get_empty(late[1] - late[0]))
    Wpost = STDP.W[:, mask]
#     observer.remove()
    w1 = get_stdp_measures(torch.stack((Wpre, Wpost)) / conf.areas.L1.populations.Exc.targets.Exc.wmax, t=-1)
    r1 = get_rate_measures(X[1], model.areas[0])
    
    return r0, w0, r1, w1

In [11]:
def get_rate_measures(X, area, quantiles = [0, .1, .25, .5, .75, .9, 1]):
    # X: (t,b,N) in area
    rates = X.sum(dim=(0,1)) / (X.shape[0] * X.shape[1] * dt) # Hz, (N)
    ret = torch.zeros(len(area.p_idx), len(quantiles)+2)
    for i, idx in enumerate(area.p_idx):
        ret[i,2:] = torch.quantile(rates[idx], torch.tensor(quantiles).to(rates))
        ret[i,1], ret[i,0] = torch.std_mean(rates[idx])
    return ret.cpu()

In [12]:
def get_stdp_measures(W, t = 0, quantiles = [0, .1, .25, .5, .75, .9, 1], tol = 1e-4):
    # W: (t,b,masked)
    q = torch.quantile(W[t], torch.tensor(quantiles).to(W)).cpu()
    std, mean = torch.std_mean(W[t])
    
    # Cosine similarity between the weight vectors in different batches
    # : Distance between batches
    batch_size = W.shape[1]
    i,j = torch.triu_indices(batch_size, batch_size, 1)
    cs = torch.nn.functional.cosine_similarity(W[t], W[t].unsqueeze(1), dim=-1)
    batch_std, batch_mean = torch.std_mean(cs[i,j])
    
    # Cosine similarity between weight vectors at start & end of period; mean & std:
    # : Distance between time points
    cs = torch.nn.functional.cosine_similarity(W[0], W[-1], dim=1)
    time_std, time_mean = torch.std_mean(cs)
    
    # Proportion of saturated weights at higher/lower weight bound
    sat_norm = W.shape[1] * W.shape[2]
    sat_high, sat_low = torch.sum(W[t] > 1-tol)/sat_norm, torch.sum(W[t] < tol)/sat_norm
    
    ret = torch.tensor([batch_mean, batch_std, time_mean, time_std, mean, std, sat_high, sat_low])
    return torch.cat((ret, q))

In [13]:
def get_stimulated_measures(X, area, periods, onset = 20, quantiles = [0, .1, .25, .5, .75, .9, 1], sig=.05):
    # X: (t,b,N) in area
    stop, total = 0, int(sum(periods))
    nperiods, nreps = len(periods), X.shape[0] // total
    batch_size, N = X.shape[1], X.shape[2]
    assert total*nreps == X.shape[0], 'Periods must neatly tile the input.'
    
    r_pulse = torch.empty(nperiods, nreps, batch_size, N)
    r_onset = torch.empty(nperiods, nreps, batch_size, N)
    for i,p in enumerate(periods):
        p = int(p)
        idx = torch.arange(stop, stop+p).expand(nreps,-1) + (torch.arange(nreps)*total)[:,None]
        S = X[idx,:,:] # (nreps, p, b, N)
        r_pulse[i] = S.sum(dim=1).cpu() / (p * dt) # Hz, (nreps, b, N)
        r_onset[i] = S[:,:onset].sum(dim=1).cpu() / (onset * dt) # Hz
        stop += p
    
    ret = torch.empty(len(area.p_idx), nperiods, len(quantiles)+4)
    for i,idx in enumerate(area.p_idx):
        rx_pulse = r_pulse[:,:,:,idx] # (nperiods, nreps, b, Nx)
        rx_onset = r_onset[:,:,:,idx]

        # Firing rates during each stimulus (full pulse duration)
        # Note, batch instances are treated as separate networks, since they may have self-organised differently
        r = rx_pulse.sum(dim=1) / nreps # Hz, (nperiods, b, Nx)
        std_pulse, mean_pulse = torch.std_mean(r, dim=(1,2)) # (nperiods)
        q_pulse = np.quantile(r.numpy(), np.array(quantiles), axis=(1,2)) # (|q|, nperiods)
        q_pulse = torch.tensor(q_pulse, dtype=torch.float)
        
        # Finding pulse level sensitive units
        # 2-tailed independent test; assumes that on and off are independent, even though they follow each other
        # Value reflects the fraction of units that significantly increase their firing for a given stimulus.
        level_sensitive = torch.empty(nperiods) # (nperiods)
        for j in range(nperiods):
            off = [k for k in range(nperiods) if k != j]
            if nperiods > 2:
                raise NotImplemented
            else:
                rx_off = rx_pulse[off]
            _, p = scipy.stats.ttest_ind(
                rx_pulse[j].numpy(), rx_off.numpy(), axis=0, alternative='greater') # (b, Nx)
            level_sensitive[j] = np.sum(p<sig) / (batch_size*N) # scalar

        # Finding onset-selective units
        # 2-tailed paired test, since onset and rest-of-pulse are tightly linked
        _, p = scipy.stats.ttest_rel(rx_onset, rx_pulse, axis=1) # (nperiods, b, Nx)
        onset_sensitive = np.sum(p<sig, axis=(1,2)) / (batch_size*N) # (nperiods)
        onset_sensitive = torch.tensor(onset_sensitive, dtype=torch.float)
        
        ret[i] = torch.cat((
            mean_pulse[:,None],
            std_pulse[:,None],
            q_pulse.T,
            level_sensitive[:,None],
            onset_sensitive[:,None]
        ), dim=1)
    return ret # (nareas, nperiods, |q|+4)

In [14]:
def quantify_stimulated(model, periods, rates, onset = 20, settle = 5, early = (10,15), late = (55,60)):
    '''Expects period, onset in ticks; settle, early, late in seconds.'''
    assert second % sum(periods) == 0, 'Periods must tile neatly into 1-second segments.'
    for p in periods:
        assert p > onset, 'Periods must be longer than the onset'
    def record_W(mask):
        def inner(m, *args):
            W.append(m.W[:, mask])
        return inner
    STDP = model.areas[0].synapses_int.longterm
    mask = get_STDP_mask(model.areas[0], 'Exc', 'Exc')
    
    # Settle without observation and without stimulation:
    if settle > 0:
        model(get_inputs(int(settle * second), (0,), (0,), device))
    
    inputs = get_inputs(int(late[1]*second), periods, rates)
    
    # Run without observation until early begins:
    t = 0
    if early[0] > 0:
        t = int(early[0] * second)
        model(inputs[:t].to(device))
    
    # Observe spikes, STDP in early window:
    t0, t = t, int(early[1] * second)
    Wpre = STDP.W[:, mask]
    X = model(inputs[t0:t].to(device))
    Wpost = STDP.W[:, mask]
    w0 = get_stdp_measures(torch.stack((Wpre, Wpost)) / conf.areas.L1.populations.Exc.targets.Exc.wmax)
    r0 = get_rate_measures(X[1], model.areas[0])
    s0 = get_stimulated_measures(X[1], model.areas[0], periods)
    
    # Run through to late window:
    t0, t = t, int(late[0] * second)
    model(inputs[t0:t].to(device))
    
    # Observe spikes, STDP in late window:
    t0, t = t, int(late[1] * second)
    Wpre = STDP.W[:, mask]
    X = model(inputs[t0:t].to(device))
    Wpost = STDP.W[:, mask]
    w1 = get_stdp_measures(torch.stack((Wpre, Wpost)) / conf.areas.L1.populations.Exc.targets.Exc.wmax, t=-1)
    r1 = get_rate_measures(X[1], model.areas[0])
    s1 = get_stimulated_measures(X[1], model.areas[0], periods)
    
    return r0, w0, s0, r1, w1, s1

In [15]:
periods = (int(50e-3/dt), int(50e-3/dt)) # ticks
rates = (0, 50) # Hz

settle, early, late = 5, (10,15), (55,60) # seconds

assert second % sum(periods) == 0

In [16]:
batch_size = 32
m = cantata.Conductor(conf, batch_size, dt, STDP=cantata.elements.Abbott).to(device)
with torch.no_grad():
    r0_u, w0_u, r1_u, w1_u = quantify_unstimulated(m, early=early, late=late)
    m.reset()
    r0_s, w0_s, s0, r1_s, w1_s, s1 = quantify_stimulated(m, periods, rates, settle=settle, early=early, late=late)

  return _methods._var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  ret = um.true_divide(


In [17]:
r0_u, w0_u, r1_u, w1_u

(tensor([[12.8770,  3.0011,  0.0000,  9.8681, 12.2375, 13.4250, 14.4734, 15.6700,
          19.2125],
         [12.1128,  5.6502,  0.4813,  4.7375,  9.0250, 13.0125, 14.4703, 16.5688,
          26.6562]]),
 tensor([0.9887, 0.0036, 0.9977, 0.0016, 0.9261, 0.1797, 0.5062, 0.0115, 0.0000,
         0.9062, 0.9314, 1.0000, 1.0000, 1.0000, 1.0000]),
 tensor([[13.4670,  2.9992,  0.0000, 10.3294, 12.7453, 14.0000, 15.1156, 16.3100,
          19.7000],
         [12.6037,  5.8171,  0.2875,  4.8250,  9.3938, 13.5188, 15.1875, 17.3469,
          27.1562]]),
 tensor([9.8792e-01, 3.6498e-03, 9.9875e-01, 3.5875e-04, 9.3460e-01, 1.7337e-01,
         5.1857e-01, 1.7836e-02, 0.0000e+00, 9.0808e-01, 9.3410e-01, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00]))

In [18]:
r0_s, w0_s, r1_s, w1_s

(tensor([[13.7054,  3.1407,  0.0000, 10.6238, 13.0016, 14.3125, 15.4297, 16.5637,
          19.9500],
         [12.8563,  5.9270,  0.5000,  5.0781,  9.4453, 13.7531, 15.5891, 17.5531,
          27.5000]]),
 tensor([9.8923e-01, 3.3350e-03, 9.9833e-01, 9.8585e-04, 9.2852e-01, 1.7353e-01,
         4.7546e-01, 1.1571e-02, 0.0000e+00, 9.0411e-01, 9.3253e-01, 9.7510e-01,
         1.0000e+00, 1.0000e+00, 1.0000e+00]),
 tensor([[13.5378,  2.9360,  0.0000, 10.5050, 12.8375, 14.0875, 15.1719, 16.3188,
          19.7812],
         [12.7041,  5.8457,  0.4813,  4.8813,  9.3891, 13.5719, 15.2328, 17.5031,
          27.2437]]),
 tensor([9.8861e-01, 3.4260e-03, 9.9866e-01, 7.9792e-04, 9.3570e-01, 1.6746e-01,
         5.1252e-01, 1.5468e-02, 0.0000e+00, 9.1033e-01, 9.3234e-01, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00]))

In [19]:
s0, s1

(tensor([[[1.3376e+01, 4.0665e+00, 0.0000e+00, 9.6000e+00, 1.1600e+01,
           1.3600e+01, 1.5600e+01, 1.7600e+01, 2.5200e+01, 0.0000e+00,
           1.7944e-02],
          [1.4035e+01, 3.9213e+00, 0.0000e+00, 1.0400e+01, 1.2400e+01,
           1.4400e+01, 1.6400e+01, 1.8000e+01, 2.4000e+01, 0.0000e+00,
           6.5918e-03]],
 
         [[1.2567e+01, 6.3235e+00, 0.0000e+00, 4.0000e+00, 8.8000e+00,
           1.2800e+01, 1.6000e+01, 2.0800e+01, 3.3600e+01, 0.0000e+00,
           5.1270e-03],
          [1.3146e+01, 6.2060e+00, 0.0000e+00, 4.8000e+00, 9.6000e+00,
           1.3400e+01, 1.6400e+01, 2.0760e+01, 3.3600e+01, 0.0000e+00,
           2.5635e-03]]]),
 tensor([[[1.3261e+01, 3.7881e+00, 0.0000e+00, 9.6000e+00, 1.1600e+01,
           1.3600e+01, 1.5600e+01, 1.7200e+01, 2.4400e+01, 0.0000e+00,
           2.4780e-02],
          [1.3815e+01, 4.1342e+00, 0.0000e+00, 1.0000e+01, 1.2000e+01,
           1.4000e+01, 1.6000e+01, 1.8400e+01, 2.6800e+01, 0.0000e+00,
           1.5869e-02]