Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example to visualize the evolutuion of fitted parameters in bounded parameter space #43

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

akapet00
Copy link
Member

@akapet00 akapet00 commented Apr 7, 2021

Simple example that shows the error (and parameters) evolution in time by using the nevergrad optimizer.
Voltage traces are fitted using the default approach with the sample size of 40 for 25 rounds of fitting.
The first figure shows fitted traces and the input current that drives each neuron.
The second figure (animation) shows parameter population per each round of fitting (frame by frame), where the highlighted parameters are the one with the smallest overall error.
There is no additional dependencies used for this example other than matplotlib which is already used by brian2 by default.

Additionally, custom callback is also added in the fitter.fit(...) method which is very similar to the one that can be found in the documentation under the Advanced Features section: Callback Function subsection.

@mstimberg
Copy link
Member

This looks very nice, thank you! I will wait with the merge until I have the time to do a final thorough review in ~2 weeks (don't worry that this will be after the GSoC deadline).

@mstimberg
Copy link
Member

Will still do a bit of minor nit-picking later, but here's a thing I just encountered: while running the example, I get the following error (multiple times) during the 3D plotting:

Traceback (most recent call last):
  File "/home/marcel/anaconda3/envs/brian2modelfitting/lib/python3.7/site-packages/matplotlib/backend_bases.py", line 1187, in _on_timer
    ret = func(*args, **kwargs)
  File "/home/marcel/anaconda3/envs/brian2modelfitting/lib/python3.7/site-packages/matplotlib/animation.py", line 1449, in _step
    still_going = Animation._step(self, *args)
  File "/home/marcel/anaconda3/envs/brian2modelfitting/lib/python3.7/site-packages/matplotlib/animation.py", line 1169, in _step
    self._draw_next_frame(framedata, self._blit)
  File "/home/marcel/anaconda3/envs/brian2modelfitting/lib/python3.7/site-packages/matplotlib/animation.py", line 1188, in _draw_next_frame
    self._draw_frame(framedata)
  File "/home/marcel/anaconda3/envs/brian2modelfitting/lib/python3.7/site-packages/matplotlib/animation.py", line 1766, in _draw_frame
    self._drawn_artists = self._func(framedata, *self._args)
  File "/home/marcel/programming/brian2modelfitting/examples/hh_nevergrad_errorevolution.py", line 153, in animate
    'r*', markersize=8, label='best params')
  File "/home/marcel/anaconda3/envs/brian2modelfitting/lib/python3.7/site-packages/mpl_toolkits/mplot3d/axes3d.py", line 1421, in plot
    zs = np.broadcast_to(zs, len(xs))
TypeError: len() of unsized object

In case it matters, this is with matplotlib 3.2.1. Any idea what could cause this?

@akapet00
Copy link
Member Author

akapet00 commented Apr 21, 2021

The problem seems to be related to the version of matplotlib you are using. I run the script seamlessly when using newer version, concretely:

(brian) alk@alk:~/github/brian2modelfitting/examples$ conda list | grep matplotlib
matplotlib                3.3.4            py38h06a4308_0  
matplotlib-base           3.3.4            py38h62a2d02_0

The mentioned problem is related to how 3-D plotting is performed in animate function. For some reason matplotlib calls len on the data and in this case the data is a number and not iterable:

ax1.plot3D(res['g_k'], res['g_na'], res['g_l'], ...)

This can be solved easily by inserting the data to a list:

ax1.plot3D([res['g_k']], [res['g_na']], [res['g_l']], ...)

The error occurred multiple times because the animation was repeated until it was stopped manually.
Setting repeat=False in FuncAnimation resolved the issue.

Should I commit these changes?

@mstimberg
Copy link
Member

Ah, thanks for looking into this. I think it makes sense to have one element lists in the plot3D calls, this is also more consistent with all the other calls.

Copy link
Member

@mstimberg mstimberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looking great, I made a few minor comments but nothing super important.

from brian2 import *
from brian2modelfitting import *
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Axes3D import does not seem to be used.


def callback(params, errors, best_params, best_error, index):
"""Custom callback"""
print(f'[round {index + 1}]\t{np.min(errors)}')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really minor thing, but maybe add an "Best error: " before the error to make the output a bit more meaningful.

g_k=[6.e-07 * siemens, 6.e-05 * siemens])

# visualization of best fitted traces
start_scope()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The start_scope shouldn't be needed here.

Comment on lines 111 to 116
# visualization of errors and parameters evolving over time
full_output = fitter.results(format='dataframe', use_units=False)
g_k = full_output['g_k'].to_numpy()
g_na = full_output['g_na'].to_numpy()
g_l = full_output['g_l'].to_numpy()
error = full_output['error'].to_numpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yet another minor thing: the code might be easier to follow if it did not use the dataframe format here (that gets converted into numpy arrays right away), but instead format='dict'.

Comment on lines 121 to 125
ax1.set_xlabel('$g_K$ [S]')
ax1.set_ylabel('$g_{Na}$ [S]')
ax1.set_zlabel('$g_l$ [S]')
ax2.set_xlabel('round')
ax2.set_ylabel('error')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of a matter of taste, but I prefer the slightly more compact set syntax, e.g.

ax1.set(xlabel='$g_K$ [S]', ylabel='$g_{Na}$ [S]', zlabel='$g_l$ [S]')

ax2.grid()


def init():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the init and animate function could use a short comment for those not familiar with FuncAnimation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants