In [1]:
%matplotlib

import sys
sys.path.append('..')

from brian2 import defaultclock, ms, pA, nA, Hz, seed, mV, second
from params import paramsJercog as p
from params import (paramsJercogEphysBuono22, paramsJercogEphysBuono4, paramsJercogEphysBuono5, paramsJercogEphysBuono6,
                    paramsJercogEphysBuono7)
import numpy as np
from generate import convert_kicks_to_current_series, norm_weights, weight_matrix_from_flat_inds_weights, adjacency_indices_within
from trainer import JercogTrainer
from results import Results
import matplotlib.pyplot as plt
from plot import weight_matrix
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

Using matplotlib backend: Qt5Agg


INFO       Cache size for target "cython": 15612 MB.
You can call "clear_cache('cython')" to delete all files from the cache or manually delete files in the "C:\Users\mikejseay\.cython\brian_extensions" directory. [brian2]


In [2]:
rngSeed = 43
defaultclock.dt = p['dt']

p['useNewEphysParams'] = True
ephysParams = paramsJercogEphysBuono7.copy()
p['useSecondPopExc'] = False

if p['useNewEphysParams']:
    # remove protected keys from the dict whose params are being imported
    protectedKeys = ('nUnits', 'propInh', 'duration')
    for pK in protectedKeys:
        del ephysParams[pK]
    p.update(ephysParams)

p['useRule'] = 'upCrit'
p['nameSuffix'] = ''
p['saveFolder'] = 'C:/Users/mikejseay/Documents/BrianResults/'
p['saveWithDate'] = True
p['useOldWeightMagnitude'] = True
p['disableWeightScaling'] = True
p['applyLogToFR'] = False
p['setMinimumBasedOnBalance'] = False
p['recordMovieVariables'] = False
p['downSampleVoltageTo'] = 1 * ms
p['stateVariableDT'] = 1 * ms
p['recordAllVoltage'] = True

# simulation params
p['nUnits'] = 2e3
p['propConnect'] = 0.25

# p['initWeightMethod'] = 'guessBuono7Weights2e3p025'
# p['initWeightMethod'] = 'guessBuono7Weights2e3p025SlightLowTuned'
p['initWeightMethod'] = 'resumePrior'
p['initWeightPrior'] = 'buonoEphysBen1_2000_0p25_cross-homeo-pre-outer-homeo_guessBuono7Weights2e3p025SlightLow__2021-09-04-08-20_results'
p['kickType'] = 'spike'  # kick or spike
p['nUnitsToSpike'] = int(np.round(0.05 * p['nUnits']))
p['timeToSpike'] = 100 * ms
p['timeAfterSpiked'] = 3000 * ms
p['spikeInputAmplitude'] = 0.96  # in nA
p['allowAutapses'] = False

# params not important unless using "kick" instead of "spike"
p['propKicked'] = 0.1
p['onlyKickExc'] = True
p['kickTimes'] = [100 * ms]
p['kickSizes'] = [1]
iKickRecorded = convert_kicks_to_current_series(p['kickDur'], p['kickTau'],
                                                p['kickTimes'], p['kickSizes'], p['duration'], p['dt'])
p['iKickRecorded'] = iKickRecorded

# boring params
p['nIncInh'] = int(p['propConnect'] * p['propInh'] * p['nUnits'])
p['nIncExc'] = int(p['propConnect'] * (1 - p['propInh']) * p['nUnits'])
indUnkickedExc = int(p['nUnits'] - (p['propInh'] * p['nUnits']) - 1)
p['indsRecordStateExc'].append(indUnkickedExc)
p['nExc'] = int(p['nUnits'] * (1 - p['propInh']))
p['nInh'] = int(p['nUnits'] * p['propInh'])

if p['recordAllVoltage']:
    p['indsRecordStateExc'] = list(range(p['nExc']))
    p['indsRecordStateInh'] = list(range(p['nInh']))

# END OF PARAMS

# set RNG seeds...
p['rngSeed'] = rngSeed
rng = np.random.default_rng(rngSeed)  # for numpy
seed(rngSeed)  # for Brian... will insert code to set the random number generator seed into the generated code
p['rng'] = rng

