### Imports

In [None]:
import numpy as np
import sympy as sym
import numpy as np
from scipy.optimize import fsolve
import scipy.signal as signal
import matplotlib.pyplot as plt
from matplotlib import cm, colors
import pandas as pd
import seaborn as sns
import os
import json
from scipy.io import loadmat
from matplotlib.colors import ListedColormap

import sys
sys.path.append(r'path\to\shared_utils')

In [None]:
global color_dict
color_dict = json.load(open('color_dict.json'))

global all_color_data
all_color_data = loadmat('colorList.mat')['all_data'][0][0]

def make_color_pastel(colormap, c=0.25, n=256):
    colormap = (1. - c) * colormap(np.linspace(0., 1., n)) + c * np.ones((n, 4))
    return ListedColormap(colormap)


<div>
Hassenstein-Reichardt detector <br>
High-pass filter motion (tau=50ms) <br>
Half-wave rectification <br>
Low-pass filter motion (first order, tau=15ms) <br>
multiply with unfiltered signals <br>
twice in mirror symmetrical manner and subtract <br>
scaled by "visual weight" <br>
</div>
<p>
HSright = WHSPD * FtBright  - WHSNPD * BtFright + WHScontra * BtFleft + WH2PD * c.c * BtFleft <br>
H2right = WH2PD * BtFright - WH2NPD * FtBright + WH2contra * FtBleft + WHSPD * c.c * FtBleft <br>
</p>
<p>
Bright = WHSB * HSright + WH2B * H2left - WUB * Uright <br>
Uright = WHSU * HSright + WH2U * H2left - WBU * Bleft <br>
</p>
<p>
DNp15right = WHSp15 * HSright + WH2p15 * H2left - WUp15 * Uright - WBp15 * Bleft <br>
</p>
<p>
Vangular = Wvisual * (DNp15right - DNp15left) + noise <br>
</p>

### EMD

#### rectifier

In [None]:
def ReLU(x):
    return x * (x > 0)
def ReLU_OFF(x):
    return -x * (x < 0)
def d_ReLU(x):
    return 1 * (x > 0)
def d_ReLU_OFF(x):
    return -1 * (x < 0)

#### filters

In [None]:
def high_pass_butter(input, tau=150):
    cutoff_frequency = 1000/(2*np.pi*tau) # 1000/(2*3.14*150) = 1Hz
    sos = signal.butter(1, cutoff_frequency, 'highpass', fs=1000, output='sos')
    filtered = signal.sosfilt(sos, input)
    return filtered
def low_pass_butter(input, tau=50):
    cutoff_frequency = 1000/(2*np.pi*tau) # 1000/(2*3.14*50) = 3Hz
    sos = signal.butter(1, cutoff_frequency, 'lowpass', fs=1000, output='sos')
    filtered = signal.sosfilt(sos, input)
    return filtered

def high_pass(input, tau=150):
    filtered = np.zeros(input.shape[0])
    for i in range(1, input.shape[0]):
        filtered[i] = filtered[i-1] + ((input[i] - filtered[i-1]) * (1 - np.exp(-(1/tau))))
    filtered = input - filtered
    return filtered

def low_pass(input, tau=50):
    filtered = np.zeros(input.shape[0])
    for i in range(1, input.shape[0]):
        filtered[i] = filtered[i-1] + ((input[i] - filtered[i-1]) * (1 - np.exp(-(1/tau))))
    return filtered

#### Model

In [None]:
def create_input_new(temp=1, spatial=1/20, PD=True, time=10, interommatidial_angle=5):
    pattern = np.tile(np.repeat([0, 1], ((1/spatial)/0.001)/2), 3) # 1/1000th of a degree
    speed = int((temp/spatial)) # these many indices per ms
    left_input = np.zeros(time * 1000)
    right_input = np.zeros(time * 1000)
    for i in range(left_input.shape[0]):
        left_input[i] = pattern[100]
        right_input[i] = pattern[100 + (interommatidial_angle*1000)]
        if PD:
            pattern = np.roll(pattern, -speed)
        else:
            pattern = np.roll(pattern, speed)
    return left_input, right_input

In [None]:
def create_sine_input(temp=1, spatial=1/20, PD=True, time=10, interommatidial_angle=5):
    size = 2 * (((1/spatial)/0.001)/2) * 3
    sine_pattern = (np.sin(np.deg2rad(np.arange(size) * (360/(1/spatial))/1000))+1)/2 # 1/1000th of a degree
    speed = int((temp/spatial)) # these many indices per ms
    left_input = np.zeros(time * 1000)
    right_input = np.zeros(time * 1000)
    for i in range(left_input.shape[0]):
        left_input[i] = sine_pattern[100]
        right_input[i] = sine_pattern[100 + (interommatidial_angle*1000)]
        if PD:
            sine_pattern = np.roll(sine_pattern, -speed)
        else:
            sine_pattern = np.roll(sine_pattern, speed)
    return left_input, right_input

In [None]:
def create_input(temp=1, spatial=1/20, PD=True, time=10, interommatidial_angle=5):
    fs = 1000
    time = np.arange(0, time, 1/fs)
    input = signal.square(temp*2*np.pi*time)
    cycle_period = fs/temp # these many frames per cycle
    # assuming that distance between two consecutive ommatidia is 5 degrees
    shift = int((interommatidial_angle/(1/spatial)) * cycle_period)
    if PD:
        right_input = input[shift:]
        left_input = input[:-shift]
    else:
        left_input = input[shift:]
        right_input = input[:-shift]
    return left_input, right_input

def ommatidia_output(input, highpass_tau=150, lowpass_tau=50):
    # high pass filter in t
    high_pass_filtered = high_pass(input, tau=highpass_tau)
    # half-wave recitification or ReLU, ON
    rectified_ON = ReLU(high_pass_filtered)
    # half-wave recitification or ReLU, OFF
    rectified_OFF = ReLU_OFF(high_pass_filtered)
    # low pass filtered 150ms time
    delayed_ON = low_pass(rectified_ON, tau=lowpass_tau)
    delayed_OFF = low_pass(rectified_OFF, tau=lowpass_tau)
    return rectified_ON, rectified_OFF, delayed_ON, delayed_OFF

def two_quadrant_EMD(temporal_frequency=1, spatial_frequency=1/20, contrast=1, PD=True, lowpass_tau = 50, highpass_tau=150):
    """
    Elementary motion detector based on RH detector
    """
    left_input, right_input = create_sine_input(temp=temporal_frequency, spatial=spatial_frequency, PD=PD)
    left_ON, left_OFF, left_delayed_ON, left_delayed_OFF = ommatidia_output(left_input, highpass_tau, lowpass_tau)
    right_ON, right_OFF, right_delayed_ON, right_delayed_OFF = ommatidia_output(right_input, highpass_tau, lowpass_tau)
    left_ON = left_ON * (1 - np.exp(-contrast/0.3))
    left_OFF = left_OFF * (1 - np.exp(-contrast/0.3))
    left_delayed_ON = left_delayed_ON * (1 - np.exp(-contrast/0.3))
    left_delayed_OFF = left_delayed_OFF * (1 - np.exp(-contrast/0.3))
    right_ON = right_ON * (1 - np.exp(-contrast/0.3))
    right_OFF = right_OFF * (1 - np.exp(-contrast/0.3))
    right_delayed_ON = right_delayed_ON * (1 - np.exp(-contrast/0.3))
    right_delayed_OFF = right_delayed_OFF * (1 - np.exp(-contrast/0.3))
    # multiply with unaltered neughbouring signal
    coincidence_ON = left_ON * right_delayed_ON - left_delayed_ON * right_ON
    coincidence_OFF = left_OFF * right_delayed_OFF - left_delayed_OFF * right_OFF
    return coincidence_ON + coincidence_OFF

