# Visualizer
### Use this script to visualize the neuron traces according to the paper "A Generalized Linear Integrate-and-Fire Neural Model Produces Diverse Spiking Behaviors" by Stefan Mihalas and Ernst Niebur. Further, data was created with a fix length of 1sec (1ms time steps), with noise on the input current, and/or temporal jitter on the time point of the step for dynamic inputs. 

### The script will also calculate the inter-spike intervalls (ISIs) for a single trial and for all repeating trials, whenever possible. For repeating trials, all ISIs are grouped and further statics represent the outcome of all repetitions per class.

In [35]:
import pickle
import torch
import os
import progressbar

import matplotlib.pyplot as plt
import numpy as np

from tactile_encoding.utils.utils import create_directory
from utils.functions import return_isi_fix_len


In [36]:

data_path_braille = './data/braille_mn_output'  # path to output from Braille data
data_path_original = './data/original_mn_output'  # path to output from MN paper
plot_out = './plots/braille'  # path to save plots
create_directory(plot_out)
data_types = ['', '_noisy', '_temp_jitter', '_offset', '_noisy_temp_jitter',
              '_noisy_offset', '_temp_jitter_offset', '_noisy_temp_jitter_offset']

braille_letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L',
                   'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'Space']

max_trials = 100

classes_list = {
    'A': "Tonic spiking",
    'B': "Class 1",
    'C': "Spike frequency adaptation",
    'D': "Phasic spiking",
    'E': "Accommodation",
    'F': "Threshold variability",
    'G': "Rebound spike",
    'H': "Class 2",
    'I': "Integrator",
    'J': "Input bistability",
    'K': "Hyperpolarizing spiking",
    'L': "Hyperpolarizing bursting",
    'M': "Tonic bursting",
    'N': "Phasic bursting",
    'O': "Rebound burst",
    'P': "Mixed mode",
    'Q': "Afterpotentials",
    'R': "Basal bistability",
    'S': "Preferred frequency",
    'T': "Spike latency",
}

'''
neuron parameters:
a: 2.743
A1: 0.03712
A2: -0.5089
b: 11.4
G: 47.02
k1: 200
k2: 20
R1: 0
R2: 1
'''


'\nneuron parameters:\na: 2.743\nA1: 0.03712\nA2: -0.5089\nb: 11.4\nG: 47.02\nk1: 200\nk2: 20\nR1: 0\nR2: 1\n'

# Pre-processing
### extract and concatenate all ISIs for each sensor over all data after training and test

In [37]:
# load ISIs from paper
norm_count = True
norm_time = True
# for possible datatypes look at the top
filename = 'data_encoding_fix_len_noisy_temp_jitter_offset'
infile = open(f"{data_path_original}/{filename}.pkl", 'rb')
data_original = pickle.load(infile)
infile.close()
isi_original = return_isi_fix_len(
    data_original, max_trials, norm_count=norm_count, norm_time=norm_time)

In [38]:
file_names = os.listdir(data_path_braille)
file_names = np.sort(file_names)
init = True

for _, file_name in enumerate(file_names):
    [mn_spk, input_current, trial_label] = torch.load(
        data_path_braille + '/' + file_name, map_location=torch.device('cpu'))
    # convert to numpy
    mn_spk = mn_spk.numpy()
    input_current = input_current.numpy()
    trial_label = trial_label.numpy()

    # extract traces from single batch
    for batch in range(mn_spk.shape[0]):
        isi_list_channel = []

        # loop over all channels
        for channel in range(mn_spk[batch].shape[-1]):
            # calc ISI per channel and append to list
            isi_list_channel.append(
                np.diff(np.where(mn_spk[batch][:, channel] == 1.0))*1E-2)

        if init:
            # init at first run
            isi_list = isi_list_channel
            init = False
        else:
            # extend with ISIs found in channel per batch
            for num, _ in enumerate(isi_list_channel):
                # print(len(isi_list[num][0]), len(isi_list_channel[num][0]))
                isi_list[num] = np.append(
                    isi_list[num], isi_list_channel[num], axis=1)
                # print(len(isi_list[num][0]))


# Compare to paper output

## Per channel over all classes

In [39]:
# linear fit on original data
lin_fit_original = []
cov_matr_original = []
for num, isi_original_sel in enumerate(isi_original):
    if len(isi_original_sel[0]) > 0:
        # linear fit on original ISIs
        [slope, offset], cov_matr = np.polyfit(isi_original_sel[0], isi_original_sel[1], 1, cov='unscaled')
        print('uncertainty (x, y):', np.sqrt(np.diag(cov_matr)))
        lin_fit_original.append([slope, offset])
        cov_matr_original.append(cov_matr)
    else:
        print('Only got ', len(isi_original_sel[0]), ' ISIs.')
        lin_fit_original.append([[], []])
        cov_matr_original.append([])