JT = JercogTrainer(p)

if p['initWeightMethod'] == 'resumePrior':
    PR = Results()
    PR.init_from_file(p['initWeightPrior'], p['saveFolder'])
    p = dict(list(p.items()) + list(PR.p.items()))
    # p = PR.p.copy()  # note this completely overwrites all settings above
    p['nameSuffix'] = p['initWeightMethod'] + p['nameSuffix']  # a way to remember what it was...
    if 'seed' in p['nameSuffix']:  # this will only work for single digit seeds...
        rngSeed = int(p['nameSuffix'][p['nameSuffix'].find('seed') + 4])
    p['initWeightMethod'] = 'resumePrior'  # and then we put this back...
else:
    PR = None

In [63]:
# modify for schematic

nUnits = 20
pConn = 0.25
pChanged = 0.2
nChanged = int(np.round(nUnits * pChanged))

print(nChanged)

preInds, postInds = adjacency_indices_within(nUnits, pConn)

nConnections = preInds.size
weights = norm_weights(nConnections)

wMat = weight_matrix_from_flat_inds_weights(nUnits, nUnits, preInds, postInds, weights)

4


In [64]:
values = wMat
useCmap = 'Reds'
limsMethod = 'custom'
xlabel = ''
ylabel = ''
clabel = ''
vlims = (0, weights.max())

f.clf()
f, ax = plt.subplots(figsize=(6, 6), num=1)
i = ax.imshow(values,
              cmap=getattr(plt.cm, useCmap),
              aspect='auto',
              interpolation='none')
ax.set(xlabel=xlabel, ylabel=ylabel)

if limsMethod == 'absmax':
    vmax = np.nanmax(np.fabs(values))
    vmin = -vmax
elif limsMethod == 'minmax':
    vmax, vmin = np.nanmax(values), np.nanmin(values)
elif limsMethod == 'custom':
    vmin, vmax = vlims

i.set_clim(vmin, vmax)

ax.vlines(nUnits - nChanged - 0.5, -0.5, nUnits - 0.5, color='k')
ax.hlines(nUnits - nChanged - 0.5, -0.5, nUnits - 0.5, color='k')

# cb = plt.colorbar(i, ax=ax)
# cb.ax.set_ylabel(clabel, rotation=270)

<matplotlib.collections.LineCollection at 0x284de504248>

In [60]:
p['nUnits'] = nUnits

p['nUnitsSecondPopExc'] = int(np.round(0.05 * p['nUnits']))
p['startIndSecondPopExc'] = p['nUnitsToSpike']
p['removePropConn'] = 0.1
p['addBackRemovedConns'] = True


startIndExc2 = p['startIndSecondPopExc']
endIndExc2 = p['startIndSecondPopExc'] + p['nUnitsSecondPopExc']

E2E1Inds = \
    np.where(
        np.logical_and(PR.preEE >= endIndExc2, np.logical_and(PR.posEE >= startIndExc2, PR.posEE < endIndExc2)))[0]
E1E2Inds = \
    np.where(
        np.logical_and(PR.posEE >= endIndExc2, np.logical_and(PR.preEE >= startIndExc2, PR.preEE < endIndExc2)))[0]

removeE2E1Inds = np.random.choice(E2E1Inds, int(np.round(E2E1Inds.size * p['removePropConn'])), replace=False)
removeE1E2Inds = np.random.choice(E1E2Inds, int(np.round(E1E2Inds.size * p['removePropConn'])), replace=False)

removeInds = np.concatenate((removeE2E1Inds, removeE1E2Inds))

# save some info...
wEEInit = weight_matrix_from_flat_inds_weights(PR.p['nExc'], PR.p['nExc'], PR.preEE, PR.posEE, PR.wEE_final)

In [40]:
# making a different version of the weight matrix where the modify conn population got moved

preEEInitR = PR.preEE.copy()
posEEInitR = PR.posEE.copy()

swapDict = dict()
for n in range(100, 200):
    swapDict[n] = n + 1400
for n in range(1500, 1600):
    swapDict[n] = n - 1400

