In [None]:
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go
import numpy as np
import pickle

In [None]:
mnist_t_max = 16
mnist_percentiles = [97, 98, 99, 99.9, 99.99, 99.999]
mnist_early_spikes = np.array([5.9887309074401855, 3.43815279006958, 1.6641639471054077, 0.11234107613563538, 0.008053530007600784, 0.001304976874962449])
mnist_errors = 100 - np.array([98.72999787330627, 99.14000034332275, 99.1599977016449, 99.11999702453613, 99.11999702453613, 99.09999966621399])

cifar_tmaxs = [2**exponent for exponent in range(3, 7)]
cifar_error = 100 - np.array([83.30999612808228, 91.3599967956543, 92.14999675750732, 92.25999712944031])

imagenet_tmaxs = [2**exponent for exponent in range(4, 7)]
imagenet_error = 100 - np.array([54.44, 64.44, 67.24])
imagenet_times = [146, 290, 578]

In [None]:
df = pd.DataFrame({
    'Dataset': ['MNIST'] * len(mnist_percentiles),
    'Activation normalization percentile': [str(percentile) for percentile in mnist_percentiles], #  mnist_percentiles, # 
    '% early spikes': mnist_early_spikes, # np.concatenate(
    '% accuracy error': mnist_errors, 
})

fig = px.bar(df, x="Activation normalization percentile", y="% accuracy error", facet_col="Dataset", color='% early spikes')
fig.show()

In [None]:
mnist = pickle.load(open('../mnist/mnist-results.pkl', 'rb'))
df = pd.DataFrame()

for percentile, data_dict in mnist.items():
    for t_max, metrics in data_dict.items():
        df_temp = pd.DataFrame({
            'Dataset': ['MNIST'],
            'Normalization percentile': [percentile],
            'Tmax': [t_max],
            'Time steps': [metrics['time steps']],
            'Accuracy error': [100 - metrics['acc']],
            'Early spikes': [metrics['early spikes']],
            'Number of operations per frame': [metrics['n_ops']],
            'Neurons': [metrics['n_neurons']],
            'SynOps': [metrics['n_synops']]
        })
        df = pd.concat((df, df_temp))
df = df.reset_index()

fig = px.bar(df, x='Normalization percentile', y='Accuracy error', color='Early spikes', template="plotly_white")
fig.update_layout(
    margin=dict(l=20, r=20, t=0, b=20),
    # autosize=False,
    width=700,
    height=300,
)
# fig.update_xaxes(matches=None, showticklabels=True)
fig.show()
fig.write_image('early_spikes.png')