In [13]:
import jax.numpy as jnp
from jax import jit
from functools import partial

In [7]:
test = jnp.array([0, 1, 2])
test2 = jnp.full(3, 3)
jnp.concatenate([test, test2])

Array([0, 1, 2, 3, 3, 3], dtype=int32)

In [19]:
class Sperm2D:
    def __init__(self,
                 length: float = 65.0,
                 n_segments: int = 50,
                 bending_modulus: float = 1800.0,
                 amplitude: float = 0.2,
                 wavenumber: float = 1.0,
                 frequency: float = 1.0,
                 phase: float = 0.0,
                 init_position: jnp.ndarray = jnp.array([0.0, 0.0]),
                 init_angle: float = 0.0,
                 head_semi_major: float = 3.0,
                 tail_radius: float = 1.0):
        """
        Initialize a single 2D sperm filament.

        Parameters
        ----------
        length : Total length of the filament (L).
        n_segments : Number of discrete segments (N).
        bending_modulus : Bending stiffness K_B.
        amplitude, wavenumber, frequency, phase : Parameters for the preferred-curvature waveform kappa(s,t).
        init_position : Initial position of the head tip of sperm
        init_angle: Initial angle of the sperm
        head_semi_major, tail_radius: radius of head and flagellum segments
        """
        # Geometry & discretization
        self.L       = length
        self.N_flag  = n_segments
        self.dL      = length / n_segments
        self.a       = head_semi_major
        self.b       = tail_radius

        # radii and segment lengths as JAX arrays
        self.radii   = jnp.concatenate([jnp.array([self.a]),
                                        jnp.full(self.N_flag, self.b)])
        self.lengths = jnp.concatenate([jnp.array([2*self.a]),
                                        jnp.full(self.N_flag, self.dL)])

        # Mechanical parameters
        self.K_B  = bending_modulus
        self.K0   = amplitude
        self.k    = wavenumber
        self.omega = frequency
        self.phi = phase

        # Precompute segment "s" coordinates for preferred curvature
        self.s_mid = jnp.arange(self.dL/2,
                            self.L,
                            self.dL) # shape (N_flag, )

        # midpoint position and angle of each segment including head
        all_coords = jnp.concatenate([
            jnp.array([self.a]),
            self.s_mid + 1.1*self.a
        ])
        self.Y0 = jnp.stack([
            init_position[0] + all_coords * jnp.cos(init_angle),
            init_position[1] + all_coords * jnp.sin(init_angle)
        ], axis=1)                       # shape (N_flag+1, 2)
        self.thetas = jnp.full(self.N_flag+1, init_angle)
        
        # Lagrange multipliers for N_flag+2 constraints
        self.Lambdas = jnp.zeros((self.N_flag+2, 2))

    @partial(jax.jit, static_argnums=(0,))
    def preferred_curvature(self, t: float) -> jnp.ndarray: # shape (N_flag, )
        """
        Calculate the curvature at midpoints of each flagellum segment, head exclusive.

        Parameters
        ----------
        t : time

        Returns
        -------
        kappa : curvature
        """
        base  = self.K0 * jnp.sin(2*jnp.pi*self.k*self.s_mid/self.L
                                  - self.omega*t + self.phi)
        decay = jnp.where(self.s_mid > self.L/2,
                          1.0,
                          2*self.s_mid/self.L)
        return base * decay

    @partial(jax.jit, static_argnums=(0,))
    def internal_moment(self,
                        theta: jnp.ndarray,   # shape (N_flag+1, )
                        t: float) -> jnp.ndarray: # shape (N_flag+2, )
        """
        Calculate the internal moments at the edge of each segment, head inclusive.

        Parameters
        ----------
        theta : angle at the midpoint of each segment
        t : time

        Returns
        -------
        M : internal moments at segment edges
        """
        kappa = self.preferred_curvature(t)                 # (N_flag, )
        t_hat = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1)  # (N_flag+1, 2)

        # cross product z‐component: t_i × t_{i+1}
        cross = (t_hat[:-1,0]*t_hat[1:,1]
                 - t_hat[:-1,1]*t_hat[1:,0])           # (N_flag, )

        # delta_s for each edge
        delta_s = jnp.concatenate([jnp.array([1.1*self.a+0.5*self.L]),
                                    jnp.full(self.N_flag-1, self.dL)]) # (N_flag, )

        # build M: zeros at boundaries, use vector ops
        M_mid = self.K_B * (cross / delta_s - kappa)    # (N_flag, )
        return jnp.concatenate([jnp.array([0.0]),
                                M_mid,
                                jnp.array([0.0])])  # (N_flag+2,)

    @jit
    def reconstruct_midpoints(self,
                              Y1: jnp.ndarray,    # shape (2,)
                              θ: jnp.ndarray      # shape (N_flag+1,)
                              ) -> jnp.ndarray:    # returns (N_flag+1,2)
        t_hat = jnp.stack([jnp.cos(θ), jnp.sin(θ)], axis=1)  # (N+1,2)

        # build displacements without in‐place mutation
        disp0 = jnp.zeros_like(t_hat)                       # (N+1,2)

        # head‐link displacement
        head_link = 0.5*(1.1*self.a + 0.5*self.dL) * (
                        t_hat[0] + t_hat[1]
                    )
        # tail displacements
        tail_disp = 0.5*self.dL * (
                        t_hat[:-1] + t_hat[1:]
                    )                                     # (N,2)

        # scatter head_link and tail_disp into disp0
        disp1 = disp0.at[1].set(head_link)
        # tail_disp goes into positions 1→N
        disp = disp1.at[1:].add(tail_disp)

        return Y1 + jnp.cumsum(disp, axis=0)               # (N+1,2)


In [20]:
sperm = Sperm2D()

In [22]:
%%timeit
sperm.internal_moment(sperm.thetas, 0)

15.9 μs ± 51.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