def four_quadrant_EMD(temporal_frequency=1, spatial_frequency=1/20, contrast=1, PD=True, lowpass_tau = 50, highpass_tau=150):
    """
    Elementary motion detector based on RH detector
    """
    left_input, right_input = create_input(temp=temporal_frequency, spatial=spatial_frequency, PD=PD)
    left_ON, left_OFF, left_delayed_ON, left_delayed_OFF = ommatidia_output(left_input, highpass_tau, lowpass_tau) * (1 - np.exp(-contrast/0.3))
    right_ON, right_OFF, right_delayed_ON, right_delayed_OFF = ommatidia_output(right_input, highpass_tau, lowpass_tau) * (1 - np.exp(-contrast/0.3))

    left_ON = left_ON * (1 - np.exp(-contrast/0.3))
    left_OFF = left_OFF * (1 - np.exp(-contrast/0.3))
    left_delayed_ON = left_delayed_ON * (1 - np.exp(-contrast/0.3))
    left_delayed_OFF = left_delayed_OFF * (1 - np.exp(-contrast/0.3))
    right_ON = right_ON * (1 - np.exp(-contrast/0.3))
    right_OFF = right_OFF * (1 - np.exp(-contrast/0.3))
    right_delayed_ON = right_delayed_ON * (1 - np.exp(-contrast/0.3))
    right_delayed_OFF = right_delayed_OFF * (1 - np.exp(-contrast/0.3))
    # multiply with unaltered neughbouring signal
    coincidence_ON = left_ON * right_delayed_ON - left_delayed_ON * right_ON
    coincidence_ONOFF = left_ON * right_delayed_OFF - left_delayed_ON * right_OFF
    coincidence_OFF = left_OFF * right_delayed_OFF - left_delayed_OFF * right_OFF
    coincidence_OFFON = left_OFF * right_delayed_ON - left_delayed_OFF * right_ON
    return (coincidence_ON + coincidence_OFF) - (coincidence_ONOFF + coincidence_OFFON)

#### low-pass tau

In [None]:
for i in [1/20, 1/30, 1/40]:
    response_freq = []
    for j in [0.2, 0.5, 1, 2, 4, 8, 12, 16, 20, 32]:
        response_freq.append(np.mean(two_quadrant_EMD(temporal_frequency=j, spatial_frequency = i, lowpass_tau=150, highpass_tau=150)))
    plt.plot(response_freq/max(response_freq), label=str(i))
plt.xticks(np.arange(10),['0.2', '0.5', '1', '2', '4', '8', '12', '16', '20', '32'])
plt.legend()

In [None]:
for i in [1/20, 1/30, 1/40]:
    response_freq = []
    for j in [0.2, 0.5, 1, 2, 4, 8, 12, 16, 20, 32]:
        response_freq.append(np.mean(two_quadrant_EMD(temporal_frequency=j, spatial_frequency = i, lowpass_tau=300, highpass_tau=300)))
    plt.plot(response_freq/max(response_freq), label=str(i))
plt.xticks(np.arange(10),['0.2', '0.5', '1', '2', '4', '8', '12', '16', '20', '32'])
plt.legend()

### Neuron network

In [None]:
OUTPUT_DIR = r'path\to\output\directory'

#### Weights

In [None]:
def set_weights_to_default():
    global WHSPD, WHSNPD, WHScontra, WH2PD, WH2NPD, WH2contra, WHSB, WH2B, WUB, WHSU, WH2U, WBU, \
    WHSp15, WH2p15, WUp15, WBp15, Wvisual, noise, HS_to_H2_cc, H2_to_HS_cc, H2threshold, slope, H2_neg
    # HS
    WHSPD = 1
    WHSNPD = 0.2
    WHScontra = 0.2
    # H2
    WH2PD = 0.8
    WH2NPD = 0.3
    WH2contra = 0.3
    # bIPS and uLPTC
    WHSB = 0.9
    WH2B = 0
    WUB = 0.7
    WHSU = 0.5
    WH2U = 0
    WBU = 0.5
    # DNp15
    WHSp15 = 0.8
    WH2p15 = 0.3
    WUp15 = 0.2
    WBp15 = 0.5
    Wvisual = 55
    noise = 0
    # coupling weights
    HS_to_H2_cc = 0.3
    H2_to_HS_cc = 0.3
    # H2 threshold
    H2threshold=-50
    slope=0.5
    H2_neg=5
    return 1

#### System equations

In [None]:
def LPTC_system_without_HS(input):
    H2right, H2left, Bright, Uright, Bleft, Uleft, DNp15right, DNp15left, Vangular = input

    c = H2right - (WH2PD * BtFright - WH2NPD * FtBright + WH2contra * FtBleft)
    e = H2left - (WH2PD * BtFleft - WH2NPD * FtBleft + WH2contra * FtBright)

    f = Bright - (WH2B * H2left - WUB * Uright)
    g = Uright - (WH2U * H2left - WBU * Bleft)
    h = Bleft - (WH2B * H2right - WUB * Uleft)
    i = Uleft - (WH2U * H2right - WBU * Bright)

    j = DNp15right - (WH2p15 * H2left - WUp15 * Uright - WBp15 * Bleft)
    k = DNp15left - (WH2p15 * H2right - WUp15 * Uleft - WBp15 * Bright)

    l = Vangular - (Wvisual * (DNp15right - DNp15left))
    return c, e, f, g, h, i, j, k, l

In [None]:
def LPTC_system_without_H2(input):
    HSright, HSleft, Bright, Uright, Bleft, Uleft, DNp15right, DNp15left, Vangular = input

    b = HSright - (WHSPD * FtBright  - WHSNPD * BtFright + WHScontra * BtFleft)
    d = HSleft - (WHSPD * FtBleft  - WHSNPD * BtFleft + WHScontra * BtFright)

    f = Bright - (WHSB * HSright - WUB * Uright)
    g = Uright - (WHSU * HSright - WBU * Bleft)
    h = Bleft - (WHSB * HSleft - WUB * Uleft)
    i = Uleft - (WHSU * HSleft - WBU * Bright)

    j = DNp15right - (WHSp15 * HSright - WUp15 * Uright - WBp15 * Bleft)
    k = DNp15left - (WHSp15 * HSleft - WUp15 * Uleft - WBp15 * Bright)

    l = Vangular - (Wvisual * (DNp15right - DNp15left))

    return b, d, f, g, h, i, j, k, l

In [None]:
def LPTC_system(input):
    HSright, H2right, HSleft, H2left, Bright, Uright, Bleft, Uleft, DNp15right, DNp15left, Vangular = input

    b = HSright - (WHSPD * FtBright  - WHSNPD * BtFright + WHScontra * BtFleft + H2_to_HS_cc * H2left)
    c = H2right - (WH2PD * BtFright - WH2NPD * FtBright + WH2contra * FtBleft + HS_to_H2_cc * HSleft)
    d = HSleft - (WHSPD * FtBleft  - WHSNPD * BtFleft + WHScontra * BtFright + H2_to_HS_cc * H2right)
    e = H2left - (WH2PD * BtFleft - WH2NPD * FtBleft + WH2contra * FtBright + HS_to_H2_cc * HSright)

    f = Bright - (WHSB * HSright + WH2B * H2left - WUB * Uright)
    g = Uright - (WHSU * HSright + WH2U * H2left - WBU * Bleft)
    h = Bleft - (WHSB * HSleft + WH2B * H2right - WUB * Uleft)
    i = Uleft - (WHSU * HSleft + WH2U * H2right - WBU * Bright)

    j = DNp15right - (WHSp15 * HSright + WH2p15 * H2left - WUp15 * Uright - WBp15 * Bleft)
    k = DNp15left - (WHSp15 * HSleft + WH2p15 * H2right - WUp15 * Uleft - WBp15 * Bright)

    l = Vangular - (Wvisual * (DNp15right - DNp15left))

    return b, c, d, e, f, g, h, i, j, k, l

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
H2supremum = 50
H2_neg = 5
H2threshold = 7
H2slope = 3
# H2threshold > H2_neg
x = np.arange(-100, 100)
y = ((H2supremum/(1+(np.exp(-((x)-H2threshold)/H2slope)))) - H2_neg)
ax[0].plot(x, y)
ax[0].set_title('H2')
ax[0].axhline(y=0, c='grey', ls='--')
ax[0].axvline(x=0, c='grey', ls='--')
ax[0].set_xlabel('input')
ax[0].set_ylabel('response')

HSsupremum = 80
HS_neg = 40
HSthreshold = 7
HSslope = 20
x = np.arange(-100, 100)
y = ((HSsupremum/(1+(np.exp(-((x)-HSthreshold)/HSslope)))) - HS_neg)
ax[1].plot(x, y)
ax[1].set_title('HS')
ax[1].axhline(y=0, c='grey', ls='--')
ax[1].axvline(x=0, c='grey', ls='--')
ax[1].set_xlabel('input')
# ax[1].set_ylabel('response')

In [None]:
def LPTC_system_spiking(input):
    HSright, H2right, HSleft, H2left, Bright, Uright, Bleft, Uleft, DNp15right, DNp15left, Vangular = input

    supremum = 50
    H2_neg = 5
    H2threshold = 7
    slope = 3
    # H2threshold > H2_neg

    b = HSright - (WHSPD * FtBright  - WHSNPD * BtFright + WHScontra * BtFleft + H2_to_HS_cc * H2left)
    c = H2right - ((supremum/(1+(np.exp(-((WH2PD * BtFright - WH2NPD * FtBright + WH2contra * FtBleft + HS_to_H2_cc * HSleft)-H2threshold)/slope)))) - H2_neg)
    d = HSleft - (WHSPD * FtBleft  - WHSNPD * BtFleft + WHScontra * BtFright + H2_to_HS_cc * H2right)
    e = H2left - ((supremum/(1+(np.exp(-((WH2PD * BtFleft - WH2NPD * FtBleft + WH2contra * FtBright + HS_to_H2_cc * HSright)-H2threshold)/slope)))) - H2_neg)

    f = Bright - (WHSB * HSright + WH2B * H2left - WUB * Uright)
    g = Uright - (WHSU * HSright + WH2U * H2left - WBU * Bleft)
    h = Bleft - (WHSB * HSleft + WH2B * H2right - WUB * Uleft)
    i = Uleft - (WHSU * HSleft + WH2U * H2right - WBU * Bright)

    j = DNp15right - (WHSp15 * HSright + WH2p15 * H2left - WUp15 * Uright - WBp15 * Bleft)
    k = DNp15left - (WHSp15 * HSleft + WH2p15 * H2right - WUp15 * Uleft - WBp15 * Bright)

    l = Vangular - (Wvisual * (DNp15right - DNp15left))

    return b, c, d, e, f, g, h, i, j, k, l

In [None]:
def LPTC_system_nonlinear(input):
    HSright, H2right, HSleft, H2left, Bright, Uright, Bleft, Uleft, DNp15right, DNp15left, Vangular = input

    H2supremum = 50
    H2_neg = 5
    H2threshold = 7
    H2slope = 3
    # H2threshold > H2_neg

    HSsupremum = 80
    HS_neg = 40
    HSthreshold = 7
    HSslope = 20

    b = HSright - ((HSsupremum/(1+(np.exp(-((WHSPD * FtBright  - WHSNPD * BtFright + WHScontra * BtFleft + H2_to_HS_cc * H2left)-HSthreshold)/HSslope)))) - HS_neg)
    c = H2right - ((H2supremum/(1+(np.exp(-((WH2PD * BtFright - WH2NPD * FtBright + WH2contra * FtBleft + HS_to_H2_cc * HSleft)-H2threshold)/H2slope)))) - H2_neg)
    d = HSleft - ((HSsupremum/(1+(np.exp(-((WHSPD * FtBleft  - WHSNPD * BtFleft + WHScontra * BtFright + H2_to_HS_cc * H2right)-HSthreshold)/HSslope)))) - HS_neg)
    e = H2left - ((H2supremum/(1+(np.exp(-((WH2PD * BtFleft - WH2NPD * FtBleft + WH2contra * FtBright + HS_to_H2_cc * HSright)-H2threshold)/H2slope)))) - H2_neg)

    f = Bright - (WHSB * HSright + WH2B * H2left - WUB * Uright)
    g = Uright - (WHSU * HSright + WH2U * H2left - WBU * Bleft)
    h = Bleft - (WHSB * HSleft + WH2B * H2right - WUB * Uleft)
    i = Uleft - (WHSU * HSleft + WH2U * H2right - WBU * Bright)

    j = DNp15right - (WHSp15 * HSright + WH2p15 * H2left - WUp15 * Uright - WBp15 * Bleft)
    k = DNp15left - (WHSp15 * HSleft + WH2p15 * H2right - WUp15 * Uleft - WBp15 * Bright)

    l = Vangular - (Wvisual * (DNp15right - DNp15left))

    return b, c, d, e, f, g, h, i, j, k, l

#### Helper functions

