Skip to content

Commit

Permalink
Reservoir states visualisation
Browse files Browse the repository at this point in the history
  • Loading branch information
nschaetti committed Jan 26, 2019
1 parent 7b196f2 commit aedbf03
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 4 deletions.
4 changes: 2 additions & 2 deletions echotorch/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# Imports
from .error_measures import nrmse, nmse, rmse, mse, perplexity, cumperplexity
from .utility_functions import spectral_radius, deep_spectral_radius, normalize, average_prob, max_average_through_time
from .visualisation import show_3d_timeseries, show_2d_timeseries, show_1d_timeseries
from .visualisation import show_3d_timeseries, show_2d_timeseries, show_1d_timeseries, neurons_activities_1d, neurons_activities_2d, neurons_activities_3d

__all__ = [
'nrmse', 'nmse', 'rmse', 'mse', 'perplexity', 'cumperplexity', 'spectral_radius', 'deep_spectral_radius',
'normalize', 'average_prob', 'max_average_through_time', 'show_3d_timeseries', 'show_2d_timeseries',
'show_1d_timeseries'
'show_1d_timeseries', 'neurons_activities_1d', 'neurons_activities_2d', 'neurons_activities_3d'
]
108 changes: 106 additions & 2 deletions echotorch/utils/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,103 @@
from mpl_toolkits.mplot3d import Axes3D


# Display neurons activities on a 3D plot
def neurons_activities_3d(stats, neurons, title, timesteps=-1, start=0):
"""
Display neurons activities on a 3D plot
:param stats:
:param neurons:
:param title:
:param timesteps:
:param start:
:return:
"""
# Fig
ax = plt.axes(projection='3d')

# Two by two
n_neurons = neurons.shape[0]
stats = stats[:, neurons].view(-1, n_neurons // 3, 3)

# Plot
if timesteps == -1:
time_length = stats.shape[0]
ax.plot3D(stats[:, :, 0].view(time_length).numpy(), stats[:, :, 1].view(time_length).numpy(), stats[:, :, 2].view(time_length).numpy(), 'o')
else:
ax.plot3D(stats[start:start + timesteps, :, 0].numpy(), stats[start:start + timesteps, :, 1].numpy(), stats[start:start + timesteps, :, 2].numpy(), 'o', lw=0.5)
# end if
ax.set_xlabel("X Axis")
ax.set_ylabel("Y Axis")
ax.set_zlabel("Z Axis")
ax.set_title(title)
plt.show()
plt.close()
# end neurons_activities_3d


# Display neurons activities on a 2D plot
def neurons_activities_2d(stats, neurons, title, colors, timesteps=-1, start=0):
"""
Display neurons activities on a 2D plot
:param stats:
:param neurons:
:param title:
:param timesteps:
:param start:
:return:
"""
# Fig
fig = plt.figure()
ax = fig.gca()

# Two by two
n_neurons = neurons.shape[0]

# For each plot
for i, stat in enumerate(stats):
# Stats
stat = stat[:, neurons].view(-1, n_neurons // 2, 2)

# Plot
if timesteps == -1:
ax.plot(stat[:, :, 0].numpy(), stat[:, :, 1].numpy(), colors[i])
else:
ax.plot(stat[start:start + timesteps, :, 0].numpy(), stat[start:start + timesteps, :, 1].numpy(), colors[i])
# end if
# end for
ax.set_xlabel("X Axis")
ax.set_ylabel("Y Axis")
ax.set_title(title)
plt.show()
plt.close()
# end neurons_activities_2d


# Display neurons activities
def neurons_activities_1d(stats, neurons, title, timesteps=-1, start=0):
"""
Display neurons activities
:param stats:
:param neurons:
:return:
"""
# Fig
fig = plt.figure()
ax = fig.gca()

if timesteps == -1:
ax.plot(stats[:, neurons].numpy())
else:
ax.plot(stats[start:start + timesteps, neurons].numpy())
# end if

ax.set_xlabel("Timesteps")
ax.set_title(title)
plt.show()
plt.close()
# end neurons_activities_1d


# Show 3D time series
def show_3d_timeseries(ts, title):
"""
Expand All @@ -26,6 +123,7 @@ def show_3d_timeseries(ts, title):
ax.set_zlabel("Z Axis")
ax.set_title(title)
plt.show()
plt.close()
# end show_3d_timeseries


Expand All @@ -46,11 +144,12 @@ def show_2d_timeseries(ts, title):
ax.set_ylabel("Y Axis")
ax.set_title(title)
plt.show()
plt.close()
# end show_2d_timeseries


# Show 1D time series
def show_1d_timeseries(ts, title):
def show_1d_timeseries(ts, title, start=0, timesteps=-1):
"""
Show 1D time series
:param ts:
Expand All @@ -61,8 +160,13 @@ def show_1d_timeseries(ts, title):
fig = plt.figure()
ax = fig.gca()

ax.plot(ts[:, 0].numpy())
if timesteps == -1:
ax.plot(ts[:, 0].numpy())
else:
ax.plot(ts[start:start+timesteps, 0].numpy())
# end if
ax.set_xlabel("X Axis")
ax.set_title(title)
plt.show()
plt.close()
# end show_1d_timeseries

0 comments on commit aedbf03

Please sign in to comment.