uncertainty (x, y): [4.72455591 4.17475406]
uncertainty (x, y): [0.58361602 0.13815952]
uncertainty (x, y): [1.62697843 1.24558042]
uncertainty (x, y): [0.77674075 0.47061617]
uncertainty (x, y): [0.30205903 0.21230182]
uncertainty (x, y): [1.08463873 0.98699511]
Only got  0  ISIs.
uncertainty (x, y): [1.62697843 1.24558042]
uncertainty (x, y): [2.14662172 1.60380898]
uncertainty (x, y): [0.44806227 0.20821254]
Only got  0  ISIs.
Only got  0  ISIs.
uncertainty (x, y): [0.42753135 0.31420317]
uncertainty (x, y): [1.91236577 1.37667647]
Only got  0  ISIs.
uncertainty (x, y): [0.44567503 0.25964504]
uncertainty (x, y): [0.77089091 0.43685787]
uncertainty (x, y): [0.51849612 0.40057013]
uncertainty (x, y): [0.31224206 0.21069833]
uncertainty (x, y): [0.53528914 0.25689683]


In [40]:
# load ISIs from paper
norm_count = True
norm_time = True
# data_types = ['', '_noisy', '_temp_jitter', '_noisy_temp_jitter_offset']
filename = 'data_encoding_fix_len_noisy_temp_jitter_offset'
infile = open(f"{data_path_original}/{filename}.pkl", 'rb')
data_original = pickle.load(infile)
infile.close()
isi_original = return_isi_fix_len(
    data_original, max_trials, norm_count=norm_count, norm_time=norm_time)
error_list = []

bar = progressbar.ProgressBar(maxval=len(isi_list)*len(isi_original), \
    widgets=[progressbar.Bar('=', '[', ']'), ' ', progressbar.Percentage()])
bar.start()

# TODO save error values!
# iterate channel (sensors)
for channel, entry in enumerate(isi_list):
    # extract single ISIs and their count
    isi, count = np.unique(entry[0], return_counts=True)

    # only plot if ISIs found in channel
    if len(isi) > 0:
        if norm_time:
            isi = isi/max(isi)
        if norm_count:
            count = count/max(count)

        # linear fit on ISIs from channel (sensor)
        slope_braille, offset_braille = np.polyfit(isi, count, 1)
        # https://numpy.org/doc/stable/reference/generated/numpy.polyfit.html
        [slope_braille, offset_braille], cov_matr_braille = np.polyfit(isi, count, 1, cov='unscaled')
        # print('uncertainty (x, y): ', np.sqrt(np.diag(cov_matr_braille)))
        # create box plots with all classes from paper and single sensor and class for Braille data
        figname = f'comparison ISI sensor: {channel}'
        plt.figure(figname, figsize=(12, 12))
        plt.suptitle(figname)

        error = []
        # compare to original traces
        for num, isi_original_sel in enumerate(isi_original):
            # create scatter plot for Braille and original ISIs
            plt.subplot(5, 4, num+1)
            plt.title(f'{classes_list[braille_letters[num]]}')
            if len(isi_original_sel[0]) > 0:
                # load linear fit for original
                slope_original, offset_original = lin_fit_original[num]
                # TODO inlcude check of uncertainty. 
                # High uncertainty -> bad fit -> bad data representation -> not reliable
                # find threshold

                # slope is a good first indicator some similarity
                if np.sign(slope_original) == np.sign(slope_braille):
                    # TODO error = max(error). No sqrt needed
                    # calc error between line fits
                    line_braille = slope_braille*np.linspace(0.0, 1.0, 10)+offset_braille
                    line_original = slope_original*np.linspace(0.0, 1.0, 10)+offset_original
                    error_of_fits = np.sqrt(np.mean((line_braille-line_original)**2))
                    error.append(error_of_fits)
                    # print('error: ', error_of_fits)
                else:
                    error.append([])
                    # print('Skipped error computation.')
            
                plt.scatter(isi_original_sel[0], isi_original_sel[1], color='tab:blue')
                plt.plot(np.linspace(0.0, 1.0, 10), slope_original*np.linspace(0.0, 1.0, 10)+offset_original, color='tab:blue') 
            else:
                error.append([])
                plt.text(0.3, 0.5, f'nbr. ISIs = {len(isi_original_sel[0])}')
            # https://matplotlib.org/stable/gallery/color/named_colors.html
            plt.scatter(isi, count, color='tab:orange')
            plt.plot(np.linspace(0.0, 1.0, 10), slope_braille*np.linspace(0.0, 1.0, 10)+offset_braille, color='tab:orange') 
            plt.xlim((0, 1.1))
            plt.ylim((0, 1.1))
            if num == 0 or num == 4 or num == 8 or num == 12 or num == 16:
                plt.ylabel('Count')
            if num > 15:
                plt.xlabel('ISI')
            bar.update(channel*len(isi_original)+num)

        error_list.append(error)
        plt.tight_layout()
        plt.savefig(
            f'{plot_out}/comparison_all_classes_channel_{channel}_scatter.png', dpi=300)
        plt.close()
    else:
        bar.update(channel*len(isi_original)+len(isi_original))
bar.finish()

[                                                                        ]   0%
[=                                                                       ]   1%
[=                                                                       ]   2%
[==                                                                      ]   3%
[===                                                                     ]   5%
[====                                                                    ]   6%
[=====                                                                   ]   7%