In [None]:
def get_DI(response, cell):
    cw = response[response['stimulus']=='cw'][cell].values
    ccw = response[response['stimulus']=='ccw'][cell].values
    fwd = response[response['stimulus']=='fwd'][cell].values
    bwd = response[response['stimulus']=='bwd'][cell].values

    OFyaw = max(cw, ccw) - min(cw, ccw)
    OFprog = max(fwd, bwd) - min(fwd, bwd)
    DI = (OFprog - OFyaw)/(np.abs(OFyaw) + np.abs(OFprog))
    return DI

In [None]:
def get_AI(left, right):
    AI = (right - left)/(abs(right) + abs(left))
    return AI

#### Solve

##### Cell Colors

In [None]:
color_dict = {'HS':'purple', 'H2':'steelblue', 'bIPS':'darkgreen', 'H2rn':'violet', 'uLPTCrn':'darkorange', 'DNp15':'grey'}

In [None]:
# motion_stimuli = {'left_ftb':np.array([1,-1, 0, 0]), 'right_ftb':np.array([0, 0, 1, -1]), 'left_btf':np.array([-1, 1, 0, 0]), \
#                   'right_btf':np.array([0, 0, -1, 1]), 'cw':np.array([-1, 1, 1, -1]), 'ccw':np.array([1, -1, -1, 1]), \
#                     'fwd':np.array([1, -1, 1, -1]), 'bwd':np.array([-1, 1, -1, 1])}
motion_stimuli = {'left_ftb':np.array([1, 0, 0, 0]), 'right_ftb':np.array([0, 0, 1, 0]), 'left_btf':np.array([0, 1, 0, 0]), \
                  'right_btf':np.array([0, 0, 0, 1]), 'cw':np.array([0, 1, 1, 0]), 'ccw':np.array([1, 0, 0, 1]), \
                    'fwd':np.array([1, 0, 1, 0]), 'bwd':np.array([0, 1, 0, 1])}

#### explore weights

##### change directionality of gap jucntion

In [None]:
set_weights_to_default()
i=0
j=0
not_spiking = np.empty((11, 11,)) * np.nan
spiking = np.empty((11, 11,)) * np.nan
not_spiking_AI = np.empty((11, 11,)) * np.nan
spiking_AI = np.empty((11, 11,)) * np.nan
WHScontra = 0
WH2contra = 0.1
for i, HS_to_H2_cc in enumerate(np.linspace(0, 0.8, 11)):
    for j, H2_to_HS_cc in enumerate(np.linspace(0, 0.8, 11)):
        FtBleft, BtFleft, FtBright, BtFright  = motion_stimuli['right_ftb'];
        sol = fsolve(LPTC_system, np.random.rand(11));
        if np.all(np.isclose(LPTC_system(sol), np.zeros(11))):
            sol = np.round(sol, decimals=3);
            solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                        "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]]};
            not_spiking[i, j] = solution['Vangular'][0]
            not_spiking_AI[i, j] = get_AI(solution['DNp15left'][0], solution['DNp15right'][0])

        sol = fsolve(LPTC_system_spiking, np.random.rand(11));
        # print(np.round(sol, 2))
        if np.all(np.isclose(LPTC_system_spiking(sol), np.zeros(11))):
            sol = np.round(sol, decimals=3);
            solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                        "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]]};
            spiking[i, j] = solution['Vangular'][0]
            spiking_AI[i, j] = get_AI(solution['DNp15left'][0], solution['DNp15right'][0])

###### plots

In [None]:
fig, ax = plt.subplots(1, 1)
max_response = np.nanmax(np.abs(not_spiking))
im = ax.imshow(not_spiking/max_response, origin='lower', cmap=make_color_pastel(ListedColormap(all_color_data['PiYG5'])), vmin=-1, vmax=1);
ax.set_xticks([0, 10], labels=['0', '0.8'])
ax.set_yticks([0, 10], labels=['0', '0.8'])
ax.set_xlabel('HS => H2', fontdict={'fontsize':12})
ax.set_ylabel('H2 => HS', fontdict={'fontsize':12})
# add colorbar
cbar = fig.colorbar(im, ax=ax)
fig.savefig(os.path.join(OUTPUT_DIR, 'gap_junc_noeffect_direction_right_ftb.pdf'), dpi=300)

fig, ax = plt.subplots(1, 1)
max_response = np.nanmax(np.abs(not_spiking))
im = ax.imshow(not_spiking, origin='lower', cmap=make_color_pastel(ListedColormap(all_color_data['PiYG5'])), vmin=-100, vmax=100);
ax.set_xticks([0, 10], labels=['0', '0.8'])
ax.set_yticks([0, 10], labels=['0', '0.8'])
ax.set_xlabel('HS => H2', fontdict={'fontsize':12})
ax.set_ylabel('H2 => HS', fontdict={'fontsize':12})
# add colorbar
cbar = fig.colorbar(im, ax=ax)
# cbar = fig.colorbar(cm.ScalarMappable(norm=colors.Normalize(vmin=0, vmax=200), cmap="Greys"), ax=ax)
# fig.savefig(os.path.join(OUTPUT_DIR, 'gap_junc_noeffect_direction_right_ftb.pdf'), dpi=300)

In [None]:
fig, ax = plt.subplots(1, 1)
max_response = np.nanmax(np.abs(spiking))
ax.imshow(spiking/max_response, origin='lower', cmap='bwr', vmin=-1, vmax=1);
ax.set_xticks([0, 10], labels=['0', '0.8'])
ax.set_yticks([0, 10], labels=['0', '0.8'])
ax.set_xlabel('HS => H2', fontdict={'fontsize':12})
ax.set_ylabel('H2 => HS', fontdict={'fontsize':12})

##### HS and H2 weights

In [None]:
set_weights_to_default()
i=0
j=0
not_spiking = np.empty((11, 11, 11)) * np.nan
spiking = np.empty((11, 11, 11)) * np.nan
WHScontra = 0
WH2contra = 0
for i, WHSp15 in enumerate(np.linspace(0, 1, 11)):
    for j, WH2p15 in enumerate(np.linspace(0, 1, 11)):
        for k, cc in enumerate(np.linspace(0, 1, 11)):
            HS_to_H2_cc = cc
            H2_to_HS_cc = cc
            FtBleft, BtFleft, FtBright, BtFright  = motion_stimuli['right_ftb'] * 10;
        
            sol = fsolve(LPTC_system, np.random.rand(11));
            if np.all(np.isclose(LPTC_system(sol), np.zeros(11))):
                sol = np.round(sol, decimals=3);
                solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                            "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]]};
                not_spiking[i, j, k] = solution['Vangular'][0]

            sol = fsolve(LPTC_system_spiking, np.random.rand(11));
            # print(np.round(sol, 2))
            if np.all(np.isclose(LPTC_system_spiking(sol), np.zeros(11))):
                sol = np.round(sol, decimals=3);
                solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                            "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]]};
                spiking[i, j, k] = solution['Vangular'][0]

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(10, 8))
cc_values = np.linspace(0, 1, 11)
axis = ax.flatten()
for k in range(11):
    max_response = np.amax(np.abs(not_spiking[:,:,k]))
    print(max_response, np.min(not_spiking[:,:,k]))
    axis[k].imshow(not_spiking[:,:,k]/max_response, origin='lower', cmap='bwr', vmin=-1, vmax=1);
    axis[k].set_title('{}'.format(np.round(cc_values[k], 2)))
    # axis[k].set_yticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    # axis[k].set_xticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    axis[k].set_xticks([0, 10], labels=['0', '1'])
    axis[k].set_yticks([0, 10], labels=['0', '1'])
    axis[k].scatter(8, 3, color='grey', s=50)
