In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import PrintNNinfo as out
import os, subprocess
from time import time
from tqdm import tqdm
from scipy.ndimage import gaussian_filter1d, gaussian_filter

In [3]:
def CreateERnetwork(cell_types, ps):
    n = len(ps)
    syn_types = np.arange(n*n).reshape(n, n)
    cnt_map = np.ones([len(cell_types), len(cell_types)], dtype=int) * (-1)
    
    for i, cpre in enumerate(cell_types):
        for j, cpost in enumerate(cell_types):
            if i != j:
                if np.random.uniform() < ps[cpre][cpost]:
                    cnt_map[i][j] = syn_types[cpre][cpost]
    return cnt_map


def CreateFullnetwork(cell_types, ps):
    n = len(ps)
    syn_types = np.arange(n*n).reshape(n, n)
    cnt_map = np.ones([len(cell_types), len(cell_types)], dtype=int) * (-1)
    
    for i, cpre in enumerate(cell_types):
        for j, cpost in enumerate(cell_types):
            if i != j:
                cnt_map[i][j] = syn_types[cpre][cpost]
    return cnt_map    


def get_Poisson_t(f, N, tmax, dt, t0=0):
    prob = f*dt
    times = np.arange(t0, tmax, dt)
    ts = []
    for i in range(N):
        rand = np.random.uniform(low=0, high=1e3, size=len(times))
        ids = rand < prob
        ts.append(times[ids])
    return ts

def select_targets(n_stims, n_cells, n_overlap):
    targets = []
    for i in range(n_overlap):
        targets = np.concatenate((targets, [i for i in range(n_cells)]))
    targets = list(targets.astype(np.int))
        
    target_ids = []
    for i in range(n_stims):
        tid = [np.random.choice(targets)]
        targets.remove(tid)
        for j in range(n_overlap-1):
            n = np.random.choice(targets)
            while n in tid:
                n = np.random.choice(targets)
            tid.append(n)
            targets.remove(n)
        target_ids.append(tid)
    return target_ids

def set_ext_types(target_ids, n_ext_exc, cell_types, ext_syn_types):
    n_ext_inh = len(target_ids) - n_ext_exc
#     ext_syn_types = [[0, 1], [2, 3]] # ext_PN -> ? / ext_Inh -> ?
    ext_types = []
    for i, tid in enumerate(target_ids):
        ext_types.append([])
        if i < n_ext_exc:
            pre = 0
        else:
            pre = 1
        for n in tid:
            ext_types[-1].append(ext_syn_types[pre][cell_types[n]])
    return ext_types
    
    
def show_raster(tspks, n_exc):
    for i, t in enumerate(tspks):
        if i < n_exc:
            color = 'r'
        else:
            color = 'b'
        plt.vlines(t, i-0.5, i+0.5, color=color)
    
    
def read_all(prefix):
    t, v = out.readOut(prefix+'_v.csv')
    _, ii = out.readOut(prefix+'_i.csv')
    tspks = []
    for i in range(v.shape[1]):
        tspks.append(t[v[:, i] == 30])
    return t, v, ii, tspks


def getSTFFT(x, t, wbin=50000, mbin=1000, dt=0.01, maxf=200):
    ids = np.arange(wbin//2, len(x)-wbin//2, mbin)
#     ids = np.arange(0, len(x), mbin)
    f = np.fft.fftfreq(wbin, dt/(1e3))
    f = f[:wbin//2]
    idf = (f<200) & (f>1)
    psd = np.zeros([sum(idf), len(ids)])
    window = np.hanning(wbin-1)
    for n, i in enumerate(ids):
        y = x[i-wbin//2:(i+wbin//2-1)] * window
        fx = np.fft.fft(y)/wbin
        fx = abs(fx[:wbin//2])
        psd[:, n] = fx[idf]
    return t[ids], f[idf], psd
    
    
def getFFT(x, idt, dt):
    fx = np.fft.fft(x[idt]) / sum(idt)
    f = np.fft.fftfreq(sum(idt), dt/(1e3))
    fx = abs(fx[:int(sum(idt)//2)])
    f = f[:int(sum(idt)//2)]
    return f, fx

def show_grid(v, t, wbin=50000, mbin=100, s=5, norm=False, vmax=0.02):
    t, f, psd = getSTFFT(v, t, wbin=wbin)
    idf = f>10
    txy, fxy = np.meshgrid(t, f[idf])
    
    if norm:
        psd = psd - np.average(psd, axis=1).reshape(-1, 1)
    im = gaussian_filter(psd[idf, :], s)
    
    plt.pcolormesh(txy, fxy, im, cmap='jet', vmin=0, vmax=vmax)
    