### Phase Space plots of one P1-P2 pair from 5 sets of a  2-coupled network

# References:
    1) https://www.neuron.yale.edu/neuron/static/new_doc/programming/hocsyntax.html
    2) https://www.neuron.yale.edu/neuron/static/new_doc/programming/python.html
    3) https://www.neuron.yale.edu/neuron/static/py_doc/programming/python.html
    4) https://www.geeksforgeeks.org/single-neuron-neural-network-python/
    5) https://github.com/piazentin/ksets
    6) AP - https://www.moleculardevices.com/applications/patch-clamp-electrophysiology/what-action-potential#gref
    7) https://www.neuron.yale.edu/neuron/static/py_doc/modelspec/programmatic/network/netcon.html
    8) http://neupy.com/apidocs/neupy.algorithms.associative.hebb.html
    9) https://qbi-software.github.io/NEURON-tutorial/lessons/network
    10) https://www.neuron.yale.edu/neuron/static/py_doc/modelspec/programmatic/topology/geometry.html
    11) http://www.cnel.ufl.edu/courses/EEL6814/chapter6.pdf
    12) https://www.slideshare.net/mentelibre/hebbian-learning
    13) http://www.diva-portal.org/smash/get/diva2:1089220/FULLTEXT02

In [1]:
!pip install neuron

Collecting neuron
  Downloading NEURON-8.0.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (12.6 MB)
