In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
mnist_tmaxs = [2**exponent for exponent in range(1, 6)]
mnist_error = 100 - np.array([89.68999981880188, 98.10999631881714, 99.04999732971191, 99.1599977016449, 99.24999475479126])

cifar_tmaxs = [2**exponent for exponent in range(3, 7)]
cifar_error = 100 - np.array([83.30999612808228, 91.3599967956543, 92.14999675750732, 92.25999712944031])

imagenet_tmaxs = [2**exponent for exponent in range(4, 7)]
imagenet_error = 100 - np.array([54.44, 64.44, 67.24])
imagenet_times = [146, 290, 578]

In [None]:
# plt.scatter(mnist_tmaxs, mnist_error)
# plt.plot(mnist_tmaxs, mnist_error, label='MNIST')
# plt.scatter(cifar_tmaxs, cifar_error)
# plt.plot(cifar_tmaxs, cifar_error, label="CIFAR10")
# plt.scatter(imagenet_tmaxs, imagenet_error)
# plt.plot(imagenet_tmaxs, imagenet_error, label="ImageNet")
# plt.xscale('log')
# plt.xlim(right=100)
# plt.ylabel("Accuracy error [%]")
# plt.xlabel(r"$T_{\mathrm{max}}$ per layer")
# plt.legend()
# plt.grid()
# plt.tight_layout()
# plt.savefig("simulation_accs.png")

In [None]:
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go

mnist_error = 100 - np.array([89.68999981880188, 98.10999631881714, 99.04999732971191, 99.1599977016449, 99.24999475479126])
mnist_ops = np.array([49, 66, 101, 171, 310])*1e3

cifar_error = 100 - np.array([83.30999612808228, 91.3599967956543, 92.14999675750732, 92.25999712944031])
cifar_ops = np.array([4.356, 6.912, 12.024, 22.249]) * 1e6

imagenet_error = 100 - np.array([54.44, 64.44, 67.24])
imagenet_ops = np.array([537.928, 775.855, 1251.711]) * 1e6


df = pd.DataFrame({
    'Dataset': ['MNIST'] * 5 + ['CIFAR10'] * 4 + ['ImageNet'] * 3,
    'Text': [f"{2**exponent}" for exponent in range(1, 6)] + [f"{2**exponent}" for exponent in range(3, 7)] + [f"{2**exponent}" for exponent in range(4, 7)],
    '% accuracy error': np.concatenate((mnist_error, cifar_error, imagenet_error)),
    'Ops per frame': np.concatenate((mnist_ops, cifar_ops, imagenet_ops)), 
})

df.loc[0, 'Text'] = "$T_{\mathrm{max}}=\mathrm{2}$"
df.loc[5, 'Text'] = "$T_{\mathrm{max}}=\mathrm{8}$"
df.loc[9, 'Text'] = "$T_{\mathrm{max}}=\mathrm{16}$"

fig = px.line(df, x="Ops per frame", y="% accuracy error", facet_col="Dataset", text="Text")

fig.update_traces(textposition=["middle right", "top right", "bottom left", "top center", "middle right"])

ANN_kws = dict(showlegend=False, mode="markers+text", text=['ANN '], textposition="middle left", marker_color=0.5)
fig.add_trace(go.Scatter(x=[45.168e6], y=[100-99.22999739646912], **ANN_kws), row = 1, col = 1)
fig.add_trace(go.Scatter(x=[21975.072768e6], y=[100-92.38999485969543], **ANN_kws), row = 1, col = 2)
fig.add_trace(go.Scatter(x=[974005.272576e6], y=[100-68.11000108718872], **ANN_kws), row = 1, col = 3)

fig.update_xaxes(type="log", matches=None, showticklabels=True)
fig.update_yaxes(matches=None, showticklabels=True)

import plotly.io as pio
pio.kaleido.scope.default_format = "svg"
pio.kaleido.scope.default_width = 1000
pio.kaleido.scope.default_height = 350

# fig['layout']['margin'] = dict(l=20, t=20)#['t'] = 0
# fig.show()
# fig.update_layout(margin=dict(l=20, r=20, t=20, b=20))
fig.write_image('n_ops.png')

# from IPython.display import Image
# img_bytes = fig.to_image(format="png", width=1000, height=350, scale=1)
# Image(img_bytes)