for connInd in range(preEEInitR.size):
    preInd = PR.preEE[connInd]
    posInd = PR.posEE[connInd]
    if preInd in swapDict:
        preEEInitR[connInd] = swapDict[preInd]
    if posInd in swapDict:
        posEEInitR[connInd] = swapDict[posInd]

In [41]:
wEEInitR = weight_matrix_from_flat_inds_weights(PR.p['nExc'], PR.p['nExc'], preEEInitR, posEEInitR, PR.wEE_final)

In [42]:
weightsSaved = PR.wEE_final[removeInds]  # save the weights to be used below

PR.preEE = np.delete(PR.preEE, removeInds, None)
PR.posEE = np.delete(PR.posEE, removeInds, None)
PR.wEE_final = np.delete(PR.wEE_final, removeInds, None)

if p['addBackRemovedConns']:
    nConnRemoved = removeInds.size
    propAddedConnToE2E2 = p['nUnitsSecondPopExc'] / (p['nExc'] - p['nUnitsToSpike'])
    # propAddedConnToE2E2 = (p['nUnitsSecondPopExc'] / (p['nExc'] - p['nUnitsToSpike'])) ** 2  # arguably should be this
    nConnAddedToE2E2 = int(np.round(propAddedConnToE2E2 * nConnRemoved))
    nConnAddedToE1E1 = nConnRemoved - nConnAddedToE2E2
    E2E2Inds = np.where(np.logical_and(np.logical_and(PR.preEE >= startIndExc2, PR.preEE < endIndExc2),
                                       np.logical_and(PR.posEE >= startIndExc2, PR.posEE < endIndExc2)))[0]
    E1E1Inds = np.where(np.logical_and(PR.preEE >= endIndExc2, PR.posEE >= endIndExc2))[0]

    # construct a probability array that is the shape of E2E2, fill it with the value of 1 / (nExc2*nExc2 - nExistingConns - nExc2)
    # in positions where there are not already connections... set existing connection and the diagonal to 0 (should sum to 1)
    # and same for E1E1
    # this will allow us to choose some number of new synapses to add, where there are not already connections, and not on the diag

    nExc1 = p['nExc'] - p['nUnitsSecondPopExc'] - p['nUnitsToSpike']
    nExc2 = p['nUnitsSecondPopExc']
    probabilityArrayE2E2 = np.full((nExc2, nExc2), 1 / (nExc2 * nExc2 - E2E2Inds.size - nExc2))
    probabilityArrayE2E2[PR.preEE[E2E2Inds] - p['nUnitsToSpike'], PR.posEE[E2E2Inds] - p['nUnitsToSpike']] = 0
    probabilityArrayE2E2[np.diag_indices_from(probabilityArrayE2E2)] = 0
    probabilityArrayE1E1 = np.full((nExc1, nExc1), 1 / (nExc1 * nExc1 - E1E1Inds.size - nExc1))
    probabilityArrayE1E1[
        PR.preEE[E1E1Inds] - nExc2 - p['nUnitsToSpike'], PR.posEE[E1E1Inds] - nExc2 - p['nUnitsToSpike']] = 0
    probabilityArrayE1E1[np.diag_indices_from(probabilityArrayE1E1)] = 0

    indicesE2E2Flat = np.random.choice(nExc2 ** 2, nConnAddedToE2E2, replace=False, p=probabilityArrayE2E2.ravel())
    indicesE1E1Flat = np.random.choice(nExc1 ** 2, nConnAddedToE1E1, replace=False, p=probabilityArrayE1E1.ravel())

    propConnE2E1 = (E2E1Inds.size - removeE2E1Inds.size) / (nExc2 * nExc1)
    propConnE1E2 = (E1E2Inds.size - removeE1E2Inds.size) / (nExc2 * nExc1)
    propConnE2E2 = (E2E2Inds.size + nConnAddedToE2E2) / (nExc2 * nExc2)
    propConnE1E1 = (E1E1Inds.size + nConnAddedToE1E1) / (nExc1 * nExc1)

    print(propConnE2E1)
    print(propConnE1E2)
    print(propConnE2E2)
    print(propConnE1E1)

    # must add  + p['nUnitsToSpike'] for E2E2 and  + p['nUnitsToSpike'] + nExc2 for E1E1
    preIndsE2E2, posIndsE2E2 = np.unravel_index(indicesE2E2Flat, (nExc2, nExc2))
    preIndsE1E1, posIndsE1E1 = np.unravel_index(indicesE1E1Flat, (nExc1, nExc1))
    PR.preEE = np.concatenate((PR.preEE, preIndsE2E2 + p['nUnitsToSpike']))
    PR.posEE = np.concatenate((PR.posEE, posIndsE2E2 + p['nUnitsToSpike']))
    PR.wEE_final = np.concatenate((PR.wEE_final, weightsSaved[:preIndsE2E2.size]))
    PR.preEE = np.concatenate((PR.preEE, preIndsE1E1 + p['nUnitsToSpike'] + nExc2))
    PR.posEE = np.concatenate((PR.posEE, posIndsE1E1 + p['nUnitsToSpike'] + nExc2))
    PR.wEE_final = np.concatenate((PR.wEE_final, weightsSaved[preIndsE2E2.size:]))

    # turn the final into a matrix also...
    wEEFinal = weight_matrix_from_flat_inds_weights(PR.p['nExc'], PR.p['nExc'], PR.preEE, PR.posEE, PR.wEE_final)

    # make inhibition stronger/weaker to compensate
    wEICompensate = weight_matrix_from_flat_inds_weights(PR.p['nInh'], PR.p['nExc'], PR.preEI, PR.posEI,
                                                         PR.wEI_final)
    wEICompensate[:, endIndExc2:] = wEICompensate[:, endIndExc2:] * propConnE1E1 / PR.p['propConnect']

    # decreaseInhOntoE2Factor = np.nansum(wEEFinal[startIndExc2:endIndExc2, :], 0).mean() / np.nansum(wEEInit[startIndExc2:endIndExc2, :], 0).mean()
    decreaseInhOntoE2Factor = np.nansum(wEEFinal[:, startIndExc2:endIndExc2], 1).mean() / np.nansum(wEEInit[:, startIndExc2:endIndExc2], 1).mean()
    wEICompensate[:, startIndExc2:endIndExc2] = wEICompensate[:, startIndExc2:endIndExc2] * decreaseInhOntoE2Factor
    PR.wEI_final = wEICompensate[PR.preEI, PR.posEI]

