In [None]:
import sys
import glob
import numpy as np
import matplotlib.pyplot as plt
from mpmath import mp, nstr
%matplotlib inline

if '/Users/daniele/Postdoc/Research/CA3/entropy' not in sys.path:
    sys.path.append('/Users/daniele/Postdoc/Research/CA3/entropy')
import ctw
from utils import binarize

def despine(ax, sides=('right','top')):
    for side in sides:
        ax.spines[side].set_visible(False)

In [None]:
def rasterplot(spikes, tend=np.inf, ax=None, **kwargs):
    if ax is None:
        ax = plt.gca()
    for i,train in enumerate(spikes):
        x = train[train < tend]
        y = i + 1 + np.zeros(x.size)
        ax.plot(x, y, '.', **kwargs)

In [None]:
def compute_ISI_entropy(spike_times, n_bins, ISI = None, ISI_min = None, ISI_max = None):
    if spike_times is not None and ISI is None:
        ISI = np.diff(spike_times)
    if ISI_min is None:
        ISI_min = ISI.min()
    if ISI_max is None:
        ISI_max = ISI.max()
    edges = np.linspace(ISI_min, ISI_max, n_bins)
    n,_ = np.histogram(ISI, edges, range=edges[[0,-1]])
    p = n / n.sum()
    idx = p > 0
    return - (p[idx] * np.log2(p[idx])).sum()

In [None]:
base_folder = '/Users/daniele/Postdoc/Research/CA3/OPTIMIZATIONS'
cell_type = 'Thorny'
if cell_type == 'Thorny':
    cell_id = 'DH070813'
    optimization_run = '20191208071008_DH070813_'
else:
    cell_id = 'DH070213C3'
    optimization_run = '20191206232623_DH070213C3_'    
folder = base_folder + '/' + cell_type + '/' + cell_id + '/' + optimization_run + \
    '/synaptic_cooperativity_experiment/'
data_files = glob.glob(folder + 'synaptic_activation_202102*.npz')
data = [np.load(data_file, allow_pickle=True) for data_file in data_files]

In [None]:
config = data[0]['config'][()]
delay = config['sim']['delay']
dur = config['sim']['stim_dur']
ttran = 100
t0, t1 = delay + ttran, delay + dur
poisson_freq = config['poisson_frequency']
bursts_freq = config['synaptic_activation_frequency']

In [None]:
n_trials = len(data)
spike_times = []
presyn_spike_times = []
ISI = []
tend = []
ttran = 100
for d in data:
    config = d['config'][()]['sim']
    t0, t1 = config['delay'] + ttran, config['delay'] + config['stim_dur']
    tend.append(t1 - t0)
    spks = d['spike_times']
    ISI.append(np.diff(spks))
    spks = spks[(spks > t0) & (spks < t1)]
    spike_times.append((spks - t0) * 1e-3)
    if bursts_freq > 0:
        pre_spks = np.sort(d['presyn_spike_times'].flatten())
        pre_spks = pre_spks[(pre_spks > t0) & (pre_spks < t1)]
        presyn_spike_times.append((pre_spks - t0) * 1e-3)
ISI = np.concatenate(ISI) * 1e-3

In [None]:
bin_size = 5e-3
x = np.concatenate([binarize(spikes, bin_size, tt * 1e-3) for spikes,tt in zip(spike_times, tend)])
firing_rate = 1 / ISI.mean()
CV = ISI.std() / ISI.mean()

depth = 10
Hx = ctw.compute_entropy(x, depth, alphabet_length = 2)
Hx /= bin_size
    
# theoretical entropy of a Poisson process
H_theor = firing_rate * mp.log(mp.e / (firing_rate * bin_size), 2)

if bursts_freq > 0:
    y = np.concatenate([binarize(pre_spikes, bin_size, tt * 1e-3) for pre_spikes,tt in zip(presyn_spike_times, tend)])
    Hxy = ctw.compute_conditional_entropy(x, y, depth, alphabet_length = 4)
    Hxy /= bin_size
    MI = Hx - Hxy