[K     |████████████████████████████████| 12.6 MB 6.9 MB/s 
Installing collected packages: neuron
Successfully installed neuron-8.0.2


### Stage 1: Import the required libraries and import base neurons



In [3]:
import os
import sys

if os.getcwd() != "/content":
    # we are not in google colab, assume hebbian library is in current dir
    from hebb import MCELL
else:
    # download from git
    import shutil, requests
    url = 'https://cloud.operationtulip.com/s/Nd7Wq4x7sCxB6Qx/download/git2.zip'
    response = requests.get(url, stream=True)
    with open('git2.zip', 'wb') as out_file:
        shutil.copyfileobj(response.raw, out_file)
    del response
    import zipfile
    with zipfile.ZipFile('git2.zip', 'r') as zip_ref:
        zip_ref.extractall('')
    from hebb_test import MCELL
from neuron import h #.h is a HOC object instance & gui from neuron can also be imported 
from math import pi                  
from neuron.units import ms,mV
h.load_file('stdrun.hoc')                #Allows us to do a high level simulation

import random  
import matplotlib.pyplot as plt1

from bokeh.io import output_notebook
import bokeh.plotting as plt2
output_notebook()

import seaborn as sns

### Stage 3: Import random weights
We import the random weights as intitally the weights of the neurons are not adjusted to fire synchronously.

In [4]:
#Give N value to get N no. of 2 coupled neurons laterally connected
#given_input_from_user = int(input('Enter the number of 2 coupled sets required: '))
import numpy as np
given_input_from_user = 5
low, high = 0.05, 0.1
all_weights = np.random.uniform(low, high,2*given_input_from_user-1)

In [5]:
def stairstep(lo: float, hi: float, duration: int, max_duration: int = 200) -> list:
    # duration and max_duration are in ms
    return [hi if i < duration else lo for i in range(0,max_duration)]
    # return [0+0.8*(i/duration) if i < duration else 0 for i in range(0,max_duration)]

### Stage 4: Couple a neuron according to Freeman's KI Set with the weights generated above

In [6]:
#Gives one 2 coupled neuron using freeman's topology in olfactory bulb
class CONNECTING:
    """Coupling neurons
    """
    def __init__(self,M,weights,delay1=1,delay2=1):
        wMM=weights[0]
        
        self.M = M #Set no.
        self.th = -70
        self.maindelay = 0
        #self.syns = []
        #self.netcons = []
        
        #Making 2 neurons
        self.cells=[]
        self.M1=MCELL(1,self.M)
        self.M2=MCELL(2,self.M)
        self.cells.append(self.M1)
        self.cells.append(self.M2)
        
        # give IClamp
        self.stim = h.IClamp(self.M1.dend(0.9))
        self.stim.delay = 1   #in ms
        self.stim.dur = 100     #in ms
        self.stim.amp = 0.8   #in nA

        
        #Connecting M1 to M2
        self.nc1 = h.NetCon(self.M1.soma(0.5)._ref_v,self.M2.dendexcisyn,sec=self.M1.soma)
        self.nc1.weight[0] = weights[0]
        self.nc1.delay = 0 #tm1m2
        self.nc1.threshold = self.th
        
        #Connecting M2 to M1
        self.nc2 = h.NetCon(self.M2.soma(0.5)._ref_v,self.M1.dendexcisyn,sec=self.M2.soma)
        self.nc2.weight[0] = weights[0]
        self.nc2.delay = 0 #tm2m1
        self.nc2.threshold = self.th
        
        

### Stage 5: Laterally connect each set of 2 coupled neurons according to Freeman's KI Set

In [7]:
class LATERAL:
    """Laterally coupling N-2 coupled neurons
    """
    def __init__(self,N,weights):
        wMML=weights[1]
        wGGL=weights[2]
        
        self.N = N
        self.th = -70
        self.maindelay = 20
        self.sets = []
        for i in range(N):
            self.sets.append(CONNECTING(i,weights))
            
        self.twoCupArr = []
        for r in range(0,N-1):
          #Connecting P1[0] to P1[1]
          # make a list for easier access
          netConList = []
          netConList.append(h.NetCon(self.sets[r].M1.axon(0.5)._ref_v, self.sets[r+1].M1.dendexcisyn,sec=self.sets[r].M1.axon))
          netConList[0].weight[0]  = weights[r+N]
          # print("Lateral weight is", weights[r+N])
          netConList[0].delay = self.maindelay+1
          netConList[0].threshold = self.th

          netConList.append(h.NetCon(self.sets[r+1].M1.axon(0.5)._ref_v, self.sets[r].M1.dendexcisyn, sec=self.sets[r+1].M1.axon))
          netConList[1].weight[0] = weights[r+N]
          netConList[1].delay = self.maindelay+2
          netConList[1].threshold = self.th

          self.twoCupArr.append(netConList)


### Stage 6: Input the number of 2 coupled sets required (which are all laterally connected)

In [None]:
L1=LATERAL(given_input_from_user,all_weights)

### Stage 7: Visualize the topology of the neurons and the 3D space they are in. Also visualize the density mechanisms added to each neuron of a set.

In [None]:
h.topology()

In [None]:
from neuron import h, gui2
gui2.set_backend('jupyter')
ps = gui2.PlotShape()
ps.variable('v')
ps.show(0)

In [None]:
for sec in h.allsec():
    print('%s: %s' % (sec, ', '.join(sec.psection()['density_mechs'].keys())))

### Stage 8 : Plot Activation and Inactivation Parameters

In [None]:
import numpy as np
from matplotlib import pyplot

checkCell = L1.sets[0].M1.axon(0.5)

tvec = h.Vector().record(h._ref_t)

vvecA = h.Vector().record(checkCell._ref_v)
kvecA = h.Vector().record(checkCell.k_ion._ref_ik)
# nvecA = h.Vector().record(checkCell.na_ion._ref_ina)
mvecA = h.Vector().record(checkCell.hh._ref_m)
hvecA = h.Vector().record(checkCell.hh._ref_h)
nvecA = h.Vector().record(checkCell.hh._ref_n)


h.finitialize(-70)
h.continuerun(300)

fig = pyplot.figure()
pyplot.plot(tvec, vvecA, label="Membrane potential")
pyplot.xlabel('t (ms)')
pyplot.ylabel('V$_m$ (mV)')
pyplot.legend(frameon=False)

fig = pyplot.figure()
pyplot.plot(tvec, hvecA, '-b', label='h')
pyplot.plot(tvec, nvecA, '-r', label='n')
pyplot.xlabel('t (ms)')
pyplot.ylabel('state')
pyplot.legend(frameon=False)


fig = pyplot.figure()
pyplot.plot(tvec, kvecA.as_numpy(), '-b', label='h')
pyplot.plot(tvec, nvecA.as_numpy(), '-r', label='n')
pyplot.xlabel('t (ms)')
pyplot.ylabel('current (mA/cm$^2$)')
pyplot.legend(frameon=False)

        


In [None]:
fig = pyplot.figure()
pyplot.plot(hvecA, vvecA, label="")
pyplot.xlabel('h')
pyplot.ylabel('V$_m$ (mV)')
pyplot.title('Voltage vs Sodium inactivation parameter')
pyplot.legend(frameon=False)

In [None]:
fig = pyplot.figure()
pyplot.plot(nvecA, vvecA, label="")
pyplot.xlabel('n')
pyplot.ylabel('V$_m$ (mV)')
pyplot.title('Voltage vs Potassium activation parameter')
pyplot.legend(frameon=False)

In [None]:
fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(vvecA, nvecA, mvecA, hvecA)
# pyplot.xlabel('m')
# pyplot.ylabel('V$_m$ (mV)')
pyplot.legend(frameon=False)

In [None]:
fig = pyplot.figure()

pyplot.plot(nvecA, hvecA)
# pyplot.xlabel('m')
# pyplot.ylabel('V$_m$ (mV)')
pyplot.legend(frameon=False)

### Stage 9: Record & visualize the voltage values at each axon of a neuron in the 1st set

In [None]:
recording_cell = L1.sets[0].M1
axon_m1 = h.Vector().record(recording_cell.axon(0.5)._ref_v)
dend_m1 = h.Vector().record(recording_cell.dend(0.5)._ref_v)
t = h.Vector().record(h._ref_t)

h.finitialize(-70 * mV)
h.continuerun(200 * ms)

f = plt2.figure(x_axis_label='t (ms)', y_axis_label='v (mV)')
f.line(t, list(axon_m1), line_width=1,legend_label='P1 axon',line_color='black')
f.line(t, list(dend_m1), line_width=2,legend_label='P1 dendrite',line_color='red', line_dash='dashed')
plt2.show(f)

In [None]:
recording_cell = L1.sets[0].M2
axon_m2 = h.Vector().record(recording_cell.axon(0.5)._ref_v)
dend_m2 = h.Vector().record(recording_cell.dend(0.5)._ref_v)
t = h.Vector().record(h._ref_t)

h.finitialize(-70 * mV)
h.continuerun(200 * ms)

f = plt2.figure(x_axis_label='t (ms)', y_axis_label='v (mV)')
f.line(t, list(axon_m2), line_width=1,legend_label='P2 axon',line_color='black')
f.line(t, list(dend_m2), line_width=2,legend_label='P2 dendrite',line_color='red', line_dash='dashed')
plt2.show(f)

In [None]:
recording_cell = L1.sets[1].M1
axon_m2 = h.Vector().record(recording_cell.axon(0.5)._ref_v)
dend_m2 = h.Vector().record(recording_cell.dend(0.5)._ref_v)
t = h.Vector().record(h._ref_t)

h.finitialize(-70 * mV)
h.continuerun(200 * ms)

f = plt2.figure(x_axis_label='t (ms)', y_axis_label='v (mV)')
f.line(t, list(axon_m1), line_width=1,legend_label='P1 axon',line_color='black')
f.line(t, list(dend_m1), line_width=2,legend_label='P1 dendrite',line_color='red', line_dash='dashed')
plt2.show(f)

In [None]:
recording_cell = L1.sets[0].M1
recording_cell_2 = L1.sets[0].M2
axon_m1 = h.Vector().record(recording_cell.axon(0.5)._ref_v)
dend_m1 = h.Vector().record(recording_cell.dend(0.5)._ref_v)
axon_m2 = h.Vector().record(recording_cell_2.axon(0.5)._ref_v)
dend_m2 = h.Vector().record(recording_cell_2.dend(0.5)._ref_v)
t = h.Vector().record(h._ref_t)

h.finitialize(-70 * mV)
h.continuerun(200 * ms)

f = plt2.figure(x_axis_label='v (mV)', y_axis_label='v (mV)',title='P1 axon vs P1 dendrite (2-coupled)')
f.line(list(axon_m1), list(dend_m1), line_width=1,line_color='black')
f2 = plt2.figure(x_axis_label='v (mV)', y_axis_label='v (mV)', title='P2 axon vs P2 dendrite (2-coupled)')
f2.line(list(axon_m2), list(dend_m2), line_width=2,line_color='black')
plt2.show(f)
plt2.show(f2)

In [None]:
recording_cell = L1.sets[0].M1
recording_cell_2 = L1.sets[0].M2
axon_m1 = h.Vector().record(recording_cell.axon(0.5)._ref_v)
dend_m1 = h.Vector().record(recording_cell.dend(0.5)._ref_v)
axon_m2 = h.Vector().record(recording_cell_2.axon(0.5)._ref_v)
dend_m2 = h.Vector().record(recording_cell_2.dend(0.5)._ref_v)
t = h.Vector().record(h._ref_t)

h.finitialize(-70 * mV)
h.continuerun(200 * ms)

f = plt2.figure(x_axis_label='v (mV)', y_axis_label='v (mV)',title='P1 axon vs P2 axon (2-coupled)')
f.line(list(axon_m1), list(axon_m2), line_width=1,line_color='black')
f2 = plt2.figure(x_axis_label='v (mV)', y_axis_label='v (mV)',title='P1 dendrite vs P2 dendrite (2-coupled)')
f2.line(list(dend_m1), list(dend_m2), line_width=2,line_color='black')
plt2.show(f)
plt2.show(f2)


### Stage 10: Visualize the spike timing of each neuron in a set

In [None]:
color=['green','orange']
label=['P1','P2']
plt1.figure(figsize=(16,10))
for j in range(given_input_from_user):
    for i,cell in enumerate(L1.sets[j].cells):
        plt1.vlines(cell.spike_times+(100*j), i + 0.5, i + 1,color=color[i])
plt1.xlabel('time(ms)')
plt1.show()
t1=list(L1.sets[0].cells[0].spike_times)
t2=list(L1.sets[1].cells[0].spike_times)
print(t1)
print(t2)

In [None]:
color=['green','orange']
label=['P1','P2']
plt1.figure(figsize=(15,8))
for i,cell in enumerate(L1.sets[j].cells):
    plt1.vlines(cell.spike_times, i + 0.5, i + 1,color=color[i],label=label[i])
plt1.xlabel('t (ms)')
plt1.legend()
plt1.show()

##### 

## Stage 11: Induce Learning in the network



In [None]:
def generate_L_weight_delta(first, second):
    A_plus = 0.01
    A_minus = -0.0011
    tau_pre =20*ms
    tau_post =20*ms
    delta_t = [(second[iter] - first[iter]) for iter in range(min(len(first),len(second)))]
    delta_w_list = [A_plus*math.exp(-delta_t[iter]/tau_post) if delta_t[iter] >=0 else A_minus*math.exp(delta_t[iter]/tau_pre) for iter in range(min(len(first),len(second)))]  
    delta_w = sum(delta_w_list)
    return delta_w

In [None]:
epochs = 50
import numpy as np
import math
out_1_1_data = []
A_plus = 0.01      #0.2 to 2.5
A_minus = -0.01
tau_pre = 20*ms
tau_post = 20*ms

weights_rec = [[0 for i in range(epochs)] for j in range(len(L1.sets))]
weights_rec_2 = [[0 for i in range(epochs)] for j in range(len(L1.sets))]
L_weights_rec= [[0 for i in range(epochs)] for j in range(len(L1.sets)-1)]
L_weights_rec_2= [[0 for i in range(epochs)] for j in range(len(L1.sets)-1)]

for l in range(epochs):  
    print("\niteration no is",l)
    P1_data = []
    print("\nmutual learning")
    for i in range(len(L1.sets)):
        h.continuerun(300 *ms)
        out_1_1 = h.Vector().record(L1.sets[i].M1.axon(0.5)._ref_v)
        # print("Spike times P{0}".format(1+i*2), list(L1.sets[i].P1.spike_times))
        spike_1_1 = list(L1.sets[i].M1.spike_times)
        out_1_2 = h.Vector().record(L1.sets[i].M2.dend(0.5)._ref_v)
        spike_1_2 = list(L1.sets[i].M2.spike_times)
        # print("Spike times P{0}".format(2+i*2), list(L1.sets[i].P2.spike_times))
        t = h.Vector().record(h._ref_t)
        h.finitialize(-70 * mV)
        
        for k in range(1,3):
            delta_t = [(spike_1_2[iter] - spike_1_1[iter]) for iter in range(min(len(spike_1_1),len(spike_1_2)))]
            delta_w_list = [A_plus*math.exp(-delta_t[iter]/tau_post) if delta_t[iter] >=0.24 else A_minus*math.exp(delta_t[iter]/tau_pre) for iter in range(min(len(spike_1_1),len(spike_1_2)))]
            delta_w = sum(delta_w_list)
            exec(f"L1.sets[i].nc{k}.weight[0] += delta_w")         
        P1_data.append(list(out_1_1))
        print("weight:", L1.sets[i].nc1.weight[0])
        weights_rec[i][l] = L1.sets[i].nc1.weight[0]
        weights_rec_2[i][l] = L1.sets[i].nc2.weight[0]    



#Lateral Learning
    print("\tlateral learning")
    for i in range(len(L1.sets)-1):
      h.continuerun(300 *ms)
      out_1_1 = h.Vector().record(L1.sets[i].M1.axon(0.5)._ref_v)
      print("Spike times P{0}".format(1+2*i), list(L1.sets[i].M1.spike_times))
      spike_1_1 = list(L1.sets[i].M1.spike_times)
      out_2_1 = h.Vector().record(L1.sets[i+1].M1.axon(0.5)._ref_v)
      print("Spike times P{0}".format(3+2*i), list(L1.sets[i+1].M1.spike_times))
      spike_2_1 = list(L1.sets[i+1].M1.spike_times)
      h.finitialize(-70 * mV)

      # change the deltas
      
      L1.twoCupArr[i][0].weight[0] += generate_L_weight_delta(spike_1_1, spike_2_1)
      L1.twoCupArr[i][1].weight[0] += generate_L_weight_delta(spike_1_1, spike_2_1)

      # Logging work

      print(L1.twoCupArr[i][0].weight[0])
      L_weights_rec[i][l] = L1.twoCupArr[i][0].weight[0]
      L_weights_rec[i][l] = L1.twoCupArr[i][1].weight[0]


## 12 : Plot Activation and Inactivation Parameters again (After Learning)

In [None]:
import numpy as np
from matplotlib import pyplot

checkCell = L1.sets[0].M1.axon(0.5)

tvec = h.Vector().record(h._ref_t)

vvecA = h.Vector().record(checkCell._ref_v)
kvecA = h.Vector().record(checkCell.k_ion._ref_ik)
# nvecA = h.Vector().record(checkCell.na_ion._ref_ina)
mvecA = h.Vector().record(checkCell.hh._ref_m)
hvecA = h.Vector().record(checkCell.hh._ref_h)
nvecA = h.Vector().record(checkCell.hh._ref_n)


h.finitialize(-70)
h.continuerun(300)

fig = pyplot.figure()
pyplot.plot(tvec, vvecA, label="Membrane potential")
pyplot.xlabel('t (ms)')
pyplot.ylabel('V$_m$ (mV)')
pyplot.legend(frameon=False)

fig = pyplot.figure()
pyplot.plot(tvec, hvecA, '-b', label='h')
pyplot.plot(tvec, nvecA, '-r', label='n')
pyplot.xlabel('t (ms)')
pyplot.ylabel('state')
pyplot.legend(frameon=False)


fig = pyplot.figure()
pyplot.plot(tvec, kvecA.as_numpy(), '-b', label='h')
pyplot.plot(tvec, nvecA.as_numpy(), '-r', label='n')
pyplot.xlabel('t (ms)')
pyplot.ylabel('current (mA/cm$^2$)')
pyplot.legend(frameon=False)

        


In [None]:
fig = pyplot.figure()
pyplot.plot(hvecA, vvecA, label="")
pyplot.xlabel('h')
pyplot.ylabel('V$_m$ (mV)')
pyplot.title('Voltage vs Sodium inactivation parameter')
pyplot.legend(frameon=False)

In [None]:
fig = pyplot.figure()
pyplot.plot(nvecA, vvecA, label="")
pyplot.xlabel('n')
pyplot.ylabel('V$_m$ (mV)')
pyplot.title('Voltage vs Potassium activation parameter')
pyplot.legend(frameon=False)

In [None]:
fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(vvecA, nvecA, mvecA, hvecA)
# pyplot.xlabel('m')
# pyplot.ylabel('V$_m$ (mV)')
pyplot.legend(frameon=False)

## Stage 13: Plot Voltage vs Voltage graphs

In [None]:
recording_cell = L1.sets[0].M1
recording_cell_2 = L1.sets[0].M2
axon_m1 = h.Vector().record(recording_cell.axon(0.5)._ref_v)
dend_m1 = h.Vector().record(recording_cell.dend(0.5)._ref_v)
axon_m2 = h.Vector().record(recording_cell_2.axon(0.5)._ref_v)
dend_m2 = h.Vector().record(recording_cell_2.dend(0.5)._ref_v)
t = h.Vector().record(h._ref_t)

h.finitialize(-70 * mV)
h.continuerun(200 * ms)

f = plt2.figure(x_axis_label='v (mV)', y_axis_label='v (mV)',title='P1 axon vs P1 dendrite (2-coupled)')
f.line(list(axon_m1), list(dend_m1), line_width=1,line_color='black')
f2 = plt2.figure(x_axis_label='v (mV)', y_axis_label='v (mV)', title='P2 axon vs P2 dendrite (2-coupled)')
f2.line(list(axon_m2), list(dend_m2), line_width=2,line_color='black')
plt2.show(f)
plt2.show(f2)

In [None]:
recording_cell = L1.sets[0].M1
recording_cell_2 = L1.sets[0].M2
axon_m1 = h.Vector().record(recording_cell.axon(0.5)._ref_v)
dend_m1 = h.Vector().record(recording_cell.dend(0.5)._ref_v)
axon_m2 = h.Vector().record(recording_cell_2.axon(0.5)._ref_v)
dend_m2 = h.Vector().record(recording_cell_2.dend(0.5)._ref_v)
t = h.Vector().record(h._ref_t)

h.finitialize(-70 * mV)
h.continuerun(200 * ms)

f = plt2.figure(x_axis_label='v (mV)', y_axis_label='v (mV)',title='P1 axon vs P2 axon (2-coupled)')
f.line(list(axon_m1), list(axon_m2), line_width=1,line_color='black')
f2 = plt2.figure(x_axis_label='v (mV)', y_axis_label='v (mV)',title='P1 dendrite vs P2 dendrite (2-coupled)')
f2.line(list(dend_m1), list(dend_m2), line_width=2,line_color='black')
plt2.show(f)
plt2.show(f2)
