In [1]:
import numpy as np
import json
import os
from copy import deepcopy

from stp_utils import *

fr_file = os.path.join('..', '..', 'Analysis', 'analysis_results', 'baseline_unit_fr.npz')
SYN_PATH = os.path.join('../', 'components', 'synaptic_models', 'synapses_STP')
NO_STP_PATH = os.path.join('../', 'components', 'synaptic_models', 'synapses_no_STP')

write_syn_params = True
estimate_from_unit_fr = True  # or directly using the population mean firing rate
update_weight = 1.0  # weight for the update. New U = weight * estimated P + (1 - weight) * old U

### Load data

In [2]:
if estimate_from_unit_fr:
    with np.load(fr_file) as f:
        pop_fr = dict(f.items())
    population = list(pop_fr)
    valid_idx = {p: np.nonzero(fr > 0)[0] for p, fr in pop_fr.items()}
else:
    os.chdir('..')
    from build_input import SHELL_FR
    os.chdir('Tuning')
    population = SHELL_FR.index.tolist()

syn_files = os.listdir(SYN_PATH)
syn_params = {}
conn_name = {}
for file in syn_files:
    conn, _ = os.path.splitext(file)
    syn = tuple(conn.split('2'))
    with open(os.path.join(SYN_PATH, file), 'r') as f:
        syn_params[syn] = json.load(f)
    conn_name[syn] = conn

In [3]:
select_syn = []
PN = ['CP', 'CS']
ITN = ['FSI', 'LTS']
for pre, post in syn_params:
    sel = pre in population
    # P2I = pre in PN and post in ITN
    # I2P = pre in ITN and post in PN
    # sel = sel and (P2I or I2P)
    if sel:
        select_syn.append((pre, post))

### Estimate STP efficacy from baseline firing rates

In [4]:
P_est = {}
P_std = {}
for syn in select_syn:
    p = syn_params[syn]
    pre = syn[0]
    U, tau_d, tau_f = p['Use'], p['Dep'] / 1000, p['Fac'] / 1000
    if estimate_from_unit_fr:
        rates = pop_fr[pre]
        idx = valid_idx[pre]
        P = np.full_like(rates, U)
        P[idx] = estimate_steady_state(rates[idx], U=U, tau_d=tau_d, tau_f=tau_f)['m_P']
        P_est[syn] = np.average(P, weights=rates)  # weighted average
        P_std[syn] = np.std(P)
    else:
        rate = SHELL_FR.loc[pre, 'mean']
        P_est[syn] = estimate_steady_state(rate, U=U, tau_d=tau_d, tau_f=tau_f)['m_P'].item()
        P_std[syn] = 0.

### Set new synapse parameters without STP

In [5]:
syn_params_no_STP = deepcopy(syn_params)

print("Set `Use` according to estimated P")
for syn in select_syn:
    p = syn_params[syn]
    new_p = syn_params_no_STP[syn]
    new_p['Use'] = update_weight * P_est[syn] + (1 - update_weight) * p['Use']
    new_p['Use'] = np.clip(new_p['Use'], 0, 1).item()
    new_p['Dep'] = 0.
    new_p['Fac'] = 0.
    print(f"{conn_name[syn]:s}:")
    print(f"mean: {P_est[syn]:.4f}, stdev: {P_std[syn]:.4f}")
    print(f"Update use:  {p['Use']} --> {new_p['Use']:.4f}")

Set `Use` according to estimated P
CP2CP:
mean: 0.5487, stdev: 0.0717
Update use:  0.37 --> 0.5487
CP2CS:
mean: 0.5487, stdev: 0.0717
Update use:  0.37 --> 0.5487
CP2FSI:
mean: 0.0329, stdev: 0.0010
Update use:  0.035 --> 0.0329
CP2LTS:
mean: 0.0829, stdev: 0.0162
Update use:  0.05 --> 0.0829
CS2CP:
mean: 0.2794, stdev: 0.0576
Update use:  0.41 --> 0.2794
CS2CS:
mean: 0.2794, stdev: 0.0576
Update use:  0.41 --> 0.2794
CS2FSI:
mean: 0.1420, stdev: 0.0174
Update use:  0.18 --> 0.1420
CS2LTS:
mean: 0.1862, stdev: 0.0273
Update use:  0.13 --> 0.1862
FSI2CP:
mean: 0.1351, stdev: 0.0557
Update use:  0.3 --> 0.1351
FSI2CS:
mean: 0.1351, stdev: 0.0557
Update use:  0.3 --> 0.1351
FSI2FSI:
mean: 0.1351, stdev: 0.0557
Update use:  0.3 --> 0.1351
FSI2LTS:
mean: 0.1351, stdev: 0.0557
Update use:  0.3 --> 0.1351
LTS2CP:
mean: 0.4596, stdev: 0.0689
Update use:  0.3 --> 0.4596
LTS2CS:
mean: 0.4596, stdev: 0.0689
Update use:  0.3 --> 0.4596
LTS2FSI:
mean: 0.4596, stdev: 0.0689
Update use:  0.3 --> 0.45

In [6]:
if write_syn_params:
    if not os.path.isdir(NO_STP_PATH):
        os.mkdir(NO_STP_PATH)

    for syn, p in syn_params_no_STP.items():
        file = os.path.join(NO_STP_PATH, conn_name[syn] + '.json')
        with open(file, 'w') as f:
            json.dump(p, f, indent=4)