print(f'Firing rate: {firing_rate:.2f} spike/s')
print(f'CV: {CV:.4f}')
print(f'Entropy: {nstr(Hx)} bits/sec')
print(f'Entropy of a Poisson process at the same firing rate: {nstr(H_theor)} bits/sec')
print(f'Ratio: {nstr(Hx / H_theor)}')
if bursts_freq > 0:
    print(f'MI: {nstr(MI)} bits/sec')

In [None]:
i = 0
outfile = 'synaptic_bg_'

if bursts_freq > 0:
    presyn_spike_times = data[i]['presyn_spike_times']
    presyn_burst_times = np.tile(data[i]['presyn_burst_times'][np.newaxis, :], [presyn_spike_times.shape[0], 1])
    presyn_spike_times = (presyn_spike_times - presyn_burst_times).T
    presyn_idx = np.argsort(presyn_spike_times.max(axis=1))
    presyn_spike_times = presyn_spike_times[presyn_idx,:]

    n_rows = 3
    outfile += 'with'
else:
    n_rows = 1
    outfile += 'without'
outfile += '_presynaptic_bursts_' + cell_type.lower() + '.pdf'

ms = 2

bin_size = 1000 / bursts_freq
spike_times = data[i]['spike_times']
bins = np.arange(0, np.ceil(spike_times[-1]) + bin_size/2, bin_size)
n,_ = np.histogram(spike_times, bins)
spks = []
idx = np.concatenate([[0], np.cumsum(n)])
for j in range(1, len(idx)):
    spks.append(spike_times[idx[j-1]:idx[j]] - (j - 1) * bin_size)

spks = spks[:200]
fig,ax = plt.subplots(n_rows, 1, figsize=(5, 2.5 * n_rows), squeeze=False)
rasterplot(spks, color='k', markersize=ms, ax=ax[0,0])
ax[0,0].set_yticks([1, len(spks)])
ax[0,0].set_ylabel('Trial #')
if n_rows == 1:
    ax[0,0].set_xlabel('Time (ms)')

if bursts_freq > 0:
    burst_times = data[i]['presyn_burst_times']
    # burst_times = [data[i]['presyn_spike_times'][:,j].min() \
    #                for j in range(len(data[i]['presyn_burst_times']))]
    spks = []
    window = 100
    for burst_time in burst_times:
        jdx = (spike_times > burst_time) & (spike_times < burst_time + window)
        spks.append(spike_times[jdx] - burst_time)
    spks = [spks[pre_idx] for pre_idx in presyn_idx]
    rasterplot(spks, color='k', markersize=ms, ax=ax[1,0])
    ax[1,0].set_yticks([1, len(spks)])
    ax[1,0].set_xlim([0, window])
    ax[1,0].set_ylabel('Presynaptic burst #')
    
    rasterplot(presyn_spike_times, color='k', markersize=ms, ax=ax[2,0])
    ax[2,0].set_yticks([1, len(spks)])
    ax[2,0].set_xlim([0, window])
    ax[2,0].set_ylabel('Presynaptic burst #')
    ax[2,0].set_xlabel('Time (ms)')
    
    bin_size = 5
    n_bins = window // bin_size + 1
    n,edges = np.histogram(np.sort(np.concatenate(spks)), bins=n_bins, range=(0, window))
    color = 'm'
    ax_rate = ax[1,0].twinx()
    ax_rate.plot(edges[:-1], n / (len(spks) * bin_size * 1e-3), color=color, lw=2)
    ax_rate.set_ylabel('Firing rate (spike/s)', color=color)
    ax_rate.tick_params(axis='y', labelcolor=color)
    ax_rate.set_ylim([0,50])
    ax[0,0].set_title(f'F={firing_rate:.1f} AP/s CV={CV:.2f} H={nstr(Hx)} bits/s MI={nstr(MI)} bits/s')
else:
    ax[0,0].set_title(f'F={firing_rate:.1f} AP/s CV={CV:.2f} H={nstr(Hx)} bits/s')

for a in ax[:,0]:
    despine(a)
ax_rate.spines['top'].set_visible(False)

fig.tight_layout()
fig.savefig(outfile)