Refactor integrators and morphology; add implicit diffrax#34
Refactor integrators and morphology; add implicit diffrax#34chaoming0625 merged 7 commits intomainfrom
Conversation
…ivity default value
Reviewer's GuideThis PR standardizes time-step handling across integrators, extends diffrax support with implicit solvers, refactors morphology APIs for clearer position/diameter separation, and normalizes temperature units in channel classes. Sequence Diagram: Neuron Update Method with dtsequenceDiagram
participant C as Caller
participant N as Neuron (Single/MultiCompartment)
participant ENV as brainstate.environ
participant S as Solver
C->>N: update(I_ext)
N->>ENV: get('t')
ENV-->>N: t
N->>ENV: get('dt')
ENV-->>N: dt
N->>S: solver(self, t, dt, I_ext)
Sequence Diagram: apply_standard_solver_step with dt PropagationsequenceDiagram
participant SpecificSolver as Solver Interface (e.g., diffrax_euler_step)
participant apply_standard_solver_step as apply_standard_solver_step
participant target as DiffEqModule
participant transform_fn as _transform_diffeq_module_into_dimensionless_fn
participant check_deriv as _check_diffeq_state_derivative
participant actual_step as actual_solver_step (e.g., _explicit_solver)
SpecificSolver->>apply_standard_solver_step: (target, t, dt, *args)
apply_standard_solver_step->>target: pre_integral(*args)
apply_standard_solver_step->>transform_fn: (target, dt, method)
Note right of transform_fn: vector_field internally calls:
transform_fn->>target: compute_derivative(*args)
transform_fn->>check_deriv: (state, dt)
transform_fn-->>apply_standard_solver_step: dimensionless_fn, y0_dimless, ...
apply_standard_solver_step->>actual_step: (dimensionless_fn, y0_dimless, t, dt, ...)
actual_step-->>apply_standard_solver_step: y1_dimless, ...
apply_standard_solver_step->>target: post_integral(*args)
Sequence Diagram: _general_rk_step with dtsequenceDiagram
participant RKMethod as RK Method (e.g., euler_step)
participant general_rk_step as _general_rk_step
participant target as DiffEqModule
participant rk_update as _rk_update
RKMethod->>general_rk_step: (tableau, target, t, dt, *args)
general_rk_step->>target: pre_integral(*args)
loop Stages (tableau.C)
general_rk_step->>rk_update: (coeff, st, y0, dt, *ks)
general_rk_step->>target: compute_derivative(*args)
end
general_rk_step->>rk_update: (tableau.B, st, y0, dt, *ks)
general_rk_step->>target: post_integral(*args)
Updated Class Diagram: Integrator Utility Functions in _integrator_util.pyclassDiagram
class IntegratorUtil {
<<Module: _integrator_util.py>>
+T: TypeAlias
+DT: TypeAlias
+VectorFiled: TypeAlias
+Y0: TypeAlias
+Y1: TypeAlias
+Jacobian: TypeAlias
+Args: TypeAlias
+Aux: TypeAlias
+_check_diffeq_state_derivative(state: DiffEqState, dt: DT)
+_transform_diffeq_module_into_dimensionless_fn(target: DiffEqModule, dt: DT, method: str) : tuple
+apply_standard_solver_step(solver_step: Callable, target: DiffEqModule, t: T, dt: DT, *args, merging_method: str)
+jacrev_last_dim(fn: Callable, hid_vals: Y0, has_aux: bool) : tuple
}
note for IntegratorUtil "Focus on changed/added dt parameter and new type hints."
Updated Class Diagram: Diffrax Integrators in _integrator_diffrax.pyclassDiagram
class DiffraxIntegrators {
<<Module: _integrator_diffrax.py>>
+_explicit_solver(solver, fn: VectorFiled, y0: Y0, t0: T, dt: DT, args=()) : tuple
+_diffrax_explicit_solver(solver, target: DiffEqModule, t: T, dt: DT, *args)
+diffrax_euler_step(target: DiffEqModule, t: T, dt: DT, *args)
+diffrax_heun_step(target: DiffEqModule, t: T, dt: DT, *args)
+diffrax_midpoint_step(target: DiffEqModule, t: T, dt: DT, *args)
+diffrax_ralston_step(target: DiffEqModule, t: T, dt: DT, *args)
+diffrax_bosh3_step(target: DiffEqModule, t: T, dt: DT, *args)
+diffrax_tsit5_step(target: DiffEqModule, t: T, dt: DT, *args)
+diffrax_dopri5_step(target: DiffEqModule, t: T, dt: DT, *args)
+diffrax_dopri8_step(target: DiffEqModule, t: T, dt: DT, *args)
+diffrax_bwd_euler_step(target: DiffEqModule, t: T, dt: DT, *args, tol: float) // New
+_implicit_solver(solver, fn: VectorFiled, y0: Y0, t0: T, dt: DT, args=()) : tuple // New
+_diffrax_implicit_solver(solver, target: DiffEqModule, t: T, dt: DT, *args) // New
}
note for DiffraxIntegrators "All listed functions now take 'dt: DT'. New implicit solvers added."
Updated Class Diagram: Runge-Kutta Integrators in _integrator_runge_kutta.pyclassDiagram
class RungeKuttaIntegrators {
<<Module: _integrator_runge_kutta.py>>
+_rk_update(coeff: Sequence, st: brainstate.State, y0: PyTree, dt: DT, *ks)
+_general_rk_step(tableau: ButcherTableau, target: DiffEqModule, t: T, dt: DT, *args)
+euler_step(target: DiffEqModule, t: T, dt: DT, *args)
+midpoint_step(target: DiffEqModule, t: T, dt: DT, *args)
+rk2_step(target: DiffEqModule, t: T, dt: DT, *args)
+heun2_step(target: DiffEqModule, t: T, dt: DT, *args)
+ralston2_step(target: DiffEqModule, t: T, dt: DT, *args)
+rk3_step(target: DiffEqModule, t: T, dt: DT, *args)
+heun3_step(target: DiffEqModule, t: T, dt: DT, *args)
+ssprk3_step(target: DiffEqModule, t: T, dt: DT, *args)
+ralston3_step(target: DiffEqModule, t: T, dt: DT, *args)
+rk4_step(target: DiffEqModule, t: T, dt: DT, *args)
+ralston4_step(target: DiffEqModule, t: T, dt: DT, *args)
}
note for RungeKuttaIntegrators "All listed functions now take 'dt: DT'."
Updated Class Diagram: Exponential Euler Integrator in _integrator_exp_euler.pyclassDiagram
class ExpEulerIntegrator {
<<Module: _integrator_exp_euler.py>>
+exp_euler_step(target: DiffEqModule, t: T, dt: DT, *args)
}
note for ExpEulerIntegrator "Function 'exp_euler_step' now takes 'dt: DT'."
Updated Class Diagram: Morphology Classes in _morphology.pyclassDiagram
class Section {
+name: SectionName
+positions: u.Quantity[u.meter]
+diam: u.Quantity[u.meter]
+nseg: int
+Ra: u.Quantity[u.ohm * u.cm]
+cm: u.Quantity[u.uF / u.cm ** 2]
+__init__(name, positions, diam, nseg, Ra, cm)
#_compute_area_and_resistance()
}
class CylinderSection {
+__init__(name, length: u.Quantity, diam: u.Quantity, nseg, Ra, cm)
}
class PointSection {
+__init__(name, position: u.Quantity[u.meter], diam: u.Quantity[u.meter], nseg, Ra, cm)
}
class Morphology {
+sections: dict
+add_cylinder_section(name, length, diam, nseg, Ra, cm)
+add_point_section(name, position: u.Quantity[u.meter], diam: u.Quantity[u.meter], nseg, Ra, cm)
+add_multiple_sections(section_dicts: Dict)
}
CylinderSection --|> Section
PointSection --|> Section
Morphology o-- Section : contains
note for PointSection "__init__ now takes 'position' and 'diam' instead of 'points'."
note for Morphology "add_point_section signature updated. add_multiple_sections logic for point sections updated."
Updated Class Diagram: Channel Temperature Handling (Example: ICaT_HM1992)classDiagram
class ICaT_HM1992 {
<<Channel>>
+g_max: ArrayLike
+T: ArrayLike
+__init__(size, T=u.celsius2kelvin(36.), T_base_p, T_base_q, g_max, V_sh, phi_p=None, phi_q=None, name=None)
# Internal: T = u.kelvin2celsius(T) upon initialization
}
note for ICaT_HM1992 "'T' parameter default is now in Kelvin (u.celsius2kelvin(36.)) and converted to Celsius internally. Applies to other channel classes as well."
Updated Class Diagram: Neuron Compartment ModelsclassDiagram
class SingleCompartmentNeuron {
+update(I_ext=0. * u.nA / u.cm ** 2)
# solver(self, t, dt, I_ext) # Behavior changed to use dt
}
class MultiCompartmentNeuron {
+pop_size: Tuple[int, ...]
+n_compartment: int
+update(I_ext=0. * u.nA)
# solver(self, t, dt, I_ext) # Behavior changed to use dt
}
note for SingleCompartmentNeuron "update() method now passes 'dt' to its internal solver."
note for MultiCompartmentNeuron "update() method now passes 'dt' to its internal solver. Added pop_size and n_compartment properties."
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
@sourcery-ai title |
There was a problem hiding this comment.
Hey @chaoming0625 - I've reviewed your changes - here's some feedback:
- Consider deprecating the old
pointsargument in morphology APIs (and emitting a warning) while still supporting it for a transition period to avoid breaking existing user code. - The repeated temperature conversion pattern (
T = u.kelvin2celsius(T)) across multiple channel constructors could be factored into a shared helper to reduce duplication. - Please update the PR description to summarize the major API changes (integrator signature updates, morphology refactoring, channel temperature handling) for clearer context and easier maintenance.
Here's what I looked at during the review
- 🟡 General issues: 2 issues found
- 🟢 Security: all looks good
- 🟡 Testing: 1 issue found
- 🟡 Complexity: 1 issue found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| def _implicit_solver(solver, fn: VectorFiled, y0: Y0, t0: T, dt: DT, args=()): | ||
| dt = u.Quantity(dt) | ||
| t0 = u.Quantity(t0).to_decimal(dt.unit) | ||
| dt = u.get_magnitude(dt) | ||
| y1 = solver.step( | ||
| diffrax.ODETerm(lambda t, y, args_: fn(t, y, *args_)[0]), | ||
| t0, | ||
| t0 + dt, | ||
| y0, | ||
| args, | ||
| None, | ||
| made_jump=False | ||
| )[0] | ||
| return y1, {} |
There was a problem hiding this comment.
suggestion: The new _implicit_solver function does not handle solver state or auxiliary outputs.
This implementation won't support solvers or ODETerm functions that require state or auxiliary outputs. Please verify if this limitation is acceptable for your use cases.
| def _implicit_solver(solver, fn: VectorFiled, y0: Y0, t0: T, dt: DT, args=()): | |
| dt = u.Quantity(dt) | |
| t0 = u.Quantity(t0).to_decimal(dt.unit) | |
| dt = u.get_magnitude(dt) | |
| y1 = solver.step( | |
| diffrax.ODETerm(lambda t, y, args_: fn(t, y, *args_)[0]), | |
| t0, | |
| t0 + dt, | |
| y0, | |
| args, | |
| None, | |
| made_jump=False | |
| )[0] | |
| return y1, {} | |
| def _implicit_solver( | |
| solver, | |
| fn: VectorFiled, | |
| y0: Y0, | |
| t0: T, | |
| dt: DT, | |
| args=(), | |
| solver_state=None, | |
| return_aux=False, | |
| ): | |
| dt = u.Quantity(dt) | |
| t0 = u.Quantity(t0).to_decimal(dt.unit) | |
| dt = u.get_magnitude(dt) | |
| term = diffrax.ODETerm(lambda t, y, args_: fn(t, y, *args_)[0]) | |
| # If solver_state is None, initialize it as required by the solver | |
| if solver_state is None and hasattr(solver, "init"): | |
| solver_state = solver.init(term, t0, y0, args) | |
| # Call step with or without solver_state depending on solver API | |
| step_kwargs = dict( | |
| term=term, | |
| t0=t0, | |
| t1=t0 + dt, | |
| y0=y0, | |
| args=args, | |
| solver_state=solver_state, | |
| made_jump=False, | |
| ) | |
| # Remove solver_state if not supported by the solver | |
| import inspect | |
| if "solver_state" not in inspect.signature(solver.step).parameters: | |
| step_kwargs.pop("solver_state") | |
| result = solver.step(**step_kwargs) | |
| # result may be (y1, solver_state, aux) or (y1, solver_state) | |
| if return_aux and len(result) == 3: | |
| y1, solver_state, aux = result | |
| return y1, solver_state, aux | |
| elif len(result) == 2: | |
| y1, solver_state = result | |
| return y1, solver_state | |
| else: | |
| # fallback for legacy API | |
| y1 = result[0] | |
| return y1, {} |
| # Construct conductance matrix for the model | ||
| print(morphology.conductance_matrix) | ||
| print(morphology.area) |
There was a problem hiding this comment.
suggestion (testing): Replace print statements with explicit assertions.
Use assertions to automatically verify the shapes or values of conductance_matrix and area, rather than relying on manual inspection via print statements.
| # Construct conductance matrix for the model | |
| print(morphology.conductance_matrix) | |
| print(morphology.area) | |
| # Construct conductance matrix for the model | |
| assert morphology.conductance_matrix is not None, "Conductance matrix should not be None" | |
| assert hasattr(morphology.conductance_matrix, "shape"), "Conductance matrix should have a shape attribute" | |
| assert morphology.area is not None, "Area should not be None" | |
| assert hasattr(morphology.area, "shape") or isinstance(morphology.area, (int, float)), "Area should have a shape attribute or be a number" |
| ## Reporting a bug in ``BrainCell`` | ||
|
|
||
| Report security bugs in ``dendritex`` via [Github Issue](https://github.com/chaoming0625/dendritex/issues). | ||
| Report security bugs in ``BrainCell`` via [Github Issue](https://github.com/chaobrain/braincell/issues). |
There was a problem hiding this comment.
suggestion (typo): Consider standard capitalization for 'GitHub'.
Please update 'Github Issue' to 'GitHub Issue' in the link text.
| Report security bugs in ``BrainCell`` via [Github Issue](https://github.com/chaobrain/braincell/issues). | |
| Report security bugs in ``BrainCell`` via [GitHub Issue](https://github.com/chaobrain/braincell/issues). |
| @@ -197,7 +198,8 @@ def _general_rk_step( | |||
| @set_module_as('braincell') | |||
| def euler_step( | |||
There was a problem hiding this comment.
issue (complexity): Consider introducing a decorator to generate the repetitive Runge-Kutta step functions automatically.
# add at top of module
def register_rk_step(tableau):
def decorator(func):
@set_module_as('braincell')
def wrapper(target: DiffEqModule, t: T, dt: DT, *args):
_general_rk_step(tableau, target, t, dt, *args)
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
return wrapper
return decoratorThen simplify each explicit step:
Before:
@set_module_as('braincell')
def euler_step(target: DiffEqModule, t: T, dt: DT, *args):
"""Perform a single step of the Euler method..."""
_general_rk_step(euler_tableau, target, t, dt, *args)After:
@register_rk_step(euler_tableau)
def euler_step(target, t, dt, *args):
"""Perform a single step of the Euler method..."""
# body is empty; wrapper does the workDo the same for all other *_step definitions. This removes the repetitive _general_rk_step(...) boilerplate while preserving signatures and docstrings.
…rential equation integration
|
@sourcery-ai summary |
Summary by Sourcery
Refactor integrators to uniformly accept explicit time-step arguments, improve unit handling, and add new implicit diffrax methods; update Runge-Kutta and exponential Euler integrators to pass dt explicitly; enhance morphology API to separate position and diameter arrays with stricter typing; convert channel temperature parameters to use Kelvin quantities internally; update compartment solvers and tests to match new signatures; rename project references in SECURITY policy.
New Features:
Enhancements:
Documentation:
Tests: