## Imports

In [None]:
import sys
import os 
pkg_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(pkg_path)

In [None]:
import numpy as np

from pynn_brainscales import brainscales2 as pynn

import shtmbss2.addsrc
from shtmbss2.core import hardware_initialization
from shtmbss2.network import SHTMTotal

## Configuration

In [None]:
# pynn.logger.default_config(level=pynn.logger.LogLevel.DEBUG)

alphabet_size = 4
num_neurons_per_symbol = 15
runtime = 0.18
num_repetitions = 1
num_sim_steps = 5
debug = False

input_sequence = ['A', 'C', 'B']
# input_sequence = []

hardware_initialization()

## Network Initialization

In [None]:
shtm = SHTMTotal(alphabet_size, num_neurons_per_symbol, log_permanence='all', log_weights='all')

shtm.init_neurons(v_rest=[60, 60], v_reset=[100, 60], v_thresh=[120, 140])
shtm.init_connections(debug=debug, w_ext_exc=500, w_exc_exc=0.01, w_exc_inh=200, w_inh_exc=-500, 
                      p_exc_exc=0.4, mature_weight=500, learning_factor=100)
shtm.init_external_input(sequence=input_sequence, num_repetitions=num_repetitions)
shtm.init_rec_exc()

## Network Emulation & Plotting

In [None]:
shtm.run(runtime=runtime, steps=1, plasticity_enabled=True)

In [None]:
%matplotlib inline

shtm.plot_events(neuron_types="all")

In [None]:
shtm.run(runtime=runtime, steps=num_sim_steps, plasticity_enabled=True)

In [None]:
%matplotlib inline

shtm.plot_events(neuron_types="all")

## Additional Plotting

In [None]:
%matplotlib inline

shtm.plot_permanence_diff()

In [None]:
%matplotlib inline

shtm.plot_permanence_history(plot_con_ids=[0,1])

In [None]:
%matplotlib inline

shtm.plot_v_exc(alphabet_range=[0], neuron_range=[0,2], neuron_type=1, runtime=runtime)

In [None]:
%matplotlib inline

shtm.plot_v_exc(alphabet_range=[0], neuron_range='all', neuron_type=1, runtime=runtime)

In [None]:
%matplotlib inline

shtm.plot_v_exc(alphabet_range=range(1, alphabet_size))

## Additional Analysis

In [None]:
for i in range(len(shtm.con_plastic)):
    shtm.con_plastic[i].mature_weight = 120
    print(i, shtm.con_plastic[i].projection.label.split('_')[1], shtm.con_plastic[i].get_all_connection_ids())
    

In [None]:
arr = np.array(shtm.con_plastic[1].permanences)

In [None]:
for c in shtm.con_plastic[1].projection.connections:
    print(f'C[{c.presynaptic_index}, {c.postsynaptic_index}].weight = {c.weight}')

In [None]:
shtm.con_plastic[1].projection.get("weight", format="array")

In [None]:
shtm.con_plastic[7].projection.get("weight", format="array")

In [None]:
print(shtm.con_plastic[1].projection.post.get_data("spikes").segments[-1].spiketrains)

In [None]:
for a in range(alphabet_size):
    spikes = [s.base for s in shtm.neurons_exc[a][1].get_data("spikes").segments[-1].spiketrains]
    new_spikes = []
    for i in range(len(spikes)):
        if len(spikes[i]) > 0:
            new_spikes.append(float(spikes[i]))
    new_spikes = sorted(new_spikes)
    
    if len(new_spikes) <= 0:
        print(f"No spikes for '{shtm.id_to_letter(a)}'\n")
        continue
        
    print(f"Spike times for '{shtm.id_to_letter(a)}'")
    print(f"Min: {round(min(new_spikes), 8)}")
    print(f"Max: {round(max(new_spikes), 8)}")
    print(f"Diff: {round(max(new_spikes)-min(new_spikes), 8)}\n")

In [None]:
for a in range(alphabet_size):
    spikes = [s.base for s in shtm.neurons_exc[a][1].get_data("spikes").segments[-1].spiketrains]
    new_spikes = []
    for i in range(len(spikes)):
        if len(spikes[i]) > 0:
            new_spikes.append(float(spikes[i]))
    new_spikes = sorted(new_spikes)
    
    if len(new_spikes) <= 0:
        print(f"No spikes for '{shtm.id_to_letter(a)}'\n")
        continue
        
    print(f"Spike times for '{shtm.id_to_letter(a)}'")
    print(f"Min: {round(min(new_spikes), 8)}")
    print(f"Max: {round(max(new_spikes), 8)}")
    print(f"Diff: {round(max(new_spikes)-min(new_spikes), 8)}\n")

## Set new weights

In [None]:
new_weight = 450

weights = shtm.con_plastic[1].projection.get("weight", format="array")
# weights[[0, 1, 3], 6] = new_weight
# weights[[4, 7, 9], 6] = 0
print(weights)
weights[weights > 0] = new_weight
shtm.con_plastic[1].projection.set(weight=weights)
shtm.con_plastic[1].mature_weight = new_weight

shtm.con_plastic[1].projection.get("weight", format="array")