0.19987142857142856
0.20025
0.34
0.2567035714285714


In [43]:
# making a different version of the weight matrix where the modify conn population got moved

preEEFinalR = PR.preEE.copy()
posEEFinalR = PR.posEE.copy()

for connInd in range(preEEFinalR.size):
    preInd = PR.preEE[connInd]
    posInd = PR.posEE[connInd]
    if preInd in swapDict:
        preEEFinalR[connInd] = swapDict[preInd]
    if posInd in swapDict:
        posEEFinalR[connInd] = swapDict[posInd]

In [44]:

wEEFinalR = weight_matrix_from_flat_inds_weights(PR.p['nExc'], PR.p['nExc'], preEEFinalR, posEEFinalR, PR.wEE_final)

In [49]:
np.nanmax(wEEInitR)

461.50757

In [52]:
# wEEInit and wEEFinal

f, ax = plt.subplots(1, 2, figsize=(12, 5))
weight_matrix(ax[0], wEEInitR, useCmap='Reds', limsMethod='custom', vlims=(0, np.nanmax(wEEInitR)))
weight_matrix(ax[1], wEEFinalR, useCmap='Reds', limsMethod='custom', vlims=(0, np.nanmax(wEEInitR)))

for anax in ax:
    anax.hlines(1500, 0, 1599, color='k')
    anax.vlines(1500, 0, 1599, color='k')
    anax.set(xlabel='Postsynaptic Index', ylabel='Presynaptic Index')
    
f.savefig('benEphys3Weights.pdf', transparent=True)

In [14]:
anax = ax[0]

In [None]:
anax.hlines()

In [5]:
PR

<results.Results at 0x16736367c08>

In [6]:
PR.wEE_final.shape

(640000,)

In [None]:
JT.set_up_network_upCrit(priorResults=PR, recordAllVoltage=True)