Skip to content

Refactor integrators and morphology; add implicit diffrax#34

Merged
chaoming0625 merged 7 commits intomainfrom
update
May 23, 2025
Merged

Refactor integrators and morphology; add implicit diffrax#34
chaoming0625 merged 7 commits intomainfrom
update

Conversation

@chaoming0625
Copy link
Copy Markdown
Collaborator

@chaoming0625 chaoming0625 commented May 22, 2025

Summary by Sourcery

Refactor integrators to uniformly accept explicit time-step arguments, improve unit handling, and add new implicit diffrax methods; update Runge-Kutta and exponential Euler integrators to pass dt explicitly; enhance morphology API to separate position and diameter arrays with stricter typing; convert channel temperature parameters to use Kelvin quantities internally; update compartment solvers and tests to match new signatures; rename project references in SECURITY policy.

New Features:

  • Introduce implicit diffrax integration methods: backward Euler, Kvaerno3, Kvaerno4, and Kvaerno5.

Enhancements:

  • Unify all integrators to accept explicit dt arguments and convert units appropriately in explicit and RK solvers.
  • Refactor _integrator_util with type aliases, improved unit checks, and dt propagation.
  • Update morphology API to use separate position and diam arrays with enhanced type checking and assertions.
  • Modify channel constructors to default temperature parameters in Kelvin and convert to Celsius internally.
  • Pass dt from environment to single- and multi-compartment solvers.

Documentation:

  • Update SECURITY.md to rename project to BrainCell and correct GitHub issue link for security reports.

Tests:

  • Adjust morphology tests to use new position and diam parameters for point sections.

@sourcery-ai
Copy link
Copy Markdown
Contributor

sourcery-ai Bot commented May 22, 2025

Reviewer's Guide

This PR standardizes time-step handling across integrators, extends diffrax support with implicit solvers, refactors morphology APIs for clearer position/diameter separation, and normalizes temperature units in channel classes.

Sequence Diagram: Neuron Update Method with dt

sequenceDiagram
    participant C as Caller
    participant N as Neuron (Single/MultiCompartment)
    participant ENV as brainstate.environ
    participant S as Solver

    C->>N: update(I_ext)
    N->>ENV: get('t')
    ENV-->>N: t
    N->>ENV: get('dt')
    ENV-->>N: dt
    N->>S: solver(self, t, dt, I_ext)
Loading

Sequence Diagram: apply_standard_solver_step with dt Propagation

sequenceDiagram
    participant SpecificSolver as Solver Interface (e.g., diffrax_euler_step)
    participant apply_standard_solver_step as apply_standard_solver_step
    participant target as DiffEqModule
    participant transform_fn as _transform_diffeq_module_into_dimensionless_fn
    participant check_deriv as _check_diffeq_state_derivative
    participant actual_step as actual_solver_step (e.g., _explicit_solver)

    SpecificSolver->>apply_standard_solver_step: (target, t, dt, *args)
    apply_standard_solver_step->>target: pre_integral(*args)
    apply_standard_solver_step->>transform_fn: (target, dt, method)
    Note right of transform_fn: vector_field internally calls:
    transform_fn->>target: compute_derivative(*args)
    transform_fn->>check_deriv: (state, dt)
    transform_fn-->>apply_standard_solver_step: dimensionless_fn, y0_dimless, ...
    apply_standard_solver_step->>actual_step: (dimensionless_fn, y0_dimless, t, dt, ...)
    actual_step-->>apply_standard_solver_step: y1_dimless, ...
    apply_standard_solver_step->>target: post_integral(*args)
Loading

Sequence Diagram: _general_rk_step with dt

sequenceDiagram
    participant RKMethod as RK Method (e.g., euler_step)
    participant general_rk_step as _general_rk_step
    participant target as DiffEqModule
    participant rk_update as _rk_update

    RKMethod->>general_rk_step: (tableau, target, t, dt, *args)
    general_rk_step->>target: pre_integral(*args)
    loop Stages (tableau.C)
        general_rk_step->>rk_update: (coeff, st, y0, dt, *ks)
        general_rk_step->>target: compute_derivative(*args)
    end
    general_rk_step->>rk_update: (tableau.B, st, y0, dt, *ks) 
    general_rk_step->>target: post_integral(*args)
Loading

Updated Class Diagram: Integrator Utility Functions in _integrator_util.py

classDiagram
    class IntegratorUtil {
        <<Module: _integrator_util.py>>
        +T: TypeAlias
        +DT: TypeAlias
        +VectorFiled: TypeAlias
        +Y0: TypeAlias
        +Y1: TypeAlias
        +Jacobian: TypeAlias
        +Args: TypeAlias
        +Aux: TypeAlias
        +_check_diffeq_state_derivative(state: DiffEqState, dt: DT)
        +_transform_diffeq_module_into_dimensionless_fn(target: DiffEqModule, dt: DT, method: str) : tuple
        +apply_standard_solver_step(solver_step: Callable, target: DiffEqModule, t: T, dt: DT, *args, merging_method: str)
        +jacrev_last_dim(fn: Callable, hid_vals: Y0, has_aux: bool) : tuple
    }
    note for IntegratorUtil "Focus on changed/added dt parameter and new type hints."
