In [73]:
import plotly.graph_objs as go
import numpy as np
import torch
from thesis.utils import *
from thesis.nets import *
import random
from _plotly_utils.colors.qualitative import Plotly as colors

In [82]:
layout = go.Layout(
    height = 500, width = 800,
    xaxis = dict(
        zeroline=True, zerolinecolor='#002c75'
    ),
    yaxis = dict(
        zeroline=True, zerolinecolor='#002c75'
    ),
    legend = dict(
        font = dict(size=12, color='black'),
        bgcolor = "rgba(25,211,243,0.5)",
        bordercolor="black",
        borderwidth=1
    )
)

In [55]:
from torchvision.datasets import MNIST
dataset = MNIST(root='torchMNIST', download=False,
               transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Lambda(lambda x: x * 1),
                ]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
import plotly
fig = plotly.subplots.make_subplots(rows=5, cols=10,
                                   vertical_spacing=0.04, horizontal_spacing=0.02)
for row in range(1, 6):
    for col in range(1, 11):
        batch = next(iter(dataloader))
        while batch[1].item() != col - 1:
            batch = next(iter(dataloader))
        fig.add_heatmap(
            x=list(range(24)), y=list(range(25)),
            z=batch[0].squeeze(0).squeeze(0).flip(0), coloraxis='coloraxis',
            row=row, col=col
        )
        fig.update_xaxes(mirror=True, linewidth=2, row=row, col=col, zeroline=False, linecolor='#2e4b73', tickvals=[0, 23], ticktext=['1', '24'], range=[0, 23], scaleanchor='x', scaleratio=1, constrain="domain")
        fig.update_yaxes(mirror=True, linewidth=2, row=row, col=col, zeroline=False, linecolor='#2e4b73', tickvals=[0, 23], ticktext=['1', '24'], range=[0, 23], scaleanchor='x', scaleratio=1, constrain="domain")
# fig.update_layout(layout)
fig.update_layout(height=750, width=1500, coloraxis_colorscale='gray_r',
                 margin=dict(
                 t=20, l=20, b=20, r=20))

In [56]:
fig.write_image('misc//MNIST.pdf')
fig.write_image('misc//MNIST.svg')

In [71]:
fig = go.Figure(layout=layout)
fig.add_heatmap(
    x=list(range(24)), y=list(range(25)),
    z=batch[0].squeeze(0).squeeze(0).flip(0), coloraxis='coloraxis',
)
fig.update_xaxes(mirror=True, linewidth=2, zeroline=False, linecolor='#2e4b73', tickvals=[], ticktext=[], range=[0, 23], scaleanchor='x', scaleratio=1, constrain="domain")
fig.update_yaxes(mirror=True, linewidth=2, zeroline=False, linecolor='#2e4b73', tickvals=[], ticktext=[], range=[0, 23], scaleanchor='x', scaleratio=1, constrain="domain")
fig.update_layout(height=400, width=400, coloraxis_colorscale='gray_r', coloraxis_showscale=False,
                 margin=dict(
                 t=0, l=0, b=0, r=0))

In [72]:
fig.write_image('misc//MNIST_digit.pdf')
fig.write_image('misc//MNIST_digit.svg')

In [79]:
t_neg = np.linspace(-100, 0, 100)
t_pos = np.linspace(0, 100, 100)
dw_neg = -0.0003 * np.exp(t_neg / 8)
dw_pos = 0.001 * np.exp(- t_pos / 20)

In [80]:
fig = go.Figure(layout=layout)
fig.add_scatter(x=t_neg, y=dw_neg,
                name='Негативное обновление', line_shape='spline')
fig.add_scatter(x=t_pos, y=dw_pos,
                name='Позитивное обновление', line_shape='spline')
fig.layout.xaxis.title.text = '$t_{post} - t_{pre}, 10^{-3} с$'
fig.layout.yaxis.title.text = '$\Delta w$'

fig.layout.legend.y = 1
fig.layout.legend.x = 0
fig.layout.margin.t = 20
fig.layout.margin.b = 20
fig.layout.margin.r = 20

In [81]:
fig

In [8]:
fig.write_image('misc//STDP_ru.pdf')
fig.write_image('misc//STDP_ru.svg')

fig.data[0].name = 'negative update'
fig.data[1].name = 'positive update'
fig.layout.xaxis.title.text = '$t_{post} - t_{pre}, 10^{-3} s$'