In [41]:
# TODO use numpy hist function to select the number of bins!
# TODO check how to calculate a probability out of hist -> same dimensions despite the input dimensions
# load ISIs from paper
norm_count = True
norm_time = True
# data_types = ['', '_noisy', '_temp_jitter', '_noisy_temp_jitter_offset']
filename = 'data_encoding_fix_len_noisy_temp_jitter_offset'
infile = open(f"{data_path_original}/{filename}.pkl", 'rb')
data_original = pickle.load(infile)
infile.close()
isi_original = return_isi_fix_len(
    data_original, max_trials, norm_count=norm_count, norm_time=norm_time)

bar = progressbar.ProgressBar(maxval=len(isi_list)*len(isi_original), \
    widgets=[progressbar.Bar('=', '[', ']'), ' ', progressbar.Percentage()])
bar.start()

# iterate channel (sensors)
for channel, entry in enumerate(isi_list):
    # TODO try numpy hist
    # hist doku: https://numpy.org/doc/stable/reference/generated/numpy.histogram.html
    # setting bin size: https://numpy.org/doc/stable/reference/generated/numpy.histogram_bin_edges.html#numpy.histogram_bin_edges
    # count_hist, isi_hist = np.histogram(entry[0])
    # print(isi_hist, count_hist)
    # extract single ISIs and their count
    isi, count = np.unique(entry[0], return_counts=True)
    # print(isi, count)
    # only plot if ISIs found in channel
    if len(isi) > 0:
        if norm_time:
            isi = isi/max(isi)
        if norm_count:
            count = count/max(count)

        # create box plots with all classes from paper and single sensor and class for Braille data
        figname = f'comparison ISI sensor: {channel}'
        plt.figure(figname, figsize=(12, 12))
        plt.suptitle(figname)

        # compare to original traces
        for num, isi_original_sel in enumerate(isi_original):
            # create scatter plot for Braille and original ISIs
            plt.subplot(5, 4, num+1)
            plt.title(f'{classes_list[braille_letters[num]]}')
            if len(isi_original_sel[0]) > 0:
                if norm_time:
                    plt.bar(isi_original_sel[0], isi_original_sel[1], width=0.01, color='tab:blue')
                else:
                    plt.bar(isi_original_sel[0], isi_original_sel[1], color='tab:blue')
            else:
                plt.text(0.3, 0.5, f'nbr. ISIs = {len(isi_original_sel[0])}')
            if norm_time:
                plt.bar(isi, count, width=0.01, color='tab:orange')
            else:
                plt.bar(isi, count, color='tab:orange')
            plt.xlim((0, 1.1))
            if num == 0 or num == 4 or num == 8 or num == 12 or num == 16:
                plt.ylabel('Count')
            if num > 15:
                plt.xlabel('ISI')
            bar.update(channel*len(isi_original)+num)

        plt.tight_layout()
        plt.savefig(
            f'{plot_out}/comparison_all_classes_channel_{channel}_bar.png', dpi=300)
        plt.close()
    else:
        bar.update(channel*len(isi_original)+len(isi_original))
bar.finish()


[                                                                        ]   0%
[=                                                                       ]   1%
[=                                                                       ]   2%
[==                                                                      ]   3%
[===                                                                     ]   5%
[====                                                                    ]   6%
[=====                                                                   ]   7%


# Find temporal evolution of predicted classes
### Given we have the output spike trains from the trained classifier for MN original classes, we can use a sliding window to determine the evolution of the networks prediction over time.

In [42]:
# data_path_original_classifier = './data/original_classifier'
# file_names = os.listdir(data_path_original_classifier)
# file_names = np.sort(file_names)
# init = True
# window_size = 50  # ms

# for _, file_name in enumerate(file_names):
#     original_classifier_spk = torch.load(
#         data_path_braille + '/' + file_name, map_location=torch.device('cpu'))
#     # convert to numpy
#     original_classifier_spk = original_classifier_spk.numpy()

#     # extract traces from single batch
#     for batch in range(mn_spk.shape[0]):
#         prediction_list_channel = []

#         # loop over all channels
#         for channel in range(original_classifier_spk[batch].shape[-1]):
#             # loop over temporal increments
#             for window in range(len(original_classifier_spk[batch][:, channel]), window_size):
#                 # calc winning class in time window
#                 prediction_list_channel.append(np.sum(original_classifier_spk[batch][window_size*window:window_size*(window+1), channel], axis=0))

#         if init:
#             # init at first run
#             prediction_list = prediction_list_channel
#             init = False
#         else:
#             # extend with ISIs found in channel per batch
#             for num, _ in enumerate(prediction_list_channel):
#                 # print(len(isi_list[num][0]), len(isi_list_channel[num][0]))
#                 prediction_list[num] = np.append(
#                     prediction_list[num], prediction_list_channel[num], axis=1)
#                 # print(len(isi_list[num][0]))


<a rel="license" href="http://creativecommons.org/licenses/by/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution 4.0 International License</a>.