Skip to content

Commit

Permalink
Improve callback text
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 12, 2021
1 parent e613f1b commit b2add67
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions brian2modelfitting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,53 @@
from tqdm.autonotebook import tqdm


def _format_quantity(v, precision=3):
if isinstance(v, Quantity):
return f'{v.in_best_unit(precision=precision)}'
else:
return f'{v:.{precision}g}'


def callback_text(params, errors, best_params, best_error, index, additional_info):
"""Default callback print-out for Fitters"""
params = []
for p, v in sorted(best_params.items()):
if isinstance(v, Quantity):
params.append(f'{p}={v.in_best_unit(precision=3)}')
else:
params.append(f'{p}={v:.3g}')
params.append(f'{p}={_format_quantity(v)}')
param_str = ', '.join(params)
if isinstance(best_error, Quantity):
best_error_str = best_error.in_best_unit(precision=4)
else:
best_error_str = f'{best_error:.4g}'
round = f'Round {index}: '
if len(additional_info.get('objective_errors', [])) > 1:
best_error_str = _format_quantity(best_error, precision=4)
errors = []
for error, normed_error, varname in zip(additional_info['objective_errors'],
additional_info['objective_errors_normalized'],
additional_info['output_var']):

if not have_same_dimensions(error, normed_error) or error != normed_error:
if isinstance(error, Quantity):
raw_error_str = f', unnormalized error: {error.in_best_unit(precision=3)}'
else:
raw_error_str = f', unnormalized error: {error:.3g}'
raw_error_str = f', unnormalized error: {_format_quantity(error)}'
else:
raw_error_str = ''

if isinstance(normed_error, Quantity):
errors.append(f'{normed_error.in_best_unit(precision=3)} ({varname}{raw_error_str})')
else:
errors.append(f'{normed_error:.3g} ({varname}{raw_error_str})')
errors.append(f'{_format_quantity(normed_error)} ({varname}{raw_error_str})')

error_sum = ' + '.join(errors)
print(f"{round}Best parameters {param_str}\n"
f"{' '*len(round)}Best error: {best_error_str} = {error_sum}")
else:
print(f"{round}Best parameters {param_str}\n"
f"{' '*len(round)}Best error: {best_error_str} ({additional_info['output_var'][0]})")
print(f"{round}Best parameters {param_str}")
if 'objective_errors_normalized' in additional_info:
best_error_normed = _format_quantity(additional_info['objective_errors_normalized'][0])
best_error_raw = _format_quantity(additional_info['objective_errors'][0])
if (not have_same_dimensions(additional_info['objective_errors_normalized'][0],
additional_info['objective_errors'][0]) or
best_error_normed != best_error_raw):
print(f"{' ' * len(round)}Best error: {best_error_normed} ({additional_info['output_var'][0]}, "
f"unnormalized error: {best_error_raw})")
else:
print(f"{' ' * len(round)}Best error: {best_error_normed} ({additional_info['output_var'][0]})")
else:
best_error_str = _format_quantity(best_error, precision=4)
print(f"{' ' * len(round)}Best error: {best_error_str} ({additional_info['output_var'][0]})")



def callback_none(params, errors, best_params, best_error, index, additional_info):
Expand Down

0 comments on commit b2add67

Please sign in to comment.