Loading

Updated Class Diagram: Diffrax Integrators in _integrator_diffrax.py

classDiagram
    class DiffraxIntegrators {
        <<Module: _integrator_diffrax.py>>
        +_explicit_solver(solver, fn: VectorFiled, y0: Y0, t0: T, dt: DT, args=()) : tuple
        +_diffrax_explicit_solver(solver, target: DiffEqModule, t: T, dt: DT, *args)
        +diffrax_euler_step(target: DiffEqModule, t: T, dt: DT, *args)
        +diffrax_heun_step(target: DiffEqModule, t: T, dt: DT, *args)
        +diffrax_midpoint_step(target: DiffEqModule, t: T, dt: DT, *args)
        +diffrax_ralston_step(target: DiffEqModule, t: T, dt: DT, *args)
        +diffrax_bosh3_step(target: DiffEqModule, t: T, dt: DT, *args)
        +diffrax_tsit5_step(target: DiffEqModule, t: T, dt: DT, *args)
        +diffrax_dopri5_step(target: DiffEqModule, t: T, dt: DT, *args)
        +diffrax_dopri8_step(target: DiffEqModule, t: T, dt: DT, *args)
        +diffrax_bwd_euler_step(target: DiffEqModule, t: T, dt: DT, *args, tol: float)  // New
        +_implicit_solver(solver, fn: VectorFiled, y0: Y0, t0: T, dt: DT, args=()) : tuple // New
        +_diffrax_implicit_solver(solver, target: DiffEqModule, t: T, dt: DT, *args) // New
    }
    note for DiffraxIntegrators "All listed functions now take 'dt: DT'. New implicit solvers added."
Loading

Updated Class Diagram: Runge-Kutta Integrators in _integrator_runge_kutta.py

classDiagram
    class RungeKuttaIntegrators {
        <<Module: _integrator_runge_kutta.py>>
        +_rk_update(coeff: Sequence, st: brainstate.State, y0: PyTree, dt: DT, *ks)
        +_general_rk_step(tableau: ButcherTableau, target: DiffEqModule, t: T, dt: DT, *args)
        +euler_step(target: DiffEqModule, t: T, dt: DT, *args)
        +midpoint_step(target: DiffEqModule, t: T, dt: DT, *args)
        +rk2_step(target: DiffEqModule, t: T, dt: DT, *args)
        +heun2_step(target: DiffEqModule, t: T, dt: DT, *args)
        +ralston2_step(target: DiffEqModule, t: T, dt: DT, *args)
        +rk3_step(target: DiffEqModule, t: T, dt: DT, *args)
        +heun3_step(target: DiffEqModule, t: T, dt: DT, *args)
        +ssprk3_step(target: DiffEqModule, t: T, dt: DT, *args)
        +ralston3_step(target: DiffEqModule, t: T, dt: DT, *args)
        +rk4_step(target: DiffEqModule, t: T, dt: DT, *args)
        +ralston4_step(target: DiffEqModule, t: T, dt: DT, *args)
    }
    note for RungeKuttaIntegrators "All listed functions now take 'dt: DT'."
Loading

Updated Class Diagram: Exponential Euler Integrator in _integrator_exp_euler.py

classDiagram
    class ExpEulerIntegrator {
        <<Module: _integrator_exp_euler.py>>
        +exp_euler_step(target: DiffEqModule, t: T, dt: DT, *args)
    }
    note for ExpEulerIntegrator "Function 'exp_euler_step' now takes 'dt: DT'."
Loading

Updated Class Diagram: Morphology Classes in _morphology.py

classDiagram
    class Section {
        +name: SectionName
        +positions: u.Quantity[u.meter]
        +diam: u.Quantity[u.meter]
        +nseg: int
        +Ra: u.Quantity[u.ohm * u.cm]
        +cm: u.Quantity[u.uF / u.cm ** 2]
        +__init__(name, positions, diam, nseg, Ra, cm)
        #_compute_area_and_resistance()
    }
    class CylinderSection {
        +__init__(name, length: u.Quantity, diam: u.Quantity, nseg, Ra, cm)
    }
    class PointSection {
        +__init__(name, position: u.Quantity[u.meter], diam: u.Quantity[u.meter], nseg, Ra, cm)
    }
    class Morphology {
        +sections: dict
        +add_cylinder_section(name, length, diam, nseg, Ra, cm)
        +add_point_section(name, position: u.Quantity[u.meter], diam: u.Quantity[u.meter], nseg, Ra, cm)
        +add_multiple_sections(section_dicts: Dict)
    }
    CylinderSection --|> Section
    PointSection --|> Section
    Morphology o-- Section : contains
    note for PointSection "__init__ now takes 'position' and 'diam' instead of 'points'."
    note for Morphology "add_point_section signature updated. add_multiple_sections logic for point sections updated."
Loading

Updated Class Diagram: Channel Temperature Handling (Example: ICaT_HM1992)

classDiagram
    class ICaT_HM1992 {
        <<Channel>>
        +g_max: ArrayLike
        +T: ArrayLike
        +__init__(size, T=u.celsius2kelvin(36.), T_base_p, T_base_q, g_max, V_sh, phi_p=None, phi_q=None, name=None)
        # Internal: T = u.kelvin2celsius(T) upon initialization
    }
    note for ICaT_HM1992 "'T' parameter default is now in Kelvin (u.celsius2kelvin(36.)) and converted to Celsius internally. Applies to other channel classes as well."
Loading

Updated Class Diagram: Neuron Compartment Models

classDiagram
    class SingleCompartmentNeuron {
        +update(I_ext=0. * u.nA / u.cm ** 2)
        # solver(self, t, dt, I_ext) # Behavior changed to use dt
    }
    class MultiCompartmentNeuron {
        +pop_size: Tuple[int, ...]
        +n_compartment: int
        +update(I_ext=0. * u.nA)
        # solver(self, t, dt, I_ext) # Behavior changed to use dt
    }
    note for SingleCompartmentNeuron "update() method now passes 'dt' to its internal solver."
    note for MultiCompartmentNeuron "update() method now passes 'dt' to its internal solver. Added pop_size and n_compartment properties."
Loading

File-Level Changes

Change Details Files
Standardize dt parameter and typing in integration utilities
  • Introduced type aliases (T, DT, VectorFiled, Y0)
  • Added dt argument to explicit/exp_euler and Runge–Kutta step functions
  • Updated apply_standard_solver_step and vector_field to propagate dt
  • Adjusted single/multi-compartment update to pass dt
_integrator_util.py
_integrator_diffrax.py
_integrator_runge_kutta.py
_integrator_exp_euler.py
_integrator.py
_single_compartment.py
_multi_compartment.py
Add implicit diffrax solver support
  • Implemented _implicit_solver and _diffrax_implicit_solver
  • Exposed diffrax_bwd_euler_step with tolerance parameter
_integrator_diffrax.py
Refactor morphology Section and PointSection APIs
  • Imported SectionName type and constrained unit types
  • Replaced combined points array with separate position and diam arguments
  • Added shape and unit assertions
  • Updated add_point_section and related tests
_morphology.py
_morphology_utils.py
_morphology_test.py
Normalize temperature handling in channel modules
  • Default T parameters now use Kelvin conversion
  • Inserted u.kelvin2celsius calls in constructors
channel/calcium.py
channel/potassium.py
channel/potassium_calcium.py
channel/sodium.py
channel/hyperpolarization_activated.py
Miscellaneous documentation and security updates
  • Renamed project in SECURITY.md to BrainCell and fixed repo URL
SECURITY.md

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625
Copy link
Copy Markdown
Collaborator Author

@sourcery-ai title

@sourcery-ai sourcery-ai Bot changed the title Update Refactor integrators and morphology; add implicit diffrax May 22, 2025
Copy link
Copy Markdown
Contributor

@sourcery-ai sourcery-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @chaoming0625 - I've reviewed your changes - here's some feedback:

  • Consider deprecating the old points argument in morphology APIs (and emitting a warning) while still supporting it for a transition period to avoid breaking existing user code.
  • The repeated temperature conversion pattern (T = u.kelvin2celsius(T)) across multiple channel constructors could be factored into a shared helper to reduce duplication.
  • Please update the PR description to summarize the major API changes (integrator signature updates, morphology refactoring, channel temperature handling) for clearer context and easier maintenance.
