In [81]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 2

@author: yaning
"""

import numpy as np
import importlib
import matplotlib.pyplot as plt
import torch
from torch.distributions import Normal, Uniform
from tqdm import tqdm
import time 
import math
import pickle

# functions and classes that i wrote
import with_learning.run_network as run_network
import with_learning.learning_NN.Receptors as Receptors
import with_learning.learning_NN.Network as Network

importlib.reload(Receptors)
importlib.reload(Network)

np.set_printoptions(threshold=np.inf)

path = "/home/yaning/Documents/Spiking_NN/with_learning/"

In [82]:
pointCount = 6500
deltaTms = 0.05
times = np.arange(pointCount) * deltaTms
initial_Vm = 1.3458754117369027

In [83]:
# functions
# get the minimum and maximum of the voltages
def get_min_max_distance(voltages):
    min_voltage = abs(np.nanmin(voltages) + 70)
    max_voltage = abs(np.nanmax(voltages) - 40.1)
    return min_voltage, max_voltage

def sigmoid(x):
    return 1 / (1 + math.exp(-x))

def get_current_distance(currents):
    current_d = abs(np.nanmax(currents)-175)
    return current_d

def get_nan_inf_amount(arr):
    return np.isinf(arr).sum() + np.isnan(arr).sum()


In [84]:
all_synapses = []

def create_synapse(send_neuron, receive_neuron, type):
    
    # create receptors accordingly
    if type == "AMPA":
        # temporal solution for weight randomise
        # Receptors.LigandGatedChannelFactory.set_params()
        ampa_receptor = Receptors.AMPA(0.0072, 1, -70, 1.35, 0.9, 1, 1, 1, 12, 10, 20, 10, 35, 7, 0.7, "AMPA")
        synapse = Network.Synapse(0.05, 0, send_neuron, receive_neuron, ampa_receptor)
        
    elif type == "AMPA+NMDA":
        # Receptors.LigandGatedChannelFactory.set_params()
        ampa_receptor = Receptors.AMPA(0.0072, 1, -70, 1.35, 0.9, 1, 1, 1, 12, 10, 20, 10, 35, 7, 0.7, "AMPA")
        nmda_receptor = Receptors.NMDA(0.0012, 1, -70, 1.35, 0.9, 1, 1, 1, 12, 10, 20, 10, 15, 7, 0.7, "NMDA")
        synapse = Network.Synapse(0.05, 0, send_neuron, receive_neuron, ampa_receptor, nmda_receptor)
    
    elif type == "GABA":
        # Receptors.LigandGatedChannelFactory.set_params()
        # print(Receptors.LigandGatedChannelFactory.w_init_GABA)
        gaba_receptor = Receptors.GABA(0.0012, 1, -140, 1.35, 0.9, 1, 1, 1, 12, 10, 20, 10, 20, 7, 0.7, "GABA")
        synapse = Network.Synapse(0.05, 0, send_neuron, receive_neuron, gaba_receptor)

    send_neuron.outgoing_synapses.append(synapse)
    receive_neuron.incoming_synapses.append(synapse)

    all_synapses.append(synapse)

def update_synapse_initial_values(infer_params):
    for synapse in all_synapses:
        for receptor in synapse.receptors:
            receptor.Vm = initial_Vm
            receptor.gP = 1
            
            receptor.e = infer_params["e"]
            receptor.u_se = infer_params["u_se"]
            receptor.g_decay = infer_params["g_decay"]
            receptor.g_rise = infer_params["g_rise"]
            receptor.w = infer_params["w"]
            receptor.tau_rec = infer_params["tau_rec"]
            receptor.tau_pre = infer_params["tau_pre"]
            receptor.tau_post = infer_params["tau_post"]

            if receptor.label == "GABA":
                receptor.gMax = infer_params["gMax_GABA"]
                receptor.tau_decay = infer_params["tau_decay_GABA"]
                receptor.tau_rise = infer_params["tau_rise_GABA"]
            
            elif receptor.label == "NMDA":
                receptor.tau_decay = infer_params["tau_decay_NMDA"]
                receptor.tau_rise = infer_params["tau_rise_NMDA"]
            
            elif receptor.label == "AMPA":
                receptor.tau_decay = infer_params["tau_decay_AMPA"]
                receptor.tau_rise = infer_params["tau_rise_AMPA"]

In [85]:
input_pattern = np.load(path + "dataset.npy") 
output_pattern = np.load(path + "output.npy")

input_pattern = input_pattern[1]

In [86]:
input_pattern.shape

(3, 6500)

In [87]:
with open(path + "XOR_dataset_black_white.pkl", "rb") as f:
    dataset, binary = pickle.load(f)

In [88]:
dataset.shape

(8, 3, 6000)

In [89]:
# ----------1 input, 1 excite main, 1 excite sub, 1 inhibit main, 1 inhibit sub, output-----------

# create the network
# Neuron: deltaTms, I, Vm, Name
neuron_input_0 = Network.Neuron(deltaTms, 0, initial_Vm, "input_0")
neuron_input_1 = Network.Neuron(deltaTms, 0, initial_Vm, "input_1")
neuron_input_2 = Network.Neuron(deltaTms, 0, initial_Vm, "input_2")

neuron_excite_main = Network.Neuron(deltaTms, 0, initial_Vm, "excite_main")
neuron_excite_sub = Network.Neuron(deltaTms, 0, initial_Vm, "excite_sub")

neuron_inhibit_main = Network.Neuron(deltaTms, 0, initial_Vm, "inhibit_main")
neuron_inhibit_sub = Network.Neuron(deltaTms, 0, initial_Vm, "inhibit_sub")

neuron_output = Network.Neuron(deltaTms, 0, initial_Vm, "output")

neurons = [neuron_input_0, neuron_input_1, neuron_input_2, 
           neuron_excite_main, neuron_excite_sub, 
        neuron_inhibit_main, neuron_inhibit_sub, neuron_output]

neuron_names = ["input_0", "input_1", "input_2",
                "excite_main", "excite_sub", "inhibit_main", "inhibit_sub", "output"]






#*********************full layer***************************
# ----------------first input layer------------------------
create_synapse(neuron_input_0, neuron_excite_main, "AMPA")
create_synapse(neuron_input_1, neuron_excite_main, "AMPA")
create_synapse(neuron_input_2, neuron_excite_main, "AMPA")

create_synapse(neuron_input_0, neuron_inhibit_main, "GABA")
create_synapse(neuron_input_1, neuron_inhibit_main, "GABA")
create_synapse(neuron_input_2, neuron_inhibit_main, "GABA")



# ----------------self recurrent layer----------------
create_synapse(neuron_excite_main, neuron_excite_sub, "AMPA+NMDA")
create_synapse(neuron_excite_sub, neuron_excite_main, "AMPA+NMDA")

create_synapse(neuron_inhibit_main, neuron_inhibit_sub, "GABA")
create_synapse(neuron_inhibit_sub, neuron_inhibit_main, "GABA")

# --------------between excitatory and inhibitory----------------
create_synapse(neuron_excite_main, neuron_inhibit_main, "AMPA+NMDA")
create_synapse(neuron_inhibit_main, neuron_excite_main, "GABA")


# ----------------output layer----------------------
create_synapse(neuron_excite_main, neuron_output, "AMPA")


In [90]:
input_pattern.shape

(3, 6500)

In [91]:
def run(input_pattern):
    currents = []

    for t in range(pointCount):
        currents_tstep = []
        
        if input_pattern[0,t]:
            neuron_input_0.sending_signal()
            neuron_input_0.fire_tstep.append(t)
            
        if input_pattern[1,t]:
            neuron_input_1.sending_signal()
            neuron_input_1.fire_tstep.append(t)
            
        if input_pattern[2,t]:
            neuron_input_2.sending_signal()
            neuron_input_2.fire_tstep.append(t)


        # update the synapse states then each neuron\
        num_cycle = 0
        for neuron in neurons[3:]:
            if neuron.fire_tstep == []:
                last_fire = -2
            else:
                last_fire = neuron.fire_tstep[-1]
            fire = neuron.check_firing(t)
            if fire:
                if neuron.fire_tstep == [] or last_fire + 1 != t:
                    # print("this line runs")
                    neuron.update_weights(t)

            
            neuron.update()
            
            # only record output current
            if num_cycle == 4:
                currents_tstep.append(neuron.I)
            num_cycle += 1
            
        # set the synapse states back to 0
        for synapse in all_synapses:
            synapse.state = 0
        
        currents.append(currents_tstep)
    return currents

In [92]:
infer_params = {
    # this might be extremly small
    "gMax_GABA" : 0.01,
    
    # between 0-1
    "e" : 1, # has to be 0-1 because the differentiate formula
    "u_se" : 1, # fraction of avaible transmitter 
    
    # do not need to be between 0-1
    "g_decay" : 1, # technically substraction of 
    "g_rise" : 1, # both should be smaller than 1 (fraction)

    "w" : 1,

    "tau_rec" : 10,
    "tau_pre" : 20,
    "tau_post" : 20,
    
    "tau_decay_AMPA" : 35,
    "tau_rise_AMPA" : 7,
    "tau_decay_NMDA" : 20,
    "tau_rise_NMDA" : 9,
    "tau_decay_GABA" : 40,
    "tau_rise_GABA" : 5 
    }

infer_names = ["gMax_GABA", "e", "u_se", "g_decay", "g_rise", "w", "tau_rec",
               "tau_pre", "tau_post", "tau_decay_AMPA", "tau_rise_AMPA",
               "tau_decay_NMDA", "tau_rise_NMDA", "tau_decay_GABA", "tau_rise_GABA"]

In [93]:
# start = time.time()
# run(input_pattern)
# for neuron in neurons:
#     neuron.erase(initial_Vm)
# update_synapse_initial_values(infer_params)
# print("Time taken:", time.time() - start)

# times = []
# for i in range(500):
#     start = time.time()
#     run(input_pattern)
#     for neuron in neurons:
#         neuron.erase(initial_Vm)
#     update_synapse_initial_values(infer_params)
#     times.append(time.time() - start)

In [94]:
# # start = time.time()
# Normal(last_pure_sample[j], std).sample().numpy()
# # print("Time taken:", time.time() - start)

In [95]:
# # try use MCMC result params
# path = "/home/yaning/Documents/Spiking_NN/without_learning/"
# samples = np.load(path + "MCMC_samples/static_std_initial_0.npy")
# cut_samples = samples[1000:, :]
# values = np.mean(cut_samples, axis=0)
# infer_names = Receptors.LigandGatedChannelFactory.infer_names
# infer_params = dict(zip(infer_names, values))

start = time.time()
#-------------------------initialise MCMC---------------------------
samples = []
pure_samples = []

for i, key in enumerate(infer_params):
    if i <= 2:
        infer_params[key] = 0.5
    else:
        # factor is 30, multiply by 0.5 (initial)
        infer_params[key] = 15

initial_sample = []
for key in infer_params:
    initial_sample.append(infer_params[key])

# put first pure and normal samples in the record
pure_samples.append(np.zeros(len(infer_params)))
samples.append(initial_sample)

# run first round
# return voltages, currents, neuron_names
update_synapse_initial_values(infer_params)
currents = run(input_pattern)

# # evaluation scores
# old_score = abs(235-len(firing_tstep))
# min_voltage_old, max_voltage_old = get_min_max_distance(voltages)

# score_old = get_nan_inf_amount(voltages) + get_nan_inf_amount(currents)
score_old = get_current_distance(currents)


# print("Time taken to finish initial run:", time.time() - start)
#-----------------------officially run MCMC----------------------
num = 1000
big_factor = 30
small_factor = 0.01

for i in tqdm(range(num), desc="Processing", ncols=100):
    # # make the std decrease as MCMC keeps sampling
    # std = np.exp(-i*3/num)
    std = 8/(1+(24/num)*i)
    # std = 6
    start = time.time()
    one_round_pure_sample = []
    one_round_sample = []
    
    temp_infer_params = {}
    
    last_pure_sample = pure_samples[-1]

    for j in range(len(infer_params)):
        # using j only for separate with and without factor
        temp_pure_sample = Normal(last_pure_sample[j], std).sample().item()
        # temp_pure_sample = 1
        one_round_pure_sample.append(temp_pure_sample)
        # if j == 0:
        #     temp_infer_params[infer_names[j]] = small_factor*sigmoid(temp_pure_sample)
        if j <= 2:
            temp_infer_params[infer_names[j]] = sigmoid(temp_pure_sample)
        else:
            temp_infer_params[infer_names[j]] = big_factor*sigmoid(temp_pure_sample)

    # print(temp_infer_params)
    # print("Time taken to choose params:", time.time() - start)
    # print(one_round_pure_sample)
    # run with sampled value
    
    # print(temp_infer_params)
    # restart the network
    for neuron in neurons:
        neuron.erase(initial_Vm)
    update_synapse_initial_values(temp_infer_params)
    # print(all_synapses[0].receptors[0].w)
    
    start = time.time()
    currents = run(input_pattern)
    # print("Time taken to finish mcmc run:", time.time() - start)
    # # evaluation scores
    # new_score = abs(235-len(firing_tstep))
    # acceptance_ratio = old_score/new_score
    # print(old_score, new_score)
    
    # min_voltage_new, max_voltage_new = get_min_max_distance(voltages)
    # min_voltage_rate = min_voltage_old/min_voltage_new
    # max_voltage_rate = max_voltage_old/max_voltage_new
    
    # acceptance_ratio = max_voltage_rate
    # print(min_voltage_new, max_voltage_new)
    # print(acceptance_ratio)
    # score_new = get_nan_inf_amount(voltages) + get_nan_inf_amount(currents)
    start = time.time()
    score_new = get_current_distance(currents)
    lambda_value = 0.01
    # print(score_new, score_old)
    score_old_softmax = np.exp(lambda_value * score_old) / (np.exp(lambda_value*score_old) + np.exp(lambda_value*score_new))
    score_new_softmax = np.exp(lambda_value * score_new) / (np.exp(lambda_value*score_old) + np.exp(lambda_value*score_new))
    
    acceptance_ratio = score_old_softmax/score_new_softmax
    # print(score_old, score_new)
    # print(acceptance_ratio)

    u = np.random.uniform(0, 1)

    if acceptance_ratio >= u:
        # print("replace")
        # old_score = new_score
        # current_d_old = current_d_new
        # min_voltage_old = min_voltage_new
        # max_voltage_old = max_voltage_new
        score_old = score_new

        
        infer_params = temp_infer_params
        # if accept new, the add the pure for next round use
        pure_samples.append(one_round_pure_sample)

    
    for key in infer_params:
        one_round_sample.append(infer_params[key])
    
    samples.append(one_round_sample)
    # print("Time taken to finish choose the sample:", time.time() - start)
    

samples = np.array(samples)
pure_samples = np.array(pure_samples)


np.save(path + 'MCMC_samples/samples_local.npy', samples)
np.save(path + 'MCMC_samples/pure_samples_local.npy', pure_samples)

  score_old_softmax = np.exp(lambda_value * score_old) / (np.exp(lambda_value*score_old) + np.exp(lambda_value*score_new))
  score_new_softmax = np.exp(lambda_value * score_new) / (np.exp(lambda_value*score_old) + np.exp(lambda_value*score_new))
  score_new_softmax = np.exp(lambda_value * score_new) / (np.exp(lambda_value*score_old) + np.exp(lambda_value*score_new))
  return -g_decay/self.tau_decay + w*e
  k2 = f(y0 + h*k1/2, *arg)
  return -g_decay/self.tau_decay + w*e
  return -g_rise/self.tau_rise + w*e
  I = self.gMax * self.gP *(self.Vm - self.rE)
Processing: 100%|███████████████████████████████████████████████| 1000/1000 [18:24<00:00,  1.10s/it]


In [96]:
import pickle

with open(path + "infer_params_names.pkl", "wb") as f:
    pickle.dump((infer_params, infer_names), f)