fig.write_image('misc//STDP_eng.pdf')
fig.write_image('misc//STDP_eng.svg')


In [9]:
fig = torch.load('misc//LC_SNN-bench-widget.pt')

In [10]:
fig.update_layout(layout)
fig.layout.legend.y = 0.1
fig.layout.legend.x = 0.7
fig.layout.margin.t = 40
fig.layout.margin.b = 20
fig.layout.margin.r = 20
fig.layout.margin.l = 60

In [11]:
fig

FigureWidget({
    'data': [{'error_y': {'array': [0.015387494922826133, 0.011988494484296184,
               …

In [12]:
fig.write_image('misc//LCSNN_learning_rate.pdf')
fig.write_image('misc//LCSNN_learning_rate.svg')

fig.layout.xaxis.title.text = 'Число итераций обучения'
fig.layout.yaxis.title.text = 'Точность'
fig.layout.yaxis.ticksuffix = ' '

fig.layout.title.text = ''
fig.data[0].name = 'голосование патчей'
fig.data[1].name = 'общее голосование'
fig.data[2].name = 'отбор по спайкам'
fig.data[3].name = 'линейный классификатор'

fig.write_image('misc//LCSNN_learning_rate_ru.pdf')
fig.write_image('misc//LCSNN_learning_rate_ru.svg')

In [13]:
data = view_database()
data = data[data['network_type'] == 'LC_SNN']
best_network = data.sort_values('accuracy', ascending=False)['name'].values[0]

In [14]:
net = load_network(best_network)

Created LC_SNN network 9b905a7666a70a6b2371ba1b68fa5436076b719c5bb80772e34403b3 with parameters
{'network_type': 'LC_SNN', 'mean_weight': 0.32000000000000006, 'n_iter': 5000, 'c_w': -50.0, 'c_w_min': -inf, 'time_max': 250, 'crop': 20, 'kernel_size': 12, 'kernel_prod': 144, 'stride': 4, 'n_filters': 100, 'intensity': 127.5, 'dt': 1, 'c_l': False, 'A_pos': None, 'A_neg': None, 'tau_pos': 20.0, 'tau_neg': 20.0, 'weight_decay': None, 'train_method': None, 'immutable_name': False}




source code of class 'bindsnet.network.network.Network' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.


source code of class 'bindsnet.network.nodes.Input' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.


source code of class 'bindsnet.network.nodes.AdaptiveLIFNodes' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.



In [15]:
fig = net.plot_weights_XY()
fig.update_layout(layout)
fig.layout.height = 800
fig.layout.width = 800

In [16]:
fig.write_image('misc//weights_XY.pdf')
fig.write_image('misc//weights_XY.svg')

In [17]:
fig.layout.title.text = ''
fig.layout.xaxis.title.text = 'Индекс'
fig.layout.yaxis.title.text = 'Индекс'

In [18]:
fig.write_image('misc//weights_XY_ru.pdf')
fig.write_image('misc//weights_XY_ru.svg')

In [19]:
data = view_database()
data = data[data['network_type'] == 'LC_SNN']
data = data[data['c_l'] == True]
best_network = data.sort_values('accuracy', ascending=False)['name'].values[0]
net = load_network(best_network)

Created LC_SNN network 8f0d908db7571346ab8e257bf31fddb69bb8faccbb0574e2ee8db82b with parameters
{'network_type': 'LC_SNN', 'mean_weight': 0.32, 'n_iter': 5000, 'c_w': -50, 'c_w_min': -inf, 'time_max': 250, 'crop': 20, 'kernel_size': 12, 'kernel_prod': 144, 'stride': 4, 'n_filters': 100, 'intensity': 127.5, 'dt': 1, 'c_l': True, 'A_pos': -0.3, 'A_neg': -3, 'tau_pos': 8.0, 'tau_neg': 20.0, 'weight_decay': 0, 'train_method': None, 'immutable_name': False}



In [20]:
w, fig = net.competition_distribution()
fig.update_layout(layout)

FigureWidget({
    'data': [{'type': 'histogram',
              'uid': '050bcb86-e6c0-4522-93ce-929108a11307',…

In [21]:
fig.write_image('misc//comp_distr.pdf')
fig.write_image('misc//comp_distr.svg')

In [22]:
fig.layout.xaxis.title.text = 'Вес конкуренции'
fig.layout.yaxis.title.text = 'N'
fig.layout.title.text = ''
fig

FigureWidget({
    'data': [{'type': 'histogram',
              'uid': '050bcb86-e6c0-4522-93ce-929108a11307',…

In [23]:
fig.write_image('misc//comp_distr_ru.pdf')
fig.write_image('misc//comp_distr_ru.svg')

In [24]:
w, fig = net.accuracy_distribution()
fig.layout.xaxis.range = [0, 9.5]
fig.layout.yaxis.range = [0.7, 1]
fig.update_layout(layout)

FigureWidget({
    'data': [{'error_y': {'array': array([0.00351738, 0.00477236, 0.01007173, 0.01018094, 0.011…

In [25]:
fig.write_image('misc//acc_distr.pdf')
fig.write_image('misc//acc_distr.svg')

In [26]:
fig.layout.xaxis.title.text = 'Метка'
fig.layout.yaxis.title.text = 'Точность'
fig.layout.title.text = ''
fig

FigureWidget({
    'data': [{'error_y': {'array': array([0.00351738, 0.00477236, 0.01007173, 0.01018094, 0.011…

In [27]:
fig.write_image('misc//acc_distr_ru.pdf')
fig.write_image('misc//acc_distr_ru.svg')

In [28]:
fig = net.votes_distribution()
fig.update_layout(layout)

FigureWidget({
    'data': [{'error_y': {'array': array([1.2934201 , 0.65735716, 0.42495543, 0.25497413, 0.235…

In [29]:
fig.write_image('misc//votes_distr.pdf')
fig.write_image('misc//votes_distr.svg')

In [30]:
fig.layout.xaxis.title.text = 'Значимость класса'
fig.layout.yaxis.title.text = 'Средний голос'
fig.layout.title.text = ''
fig

FigureWidget({
    'data': [{'error_y': {'array': array([1.2934201 , 0.65735716, 0.42495543, 0.25497413, 0.235…

In [31]:
fig.write_image('misc//votes_distr_ru.pdf')
fig.write_image('misc//votes_distr_ru.svg')

In [32]:
net.feed_label(1)
fig = net.plot_neuron_voltage(net.best_voters_locations[0][0], 0, 0)
fig.add_trace(go.Scatter(x=[0, net.time_max], y=[-65, -65], name='$v_{rest}$',
                         line=dict(color=colors[2], width=3, dash=None), mode='lines'))
fig.data = (fig.data[2], fig.data[0], fig.data[1])
fig.data[0].name = 'Resting potential'
fig.data[1].name = 'Voltage'
fig.data[2].name = 'Spikes'
fig.layout.legend.y = 1
fig.layout.legend.x = 0.75
fig.layout.showlegend = True
fig.update_layout(layout);

Prediction: 1


In [33]:
fig.layout.xaxis.title.text = 'Time, ms'
fig.layout.yaxis.title.text = 'Voltage, mV'
fig.layout.title.text = ''
fig

FigureWidget({
    'data': [{'line': {'color': '#00CC96', 'width': 3},
              'mode': 'lines',
        …

In [34]:
fig.write_image('misc//voltage.pdf')
fig.write_image('misc//voltage.svg')

In [35]:
fig.layout.xaxis.title.text = 'Время, мс'
fig.layout.yaxis.title.text = 'Потенциал нейрона, mV'
fig.layout.title.text = ''
fig.data[0].name = 'Уровень релаксации'
fig.data[1].name = 'Потенциал'
fig.data[2].name = 'Спайки'
fig.layout.legend.x = 0.7
fig

FigureWidget({
    'data': [{'line': {'color': '#00CC96', 'width': 3},
              'mode': 'lines',
        …

In [36]:
fig.write_image('misc//voltage_ru.pdf')
fig.write_image('misc//voltage_ru.svg')

In [37]:
data = view_database()
data = data[data['n_filters'] == 25]
data = data[data['network_type'] == 'LC_SNN']
net = load_network(data.sort_values('accuracy').iloc[-1]['name'])
fig = net.plot_votes()

Created LC_SNN network a00fb88cd2a9fb0a62c55afda5e2e30a376f5ae7b6338df6c91bf7b4 with parameters
{'network_type': 'LC_SNN', 'mean_weight': 0.46, 'n_iter': 4000, 'c_w': -100, 'c_w_min': -inf, 'time_max': 250, 'crop': 20, 'kernel_size': 12, 'kernel_prod': 144, 'stride': 4, 'n_filters': 25, 'intensity': 127.5, 'dt': 1, 'c_l': True, 'A_pos': -0.05557006922455341, 'A_neg': -1.6260407487677235, 'tau_pos': 17.720578799356897, 'tau_neg': 16.475865309906567, 'weight_decay': 0, 'train_method': None, 'immutable_name': False}



In [38]:
fig.update_layout(layout)
fig.layout.margin.l = 30
fig.layout.margin.t = 40
fig.layout.margin.b = 40
fig.layout.margin.r = 0
fig.layout.height = 250
fig

FigureWidget({
    'data': [{'colorbar': {'title': {'text': 'Vote'}},
              'colorscale': [[0.0, 'rgb(…

In [39]:
fig.write_image('misc//votes.pdf')
fig.write_image('misc//votes.svg')

In [40]:
fig.layout.xaxis.title.text = 'Индекс Y нейрона'
fig.layout.yaxis.title.text = 'Метка класса'
fig.data[0].colorbar.title = 'Голос'
fig.layout.title.text = ''

In [41]:
fig.write_image('misc//votes_ru.pdf')
fig.write_image('misc//votes_ru.svg')

In [42]:
fig = net.votes_distribution()
fig.update_layout(layout)
fig.layout.height = 250
fig

FigureWidget({
    'data': [{'error_y': {'array': array([2.2396517 , 1.2031698 , 0.75150585, 0.7565857 , 0.539…

In [43]:
fig.write_image('misc//votes_distribution.pdf')
fig.write_image('misc//votes_distribution.svg')

In [44]:
fig.layout.xaxis.title.text = 'Топ класс'
fig.layout.yaxis.title.text = 'Средний голос'
fig.layout.title.text = ''

In [45]:
fig.write_image('misc//votes_distribution_ru.pdf')
fig.write_image('misc//votes_distribution_ru.svg')

In [46]:
np.prod(max_votes.shape)

NameError: name 'max_votes' is not defined

In [None]:
net = LC_SNN(n_filters=100)
fig = net.plot_weights_XY()
fig.update_layout(layout)
fig.layout.height = 800
fig.layout.width = 800

In [None]:
fig

In [None]:
fig.write_image('misc//weights_XY_untrained.pdf')
fig.write_image('misc//weights_XY_untrained.svg')

In [None]:
fig.layout.title.text = ''
fig.layout.xaxis.title.text = 'Индекс'
fig.layout.yaxis.title.text = 'Индекс'

In [None]:
fig.write_image('misc//weights_XY_untrained_ru.pdf')
fig.write_image('misc//weights_XY_untrained_ru.svg')

In [None]:
data = view_database()
data = data[data['n_filters'] == 25]
data = data[data['c_l'] == False]
data = data[data['mean_weight'] < 0.55]
data = data[data['c_w'] > -20]
data = data[data['network_type'] == 'LC_SNN']
name = data.sort_values('accuracy').iloc[0]['name']

In [None]:
net = load_network(name)
fig = net.plot_weights_XY()

In [None]:
fig.write_image('misc//weights_XY_bad.pdf')
fig.write_image('misc//weights_XY_bad.svg')

In [None]:
fig.layout.title.text = ''
fig.layout.xaxis.title.text = 'Индекс'
fig.layout.yaxis.title.text = 'Индекс'

In [None]:
fig.write_image('misc//weights_XY_bad_ru.pdf')
fig.write_image('misc//weights_XY_bad_ru.svg')

In [None]:
data = view_database()
data = data[data['n_filters'] == 25]
data = data[data['network_type'] == 'LC_SNN']
name = data.sort_values('accuracy', ascending=False).iloc[0]['name']
net = load_network(name)

In [None]:
fig = net.plot_weights_XY()

In [None]:
fig.write_image('misc//weights_XY_good.pdf')
fig.write_image('misc//weights_XY_good.svg')

In [None]:
fig.layout.title.text = ''
fig.layout.xaxis.title.text = 'Индекс'
fig.layout.yaxis.title.text = 'Индекс'

In [None]:
fig.write_image('misc//weights_XY_good_ru.pdf')
fig.write_image('misc//weights_XY_good_ru.svg')

In [None]:
data = view_database()
data = data[data['n_filters'] == 25]
data = data[data['network_type'] == 'LC_SNN']
data = data[data['c_l'] == True]
name = data.sort_values('accuracy', ascending=False).iloc[0]['name']
net = load_network(name)

In [None]:
w, fig = net.competition_distribution()
fig.layout.title.text = fig.layout.title.text + f'<br>Accuracy: {net.accuracy}'
fig.layout.title.font.size = 24
fig.layout.margin.t = 80
fig

In [None]:
fig.write_image('misc//competition_distribution_best.pdf')
fig.write_image('misc//competition_distribution_best.svg')

In [None]:
fig.layout.title.text = f'Тончость: {net.accuracy}'
fig.layout.xaxis.title.text = 'Вес конкуренции'
fig.layout.yaxis.title.text = 'N'

In [None]:
fig.write_image('misc//competition_distribution_best_ru.pdf')
fig.write_image('misc//competition_distribution_best_ru.svg')

In [None]:
name = data.sort_values('accuracy', ascending=False).iloc[-1]['name']
net = load_network(name)

In [None]:
w, fig = net.competition_distribution()
fig.layout.title.text = fig.layout.title.text + f'<br>Accuracy: {net.accuracy}'
fig.layout.margin.t = 80
fig.layout.title.font.size = 24
fig

In [None]:
fig.write_image('misc//competition_distribution_worst.pdf')
fig.write_image('misc//competition_distribution_worst.svg')

In [None]:
fig.layout.title.text = f'Тончость: {net.accuracy}'
fig.layout.xaxis.title.text = 'Вес конкуренции'
fig.layout.yaxis.title.text = 'N'

In [None]:
fig.write_image('misc//competition_distribution_worst_ru.pdf')
fig.write_image('misc//competition_distribution_worst_ru.svg')

In [None]:
data = view_database()
data = data[data['n_filters'] == 25]
data = data[data['network_type'] == 'LC_SNN']
data = data[data['c_l'] == True]
data = data[data['accuracy'] < 0.7]
data = data[data['accuracy'] > 0.5]
name = data.sort_values('accuracy', ascending=False).iloc[len(data)//2]['name']
net = load_network(name)

In [None]:
w, fig = net.competition_distribution()
fig.layout.title.text = fig.layout.title.text + f'<br>Accuracy: {net.accuracy}'
fig.layout.margin.t = 80
fig.layout.title.font.size = 24
fig

In [None]:
fig.write_image('misc//competition_distribution_medium_good.pdf')
fig.write_image('misc//competition_distribution_medium_good.svg')

In [None]:
fig.layout.title.text = f'Тончость: {net.accuracy}'
fig.layout.xaxis.title.text = 'Вес конкуренции'
fig.layout.yaxis.title.text = 'N'

In [None]:
fig.write_image('misc//competition_distribution_medium_good_ru.pdf')
fig.write_image('misc//competition_distribution_medium_good_ru.svg')

In [None]:
data = view_database()
data = data[data['n_filters'] == 25]
data = data[data['network_type'] == 'LC_SNN']
data = data[data['c_l'] == True]
data = data[data['accuracy'] < 0.5]
data = data[data['accuracy'] > 0.3]
name = data.sort_values('accuracy', ascending=False).iloc[len(data)//2]['name']
net = load_network(name)

In [None]:
w, fig = net.competition_distribution()
fig.layout.title.text = fig.layout.title.text + f'<br>Accuracy: {net.accuracy}'
fig.layout.margin.t = 80
fig.layout.title.font.size = 24
fig

In [None]:
fig.write_image('misc//competition_distribution_medium_bad.pdf')
fig.write_image('misc//competition_distribution_medium_bad.svg')

In [None]:
fig.layout.title.text = f'Тончость: {net.accuracy}'
fig.layout.xaxis.title.text = 'Вес конкуренции'
fig.layout.yaxis.title.text = 'N'

In [None]:
fig.write_image('misc//competition_distribution_medium_bad_ru.pdf')
fig.write_image('misc//competition_distribution_medium_bad_ru.svg')

In [None]:
fig.layout.xaxis.title.text = r'$\text{text} \lambda$'

In [None]:
fig.layout.xaxis.title.font.size = 20

In [None]:
fig.layout.margin.b = 100