Here's what I looked at during the review
  • 🟡 General issues: 2 issues found
  • 🟢 Security: all looks good
  • 🟡 Testing: 1 issue found
  • 🟡 Complexity: 1 issue found
  • 🟢 Documentation: all looks good

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment on lines +317 to +330
def _implicit_solver(solver, fn: VectorFiled, y0: Y0, t0: T, dt: DT, args=()):
dt = u.Quantity(dt)
t0 = u.Quantity(t0).to_decimal(dt.unit)
dt = u.get_magnitude(dt)
y1 = solver.step(
diffrax.ODETerm(lambda t, y, args_: fn(t, y, *args_)[0]),
t0,
t0 + dt,
y0,
args,
None,
made_jump=False
)[0]
return y1, {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: The new _implicit_solver function does not handle solver state or auxiliary outputs.

This implementation won't support solvers or ODETerm functions that require state or auxiliary outputs. Please verify if this limitation is acceptable for your use cases.

Suggested change
def _implicit_solver(solver, fn: VectorFiled, y0: Y0, t0: T, dt: DT, args=()):
dt = u.Quantity(dt)
t0 = u.Quantity(t0).to_decimal(dt.unit)
dt = u.get_magnitude(dt)
y1 = solver.step(
diffrax.ODETerm(lambda t, y, args_: fn(t, y, *args_)[0]),
t0,
t0 + dt,
y0,
args,
None,
made_jump=False
)[0]
return y1, {}
def _implicit_solver(
solver,
fn: VectorFiled,
y0: Y0,
t0: T,
dt: DT,
args=(),
solver_state=None,
return_aux=False,
):
dt = u.Quantity(dt)
t0 = u.Quantity(t0).to_decimal(dt.unit)
dt = u.get_magnitude(dt)
term = diffrax.ODETerm(lambda t, y, args_: fn(t, y, *args_)[0])
# If solver_state is None, initialize it as required by the solver
if solver_state is None and hasattr(solver, "init"):
solver_state = solver.init(term, t0, y0, args)
# Call step with or without solver_state depending on solver API
step_kwargs = dict(
term=term,
t0=t0,
t1=t0 + dt,
y0=y0,
args=args,
solver_state=solver_state,
made_jump=False,
)
# Remove solver_state if not supported by the solver
import inspect
if "solver_state" not in inspect.signature(solver.step).parameters:
step_kwargs.pop("solver_state")
result = solver.step(**step_kwargs)
# result may be (y1, solver_state, aux) or (y1, solver_state)
if return_aux and len(result) == 3:
y1, solver_state, aux = result
return y1, solver_state, aux
elif len(result) == 2:
y1, solver_state = result
return y1, solver_state
else:
# fallback for legacy API
y1 = result[0]
return y1, {}

Comment on lines 86 to 88
# Construct conductance matrix for the model
print(morphology.conductance_matrix)
print(morphology.area)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Replace print statements with explicit assertions.

Use assertions to automatically verify the shapes or values of conductance_matrix and area, rather than relying on manual inspection via print statements.

Suggested change
# Construct conductance matrix for the model
print(morphology.conductance_matrix)
print(morphology.area)
# Construct conductance matrix for the model
assert morphology.conductance_matrix is not None, "Conductance matrix should not be None"
assert hasattr(morphology.conductance_matrix, "shape"), "Conductance matrix should have a shape attribute"
assert morphology.area is not None, "Area should not be None"
assert hasattr(morphology.area, "shape") or isinstance(morphology.area, (int, float)), "Area should have a shape attribute or be a number"

Comment thread SECURITY.md
## Reporting a bug in ``BrainCell``

Report security bugs in ``dendritex`` via [Github Issue](https://github.com/chaoming0625/dendritex/issues).
Report security bugs in ``BrainCell`` via [Github Issue](https://github.com/chaobrain/braincell/issues).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (typo): Consider standard capitalization for 'GitHub'.

Please update 'Github Issue' to 'GitHub Issue' in the link text.

Suggested change
Report security bugs in ``BrainCell`` via [Github Issue](https://github.com/chaobrain/braincell/issues).
Report security bugs in ``BrainCell`` via [GitHub Issue](https://github.com/chaobrain/braincell/issues).

@@ -197,7 +198,8 @@ def _general_rk_step(
@set_module_as('braincell')
def euler_step(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider introducing a decorator to generate the repetitive Runge-Kutta step functions automatically.

# add at top of module
def register_rk_step(tableau):
    def decorator(func):
        @set_module_as('braincell')
        def wrapper(target: DiffEqModule, t: T, dt: DT, *args):
            _general_rk_step(tableau, target, t, dt, *args)
        wrapper.__name__ = func.__name__
        wrapper.__doc__ = func.__doc__
        return wrapper
    return decorator

Then simplify each explicit step:

Before:

@set_module_as('braincell')
def euler_step(target: DiffEqModule, t: T, dt: DT, *args):
    """Perform a single step of the Euler method..."""
    _general_rk_step(euler_tableau, target, t, dt, *args)

After:

@register_rk_step(euler_tableau)
def euler_step(target, t, dt, *args):
    """Perform a single step of the Euler method..."""
    # body is empty; wrapper does the work

Do the same for all other *_step definitions. This removes the repetitive _general_rk_step(...) boilerplate while preserving signatures and docstrings.

@chaoming0625
Copy link
Copy Markdown
Collaborator Author

@sourcery-ai summary

@chaoming0625 chaoming0625 merged commit 728fb50 into main May 23, 2025
24 checks passed
@chaoming0625 chaoming0625 deleted the update branch May 23, 2025 11:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant