In [None]:
import nest
import numpy as np
import matplotlib.pyplot as plt
nest.ResetKernel()

In [None]:
import nest
import matplotlib.pyplot as plt

def create_populations():
    E_neuron_params = {
        "E_L": 0.0,
        "C_m": 1.0,
        "V_T": -63.0,
        "t_ref": 5.0,
        "tau_syn_ex": 5.0,
        "tau_syn_in": 5.0,
        "g_L": 0.3
    }

    pop_size_E = 10
    pop_size_I = 10

    E_pops = nest.Create("hh_cond_exp_traub", n=pop_size_E, params=E_neuron_params)
    I_pops = nest.Create("hh_cond_exp_traub", n=pop_size_I, params=E_neuron_params)

    return E_pops, I_pops

def connect_pop():
    E_pops, I_pops = create_populations()
    syn_dict = {"synapse_model": "stdp_synapse", "weight": 1.0, "delay": 1.0}
    
    nest.Connect(E_pops, E_pops, "all_to_all", syn_spec=syn_dict)
    nest.Connect(I_pops, E_pops, "all_to_all", syn_spec=syn_dict)
    nest.Connect(E_pops, I_pops, "all_to_all", syn_spec=syn_dict)
    nest.Connect(I_pops, I_pops, "all_to_all", syn_spec=syn_dict)

    return E_pops, I_pops

def prune(threshold=0.5):
    """
    Prunes synapses with weights below a given threshold.
    """
    connections = nest.GetConnections()
    connection_details = nest.GetStatus(connections, keys=["source", "target", "weight"])
    print(connection_details)

    # connections_to_prune = [conn for conn in connection_details if conn[2] < threshold]
    # print(connections_to_prune)
    # for conn in connections_to_prune:
        # nest.Disconnect([conn[0]], [conn[1]])

    # print(f"Pruned {len(connections_to_prune)}")

def create_thalamic_input(E_pops):
    thalamic_input = nest.Create("poisson_generator", params={"rate": 500.0})  
    nest.Connect(thalamic_input, E_pops, syn_spec={"weight": 1.0, "delay": 1.0})
    

def recording_devices(E_pops, I_pops):
    multimeter = nest.Create("multimeter", params={"record_from": ["V_m"]})
    spikerecorder = nest.Create("spike_recorder")
    all_neurons = E_pops + I_pops
    nest.Connect(multimeter, all_neurons)
    nest.Connect(all_neurons, spikerecorder)
    return multimeter, spikerecorder

def plot_spikes(multimeter):
    mm = nest.GetStatus(multimeter)[0]
    if "events" in mm and len(mm["events"]["V_m"]) > 0:
        Vms = mm["events"]["V_m"]
        ts = mm["events"]["times"]

        plt.figure()
        plt.plot(ts, Vms)
        plt.xlabel("Time (ms)")
        plt.ylabel("Membrane potential (mV)")
        plt.title("Membrane potential over time")
        plt.show()
    else:
        print("No membrane potential events recorded.")

def log_weights(time):
    connections = nest.GetConnections()
    connection_details = nest.GetStatus(connections, keys=["source", "target", "weight"])
    print(f"Weights at time {time} ms:")
    for conn in connection_details:
        print(f"Source: {conn[0]}, Target: {conn[1]}, Weight: {conn[2]}")

def create_and_simulate():
    E_pops, I_pops = connect_pop()
    create_thalamic_input(E_pops)
    # prune(threshold=0.5) 
    multimeter, spikerecorder = recording_devices(E_pops, I_pops)
    
    log_weights(0)
    
    simulation_steps = [10.0, 20.0, 30.0, 40.0, 50.0]
    for step in simulation_steps:
        nest.Simulate(step)
        log_weights(step)
    
    plot_spikes(multimeter)




In [None]:
create_and_simulate()