fig.delaxes(axis[k+1])
fig.text(0.1, 0.5, 'H2 -> DNp15', ha='center', va='center', rotation='vertical', fontdict={'fontsize':12})
fig.text(0.5, 0.08, 'HS -> DNp15', ha='center', va='center', fontdict={'fontsize':12})


In [None]:
fig, ax = plt.subplots(3, 4, figsize=(10, 8))
cc_values = np.linspace(0, 1, 11)
axis = ax.flatten()
print(axis.shape)
for k in range(11):
    max_response = np.amax(np.abs(spiking[:,:,k]))
    print(max_response, np.min(spiking[:,:,k]))
    axis[k].imshow(spiking[:,:,k]/max_response, origin='lower', cmap='bwr', vmin=-1, vmax=1);
    axis[k].set_xticks([0, 10], labels=['0', '1'])
    axis[k].set_yticks([0, 10], labels=['0', '1'])
    axis[k].set_title('{}'.format(np.round(cc_values[k], 2)))
    # axis[k].set_yticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    # axis[k].set_xticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    axis[k].scatter(8, 3, color='grey', s=50)
fig.delaxes(axis[k+1])
fig.text(0.1, 0.5, 'H2 -> DNp15', ha='center', va='center', rotation='vertical', fontdict={'fontsize':12})
fig.text(0.5, 0.08, 'HS -> DNp15', ha='center', va='center', fontdict={'fontsize':12})

##### HS to bIPS and uLPTCrn

In [None]:
set_weights_to_default()
i=0
j=0
WHScontra = 0
WH2contra = 0
not_spiking = np.empty((11, 11, 11)) * np.nan
spiking = np.empty((11, 11, 11)) * np.nan
for i, WHSU in enumerate(np.linspace(0, 1, 11)):
    for j, WHSB in enumerate(np.linspace(0, 1, 11)):
        for k, cc in enumerate(np.linspace(0, 1, 11)):
            HS_to_H2_cc = cc
            H2_to_HS_cc = cc
            FtBleft, BtFleft, FtBright, BtFright  = motion_stimuli['right_ftb'];

            sol = fsolve(LPTC_system, np.random.rand(11));
            if np.all(np.isclose(LPTC_system(sol), np.zeros(11))):
                sol = np.round(sol, decimals=3);
                solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                            "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]]};
                not_spiking[i, j, k] = solution['Vangular'][0]

            sol = fsolve(LPTC_system_spiking, np.random.rand(11));
            if np.all(np.isclose(LPTC_system_spiking(sol), np.zeros(11))):
                sol = np.round(sol, decimals=3);
                solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                            "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]]};
                spiking[i, j, k] = solution['Vangular'][0]

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(10, 8))
cc_values = np.linspace(0, 1, 11)
axis = ax.flatten()
for k in range(11):
    max_response = np.amax(np.abs(not_spiking[:,:,k]))
    print(max_response, np.min(not_spiking[:,:,k]))
    axis[k].imshow(not_spiking[:,:,k]/max_response, origin='lower', cmap='bwr', vmin=-1, vmax=1);
    axis[k].set_title('{}'.format(np.round(cc_values[k], 2)))
    # axis[k].set_yticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    # axis[k].set_xticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    axis[k].set_xticks([0, 10], labels=['0', '1'])
    axis[k].set_yticks([0, 10], labels=['0', '1'])
    axis[k].scatter(5, 2, color='grey', s=50)
fig.delaxes(axis[k+1])
fig.text(0.1, 0.5, 'uLPTCrn -> DNp15', ha='center', va='center', rotation='vertical', fontdict={'fontsize':12})
fig.text(0.5, 0.08, 'bIPS -> DNp15', ha='center', va='center', fontdict={'fontsize':12})

##### uLPTC and bIPS

In [None]:
set_weights_to_default()
i=0
j=0
WHScontra = 0
WH2contra = 0
not_spiking = np.empty((11, 11, 11)) * np.nan
spiking = np.empty((11, 11, 11)) * np.nan
for i, WUp15 in enumerate(np.linspace(0, 1, 11)):
    for j, WBp15 in enumerate(np.linspace(0, 1, 11)):
        for k, cc in enumerate(np.linspace(0, 1, 11)):
            HS_to_H2_cc = cc
            H2_to_HS_cc = cc
            FtBleft, BtFleft, FtBright, BtFright  = motion_stimuli['right_ftb'] * 10;

            sol = fsolve(LPTC_system, np.random.rand(11));
            if np.all(np.isclose(LPTC_system(sol), np.zeros(11))):
                sol = np.round(sol, decimals=3);
                solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                            "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]]};
                not_spiking[i, j, k] = solution['Vangular'][0]

            sol = fsolve(LPTC_system_spiking, np.random.rand(11));
            if np.all(np.isclose(LPTC_system_spiking(sol), np.zeros(11))):
                sol = np.round(sol, decimals=3);
                solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                            "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]]};
                spiking[i, j, k] = solution['Vangular'][0]

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(10, 8))
cc_values = np.linspace(0, 1, 11)
axis = ax.flatten()
for k in range(11):
    max_response = np.amax(np.abs(not_spiking[:,:,k]))
    print(max_response, np.min(not_spiking[:,:,k]))
    axis[k].imshow(not_spiking[:,:,k]/max_response, origin='lower', cmap='bwr', vmin=-1, vmax=1);
    axis[k].set_title('{}'.format(np.round(cc_values[k], 2)))
    # axis[k].set_yticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    # axis[k].set_xticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    axis[k].set_xticks([0, 10], labels=['0', '1'])
    axis[k].set_yticks([0, 10], labels=['0', '1'])
    axis[k].scatter(5, 2, color='grey', s=50)
fig.delaxes(axis[k+1])
fig.text(0.1, 0.5, 'uLPTCrn -> DNp15', ha='center', va='center', rotation='vertical', fontdict={'fontsize':12})
fig.text(0.5, 0.08, 'bIPS -> DNp15', ha='center', va='center', fontdict={'fontsize':12})

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(10, 8))
cc_values = np.linspace(0, 1, 11)
axis = ax.flatten()
print(axis.shape)
for k in range(11):
    max_response = np.amax(np.abs(spiking[:,:,k]))
    print(max_response, np.min(spiking[:,:,k]))
    axis[k].imshow(spiking[:,:,k]/max_response, origin='lower', cmap='bwr', vmin=-1, vmax=1);
    axis[k].set_xticks([0, 10], labels=['0', '1'])
    axis[k].set_yticks([0, 10], labels=['0', '1'])
    axis[k].set_title('{}'.format(np.round(cc_values[k], 2)))
    # axis[k].set_yticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    # axis[k].set_xticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    axis[k].scatter(5, 2, color='grey', s=50)
fig.delaxes(axis[k+1])
fig.text(0.1, 0.5, 'uLPTCrn -> DNp15', ha='center', va='center', rotation='vertical', fontdict={'fontsize':12})
fig.text(0.5, 0.08, 'bIPS -> DNp15', ha='center', va='center', fontdict={'fontsize':12})

##### uLPTC and HSp15 weights

In [None]:
set_weights_to_default()
i=0
j=0
WHScontra = 0
WH2contra = 0.1
not_spiking = np.empty((11, 11, 11)) * np.nan
spiking = np.empty((11, 11, 11)) * np.nan
for i, WUp15 in enumerate(np.linspace(0, 1, 11)):
    for j, WHSp15 in enumerate(np.linspace(0, 1, 11)):
        for k, cc in enumerate(np.linspace(0, 1, 11)):
            HS_to_H2_cc = cc
            H2_to_HS_cc = cc
            FtBleft, BtFleft, FtBright, BtFright  = motion_stimuli['right_ftb'];

            sol = fsolve(LPTC_system, np.random.rand(11));
            if np.all(np.isclose(LPTC_system(sol), np.zeros(11))):
                sol = np.round(sol, decimals=3);
                solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                            "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]]};
                not_spiking[i, j, k] = solution['Vangular'][0]

            sol = fsolve(LPTC_system_nonlinear, np.random.rand(11));
            if np.all(np.isclose(LPTC_system_nonlinear(sol), np.zeros(11))):
                sol = np.round(sol, decimals=3);
                solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                            "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]]};
                spiking[i, j, k] = solution['Vangular'][0]

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(10, 8))
cc_values = np.linspace(0, 1, 11)
axis = ax.flatten()
for k in range(11):
    max_response = np.amax(np.abs(not_spiking[:,:,k]))
    print(max_response, np.min(not_spiking[:,:,k]))
    axis[k].imshow(not_spiking[:,:,k]/max_response, origin='lower', cmap=make_color_pastel(ListedColormap(all_color_data['PiYG5'])), vmin=-0.5, vmax=0.5);
    axis[k].set_title('{}'.format(np.round(cc_values[k], 2)))
    # axis[k].set_yticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    # axis[k].set_xticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    axis[k].set_xticks([0, 10], labels=['0', '1'])
    axis[k].set_yticks([0, 10], labels=['0', '1'])
    axis[k].scatter(8, 2, color='grey', s=50)
fig.delaxes(axis[k+1])
fig.text(0.1, 0.5, 'uLPTCrn -> DNp15', ha='center', va='center', rotation='vertical', fontdict={'fontsize':12})
fig.text(0.5, 0.08, 'HS -> DNp15', ha='center', va='center', fontdict={'fontsize':12})
fig.savefig(os.path.join(OUTPUT_DIR, 'gapjunc_nonbio_weight_direction_right_btf.pdf'), dpi=300)

In [None]:
fig, ax = plt.subplots(1, 1)
max_response = np.nanmax(np.abs(not_spiking[:, :, 2]))
im = ax.imshow(not_spiking[:, :, 2]/max_response, origin='lower', cmap=make_color_pastel(ListedColormap(all_color_data['PiYG5'])), vmin=-0.5, vmax=0.5);
ax.set_xticks([0, 10], labels=['0', '1'])
ax.set_yticks([0, 10], labels=['0', '1'])
ax.set_xlabel('HS -> DNp15', fontdict={'fontsize':12})
ax.set_ylabel('uLPTCrn -> DNp15', fontdict={'fontsize':12})
ax.scatter(8, 2, color='grey', s=50)
# add colorbar
cbar = fig.colorbar(im, ax=ax)
# cbar = fig.colorbar(cm.ScalarMappable(norm=colors.Normalize(vmin=0, vmax=200), cmap="Greys"), ax=ax)
fig.savefig(os.path.join(OUTPUT_DIR, 'nonbio_weight_direction_right_btf.pdf'), dpi=300)

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(10, 8))
cc_values = np.linspace(0, 1, 11)
axis = ax.flatten()
for k in range(11):
    max_response = np.amax(np.abs(spiking[:,:,k]))
    print(max_response, np.min(spiking[:,:,k]))
    axis[k].imshow(spiking[:,:,k]/max_response, origin='lower', cmap='bwr', vmin=-1, vmax=1);
    axis[k].set_title('{}'.format(np.round(cc_values[k], 2)))
    # axis[k].set_yticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    # axis[k].set_xticks(np.arange(0, 11), np.round(np.linspace(0, 1, 11), 2));
    axis[k].set_xticks([0, 10], labels=['0', '1'])
    axis[k].set_yticks([0, 10], labels=['0', '1'])
    axis[k].scatter(8, 2, color='grey', s=50)
fig.delaxes(axis[k+1])
fig.text(0.1, 0.5, 'uLPTCrn -> DNp15', ha='center', va='center', rotation='vertical', fontdict={'fontsize':12})
fig.text(0.5, 0.08, 'HS -> DNp15', ha='center', va='center', fontdict={'fontsize':12})

##### Different stimuli with different gap junction weights

In [None]:
set_weights_to_default()
total_solution = pd.DataFrame.from_dict({"HSright":[], "H2right":[], "HSleft":[], "H2left":[], "Bright":[], "Uright":[], \
                "Bleft":[], "Uleft":[], "DNp15right":[], "DNp15left":[], "Vangular":[], 'stimulus':[], 'gap_junction':[]})
total_solution_without_H2 = pd.DataFrame.from_dict({"HSright":[], "H2right":[], "HSleft":[], "H2left":[], "Bright":[], "Uright":[], \
                "Bleft":[], "Uleft":[], "DNp15right":[], "DNp15left":[], "Vangular":[], 'stimulus':[], 'gap_junction':[]})
WHScontra = 0
WH2contra = 0.1
for i, keys in enumerate(['left_ftb', 'left_btf', 'cw', 'ccw', 'fwd', 'right_btf','right_ftb', 'bwd']):
    FtBleft, BtFleft, FtBright, BtFright  = motion_stimuli[keys];
    print(FtBleft, BtFleft, FtBright, BtFright)
    for cc in np.linspace(0, 0.8, 11):
        HS_to_H2_cc = cc
        H2_to_HS_cc = cc
        sol = fsolve(LPTC_system, np.random.rand(11));
        sol = np.round(sol, decimals=3);
        solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                    "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]], 'stimulus':[keys], 'gap_junction':[cc]}
        solution = pd.DataFrame.from_dict(solution)
        total_solution = pd.concat([total_solution, solution], ignore_index=True)

        sol = fsolve(LPTC_system_without_H2, np.random.rand(9));
        sol = np.round(sol, decimals=3);
        solution = {"HSright":[sol[0]], "HSleft":[sol[1]], "Bright":[sol[2]], "Uright":[sol[3]], \
                    "Bleft":[sol[4]], "Uleft":[sol[5]], "DNp15right":[sol[6]], "DNp15left":[sol[7]], "Vangular":[sol[8]], 'stimulus':[keys], 'gap_junction':[cc]}
        solution = pd.DataFrame.from_dict(solution)
        total_solution_without_H2 = pd.concat([total_solution_without_H2, solution], ignore_index=True)

##### Discrimination Index

In [None]:
fig, ax = plt.subplots(1,1)
for cell in ['H2', 'HS', 'bIPS', 'uLPTCrn', 'DNp15']:
    DI_gap = {}
    for cc in total_solution.gap_junction.unique():
        if cell=='HS':
            DI = get_DI(total_solution.loc[total_solution['gap_junction']==cc], 'HSright')
        elif cell=='H2':
            DI = get_DI(total_solution.loc[total_solution['gap_junction']==cc], 'H2right')
        elif cell=='bIPS':
            DI = get_DI(total_solution.loc[total_solution['gap_junction']==cc], 'Bleft')
        elif cell=='uLPTCrn':
            DI = get_DI(total_solution.loc[total_solution['gap_junction']==cc], 'Uright')
        elif cell=='DNp15':
            DI = get_DI(total_solution.loc[total_solution['gap_junction']==cc], 'DNp15right')
        DI_gap[str(cc)] = DI
    ax.plot(list(DI_gap.values()), label=cell, c=color_dict[cell])
    print(cell, list(DI_gap.values()))
    ax.set_xticks(np.arange(11), labels=[str(np.round(x, 2)) for x in np.linspace(0, 1, 11)])
ax.legend()
ax.axhline(0, color='grey', linestyle='--')
ax.set_ylim(top=0.5, bottom=-1)
ax.set_xlim(left=0, right=10)
ax.set_xticks(np.arange(0, 11, 2), labels=[str(np.round(x,2)) for x in np.linspace(0, 1, 6)])
ax.set_yticks([0.5, 0, -0.5, -1], labels=['0.5', '0', '-0.5', '-1'])
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
## set position of bottom spine 
ax.spines['bottom'].set_position(('axes', -0.05))
ax.spines['left'].set_position(('axes', -0.05))
ax.set_xlabel('Gap junction strength', fontdict={'fontsize':12})
ax.set_ylabel('Discrimination Index', fontdict={'fontsize':12})
ax.grid()
fig.savefig(os.path.join(OUTPUT_DIR, 'gap_junc_discrimination_index.pdf'), dpi=300)

##### Plot for paper

In [None]:
set_weights_to_default()
total_solution = pd.DataFrame.from_dict({"HSright":[], "H2right":[], "HSleft":[], "H2left":[], "Bright":[], "Uright":[], \
                "Bleft":[], "Uleft":[], "DNp15right":[], "DNp15left":[], "Vangular":[], 'stimulus':[], 'model_type':[]})

for i, keys in enumerate(['left_ftb', 'left_btf', 'cw', 'ccw', 'fwd', 'right_btf','right_ftb', 'bwd']):
    FtBleft, BtFleft, FtBright, BtFright  = motion_stimuli[keys];
    set_weights_to_default()
    HS_to_H2_cc = 0
    H2_to_HS_cc = 0
    sol = fsolve(LPTC_system, np.random.rand(11));
    sol = np.round(sol, decimals=3);
    solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]], 'stimulus':[keys], 'model_type':['original']}
    print(solution)
    solution = pd.DataFrame.from_dict(solution)
    total_solution = pd.concat([total_solution, solution], ignore_index=True)

    set_weights_to_default()
    WHScontra = 0
    WH2contra = 0.1
    HS_to_H2_cc = 0.5
    H2_to_HS_cc = 0.5
    sol = fsolve(LPTC_system, np.random.rand(11));
    sol = np.round(sol, decimals=3);
    solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]], 'stimulus':[keys], 'model_type':['gap_0.5']}
    print(solution)
    solution = pd.DataFrame.from_dict(solution)
    total_solution = pd.concat([total_solution, solution], ignore_index=True)

    set_weights_to_default()
    WHScontra = 0
    WH2contra = 0.1
    HS_to_H2_cc = 0.3
    H2_to_HS_cc = 0.3
    sol = fsolve(LPTC_system, np.random.rand(11));
    sol = np.round(sol, decimals=3);
    solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]], 'stimulus':[keys], 'model_type':['gap_0.3']}
    solution = pd.DataFrame.from_dict(solution)
    total_solution = pd.concat([total_solution, solution], ignore_index=True)

    set_weights_to_default()
    WHScontra = 0
    WH2contra = 0.1
    HS_to_H2_cc = 0
    H2_to_HS_cc = 0    
    sol = fsolve(LPTC_system, np.random.rand(11));
    sol = np.round(sol, decimals=3);
    solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]], 'stimulus':[keys], 'model_type':['gap_0.0']}
    solution = pd.DataFrame.from_dict(solution)
    total_solution = pd.concat([total_solution, solution], ignore_index=True)

    set_weights_to_default()
    WHScontra = 0
    WH2contra = 0.1
    HS_to_H2_cc = 0.5
    H2_to_HS_cc = 0.5
    WHSB = 0
    WH2B = 0
    WHSU = 0
    WH2U = 0
    WHSp15 = 0
    WH2p15 = 0.3
    sol = fsolve(LPTC_system, np.random.rand(11));
    sol = np.round(sol, decimals=3);
    solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]], 'stimulus':[keys], 'model_type':['HS_0.0']}
    solution = pd.DataFrame.from_dict(solution)
    total_solution = pd.concat([total_solution, solution], ignore_index=True)

    set_weights_to_default()
    WHScontra = 0
    WH2contra = 0.1
    HS_to_H2_cc = 0.5
    H2_to_HS_cc = 0.5
    WHSB = 0.9
    WH2B = 0
    WHSU = 0.7
    WH2U = 0
    WHSp15 = 0.8
    WH2p15 = 0
    sol = fsolve(LPTC_system, np.random.rand(11));
    sol = np.round(sol, decimals=3);
    solution = {"HSright":[sol[0]], "H2right":[sol[1]], "HSleft":[sol[2]], "H2left":[sol[3]], "Bright":[sol[4]], "Uright":[sol[5]], \
                "Bleft":[sol[6]], "Uleft":[sol[7]], "DNp15right":[sol[8]], "DNp15left":[sol[9]], "Vangular":[sol[10]], 'stimulus':[keys], 'model_type':['H2_0.0']}
    solution = pd.DataFrame.from_dict(solution)
    total_solution = pd.concat([total_solution, solution], ignore_index=True)

In [None]:
total_data = pd.DataFrame({'celltype':[], 'model_type':[], 'DI':[]})
for cell in ['H2', 'HS', 'bIPS', 'uLPTCrn', 'DNp15']:
    for model in ['original','gap_0.5', 'gap_0.3', 'gap_0.0', 'HS_0.0', 'H2_0.0']:
        if cell=='HS':
            DI = get_DI(total_solution.loc[total_solution['model_type']==model], 'HSright')
            total_data = pd.concat([total_data, pd.DataFrame({'celltype':[cell], 'model_type':[model], 'DI':DI})], ignore_index=True)
        elif cell=='H2':
            DI = get_DI(total_solution.loc[total_solution['model_type']==model], 'H2right')
            total_data = pd.concat([total_data, pd.DataFrame({'celltype':[cell], 'model_type':[model], 'DI':DI})], ignore_index=True)
        elif cell=='bIPS':
            DI = get_DI(total_solution.loc[total_solution['model_type']==model], 'Bleft')
            total_data = pd.concat([total_data, pd.DataFrame({'celltype':[cell], 'model_type':[model], 'DI':DI})], ignore_index=True)
        elif cell=='uLPTCrn':
            DI = get_DI(total_solution.loc[total_solution['model_type']==model], 'Uright')
            total_data = pd.concat([total_data, pd.DataFrame({'celltype':[cell], 'model_type':[model], 'DI':DI})], ignore_index=True)
        elif cell=='DNp15':
            DI = get_DI(total_solution.loc[total_solution['model_type']==model], 'DNp15right')
            total_data = pd.concat([total_data, pd.DataFrame({'celltype':[cell], 'model_type':[model], 'DI':DI})], ignore_index=True)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(9, 3))
sns.barplot(data=total_data, x='model_type', y='DI', hue='celltype', ax=ax, palette=color_dict, order=['original','gap_0.0', 'gap_0.3', 'HS_0.0', 'H2_0.0'])
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
## set position of bottom spine 
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_position(('axes', -0.05))
ax.set_xlabel('Models', fontdict={'fontsize':12})
ax.set_ylabel('Discrimination Index', fontdict={'fontsize':12})
ax.set_ylim(top=0.2, bottom=-1)

fig.savefig(os.path.join(OUTPUT_DIR, 'gap_junc_discrimination_index_models.pdf'), dpi=300)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(5, 3))
sns.barplot(data=total_data.loc[total_data['celltype']=='DNp15'], x='model_type', y='DI', color=color_dict['DNp15'], ax=ax, order=['gap_0.5','gap_0.0', 'HS_0.0', 'H2_0.0'])
ax.set_ylim(top=0, bottom=-1)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
## set position of bottom spine 
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_position(('axes', -0.05))
ax.set_xlabel('Models', fontdict={'fontsize':12})
ax.set_ylabel('Discrimination Index', fontdict={'fontsize':12})
ax.set_title('DNp15', fontdict={'fontsize':12})
ax.set_xticklabels(['with gap\njunction', 'without\ngap junction', 'without \nHS', 'without \nH2'], rotation=0)
fig.savefig(os.path.join(OUTPUT_DIR, 'gap_junc_discrimination_index_models_DNp15.pdf'), dpi=300)

### Sympy versions

In [None]:
# HSright, H2right, HSleft, H2left, Bright, Uright, Bleft, Uleft, DNp15right, DNp15left, Vangular = \
#     sym.symbols('HSright H2right HSleft H2left Bright Uright Bleft Uleft DNp15right DNp15left Vangular', real=True)
# FtBright, BtFright, FtBleft, BtFleft = 10, -10, 10, -10
# LPTC_system = (HSright - WHSPD * FtBright  - WHSNPD * BtFright + WHScontra * BtFleft + WH2PD * H2_to_HS_cc * BtFleft,
#                 H2right - WH2PD * BtFright - WH2NPD * FtBright + WH2contra * FtBleft + WHSPD * HS_to_H2_cc * FtBleft,
#                 HSleft - WHSPD * FtBleft  - WHSNPD * BtFleft + WHScontra * BtFright + WH2PD * H2_to_HS_cc * BtFright,
#                 H2left - WH2PD * BtFleft - WH2NPD * FtBleft + WH2contra * FtBright + WHSPD * HS_to_H2_cc * FtBright,
#                 Bright - (WHSB * HSright) + (WH2B * H2left) - (WUB * Uright),
#                 Uright - WHSU * HSright + WH2U * H2left - WBU * Bleft,
#                 Bleft - WHSB * HSleft + WH2B * H2right - WUB * Uleft,
#                 Uleft - WHSU * HSleft + WH2U * H2right - WBU * Bright,
#                 DNp15right - WHSp15 * HSright + WH2p15 * H2left - WUp15 * Uright - WBp15 * Bleft,
#                 DNp15left - WHSp15 * HSleft + WH2p15 * H2right - WUp15 * Uleft - WBp15 * Bright,
#                 Vangular - Wvisual * (DNp15right - DNp15left) + noise)

# outputs = (HSright, H2right, HSleft, H2left, Bright, Uright, Bleft, Uleft, DNp15right, DNp15left, Vangular)

In [None]:
# HSright, H2right, HSleft, H2left, Bright, Uright, Bleft, Uleft, DNp15right, DNp15left, Vangular = \
#     sym.symbols('HSright H2right HSleft H2left Bright Uright Bleft Uleft DNp15right DNp15left Vangular')
# slope = 1
# FtBright, BtFright, FtBleft, BtFleft = 10, -10, 10, -10
# LPTC_system_spiking = (HSright - WHSPD * FtBright  - WHSNPD * BtFright + WHScontra * BtFleft + WH2PD * H2_to_HS_cc * BtFleft,
#                 H2right - sym.exp(sym.log((WH2PD * BtFright - WH2NPD * FtBright + WH2contra * FtBleft + WHSPD * HS_to_H2_cc * FtBright - H2threshold) - slope)**2),
#                 HSleft - WHSPD * FtBleft  - WHSNPD * BtFleft + WHScontra * BtFright + WH2PD * H2_to_HS_cc * BtFright,
#                 H2left - sym.exp(sym.log((WH2PD * BtFleft - WH2NPD * FtBleft + WH2contra * FtBright + WHSPD * HS_to_H2_cc * FtBright - H2threshold) - slope)**2),
#                 Bright - WHSB * HSright + WH2B * H2left - WUB * Uright,
#                 Uright - WHSU * HSright + WH2U * H2left - WBU * Bleft,
#                 Bleft - WHSB * HSleft + WH2B * H2right - WUB * Uleft,
#                 Uleft - WHSU * HSleft + WH2U * H2right - WBU * Bright,
#                 DNp15right - WHSp15 * HSright + WH2p15 * H2left - WUp15 * Uright - WBp15 * Bleft,
#                 DNp15left - WHSp15 * HSleft + WH2p15 * H2right - WUp15 * Uleft - WBp15 * Bright,
#                 Vangular - Wvisual * (DNp15right - DNp15left) + noise)

# outputs = (HSright, H2right, HSleft, H2left, Bright, Uright, Bleft, Uleft, DNp15right, DNp15left, Vangular)