In [1]:
import numpy as np
import numba as nb
from scipy.spatial import cKDTree
from numpy.fft import fft2, ifft2

In [2]:
class Sperm2D:
    def __init__(self,
                 length=70.1,
                 n_segments=50,
                 bending_modulus=1800,
                 amplitude=0.2,
                 wavenumber=1.0, # Note: not the conventional wavenumner. 2*pi*k*l/L, where k is the wavenumber
                 frequency=10.0,
                 init_position=[0, 0], # tip of head
                 init_angle=0,
                 phase=0,
                 head_semi_major=3,
                 head_semi_minor=1): # Also the tail radius
        """
        Initialize a single 2D sperm filament.
        
        Parameters
        ----------
        length : float
            Total length of the filament (L).
        n_segments : int
            Number of discrete segments (N).
        bending_modulus : float
            Bending stiffness K_B.
        amplitude, wavenumber, frequency : float
            Parameters for the preferred-curvature waveform kappa(s,t).
        """
        # Geometry
        self.L = length
        self.N_flag = n_segments
        self.Delta_L = length / n_segments
        self.a = head_semi_major
        self.b = head_semi_minor
        
        # Mechanical parameters
        self.K_B = bending_modulus
        self.K_0 = amplitude
        self.k = wavenumber
        self.omega = frequency
        self.phi = phase
        
        # State: positions Y[0..N], angles theta[0..N]
        self.Y_0 = init_position
        self.theta_0 = init_angle
        self.edge_midpoint = np.hstack([0, self.a, np.arange(2.1*self.a, self.L+2.1*self.a, self.Delta_L/2)]) # Extra 0.1a to account for the linkage between the flagellum and the head
        self.Y = np.vstack([self.Y_0[0]+self.edge_midpoint[1::2]*np.cos(self.theta_0), self.Y_0[1]+self.edge_midpoint[1::2]*np.sin(self.theta_0)]).T # An array of length N+1 (midpoint of segmenets)
        self.theta = np.array([init_angle]*(self.N_flag+1)) # An array of length N+1 (midpoint of segments)
        self.edge = self.edge_midpoint[::2]
        
        # Lagrange multipliers for constraints (N+2 of them), corresponding to the edges of each segment
        self.Lambda = np.zeros(self.N_flag+2)
    
    def preferred_curvature(self, t):
        """
        Traveling-wave preferred curvature Kappa(s,t) along filament using Eq. 2.1 of Schoeller et al. 2018
        Returns an array of length N, corresponding to the midpoint of each filament segment only, i.e. no head.
        """
        s = np.arange(self.Delta_L/2, self.L, self.Delta_L)
        base = self.K_0 * np.sin(2*np.pi*self.k*s/self.L - self.omega*t)
        decay = np.where(s > self.L/2, 2*(self.L - s)/self.L, 1.0)
        kappa = base * decay
        return kappa
    
    def internal_moment(self, t):
        """
        Compute M_{n+1/2} for n=1..N+1 using Eq. 34 of Schoeller et al. 2020
        Returns an array of length N+2, corresponding to the edges of each segment.
        """
        kappa = self.preferred_curvature(t)
        t_hat_x, t_hat_y = np.cos(self.theta), np.sin(self.theta)
        cross = t_hat_x[:-1]*t_hat_y[1:] - t_hat_y[:-1]*t_hat_x[1:]
        delta_s = np.zeros(self.N_flag)
        delta_s[0] = 1.1*self.a+self.Delta_L/2
        delta_s[1:] = self.Delta_L
        M = np.zeros(self.N_flag+2)
        M[1:self.N_flag+1] = self.K_B * (cross/delta_s - kappa)
        return M

In [3]:
def generate_neighbors(positions, chi, a_head):
    """
    Filter out only segment pairs with centre-to-centre distance < 2*head_radius. 
    """
    tree = cKDTree(positions)
    pairs = np.array(list(tree.query_pairs(r=chi*2*a_head, output_type='ndarray')))
    return pairs

In [4]:
@nb.njit
def barrier_force(positions, head_radius, tail_radius, F_S, chi, neighbors):
    """
    Compute steric barrier forces, treating index 0 as head and others as tail.
    Assume no repulsion due to head-on tail-tail/head-tail interactions. Fair assumption since there're few such interactions.
    
    Parameters
    ----------
    positions : array, shape (N+1, 2)
        Coordinates of N+1 segment centers in 2D.
    head_radius and tail_radius: floats
        Self-explanatory
    F_S : float
        Reference strength of the repulsive force.
    chi : float
        Range factor.
    neighbors: array
        Output of query_pairs. Array of segment-pairs whose centre-to-centre distance < 2*a
    
    Returns
    -------
    forces : ndarray, shape (N+1, 2)
        Steric barrier force on the midpoint of each segment.
    """
    radii = np.zeros(positions.shape[0])
    radii[0] = head_radius
    radii[1:] = tail_radius
    F_B = np.zeros_like(positions)
    chi2m1 = chi**2 - 1
    for i in range(neighbors.shape[0]):
        n = neighbors[i,0]
        m = neighbors[i,1]
        if abs(n-m)==1:
            continue
        contact = radii[n] + radii[m]
        r_threshold = chi*contact
        # compute diff and dist
        dx = positions[n, 0] - positions[m, 0]
        dy = positions[n, 1] - positions[m, 1]
        dist_nm = np.sqrt(dx*dx + dy*dy)  # 2D
        if (dist_nm > 0.0) and (dist_nm < r_threshold):
            n_x = dx / dist_nm
            n_y = dy / dist_nm
            num   = r_threshold*r_threshold - dist_nm*dist_nm
            denom = contact*contact * chi2m1
            mag   = F_S * (num/denom)**4 / contact
            F_x, F_y = mag * n_x, mag * n_y
            F_B[n,0] +=  F_x
            F_B[n,1] +=  F_y
            F_B[m,0] -=  F_x
            F_B[m,1] -=  F_y
    return F_B

In [5]:
def elastic_torque(M):
    """
    Given bending moments at edges, M of length (N+2), 
    returns elastic torques T_E on each of the N+1 segments.
    
    Parameters
    ----------
    M : ndarray, shape (N+2,)
        By free-end BCs, M[0] and M[N+1] should be zero.
    
    Returns
    -------
    T_E : ndarray, shape (N+1,)
        Elastic torques at the midpoint of each segment.
    """
    T_E = M[1:] - M[:-1]
    return T_E

In [6]:
def constraint_torque(Lambda, angles, length, n_segments):
    """
    Vectorized compute of constraint torques T_C on each segment midpoint.

    Parameters
    ----------
    Lambda : ndarray, shape (N+2, 2)
        Constraint forces at each segment edge, inclusive of head (0…N+1).
    angles : ndarray, shape (N+1,)
        Tangent angles at each segment midpoint, head inclusive (0…N).
    length : float
        Total flagellum length, head exclusive.
    n_segments : int
        Number of flagellum segments, head exclusive (N).

    Returns
    -------
    T_C : ndarray, shape (N+1,)
        Scalar torque on each segment midpoint, head inclusive.
    """
    # Number of midpoints
    N = n_segments + 1
    half_Delta_L = 0.5 * length / n_segments

    # Sum adjacent Lambdas: shape (N+1, 2)
    lam_sum = Lambda[:-1] + Lambda[1:]

    # Tangent unit vector components: shape (N+1,)
    t_hat_x = np.cos(angles)
    t_hat_y = np.sin(angles)

    # Cross product t × lam_sum  in 2D: t_x*lam_y - t_y*lam_x
    torques = t_hat_x * lam_sum[:,1] - t_hat_y * lam_sum[:,0]

    T_C = half_Delta_L * torques
    return T_C

In [49]:
@nb.njit(parallel=True)
def fcm_spread(F_B, Lambda, T_E, T_C, sperm_coordinates, L_x, L_y, N_x, N_y, sigma_Delta, sigma_Theta):
    """
    JIT‐compiled spread of FCM force & torque onto a periodic 2D grid.
    
    Parameters
    ----------
    F_B : array, shape (N+1, 2)
        Barrier force acting at the midpoint of each segment.
    Lambda : array, shape (N+2, 2)
        Constraint force acting at the edge of each segment.
    T_E, T_C : arrays, shape (N+1, )
        Elastic and constraint torque respectively, acting at the midpoint of each segment.
    sperm_coordinates : array, shape (2N+3, 2)
        Positions of the edges and midpoint of each segment.
    L_x, L_y : floats
        Dimensions of the simulation domain
    N_x, N_y : integers
        Number of grids in each axis
    sigma_Delta, sigma_Theta: floats
        Gaussian envelope sizes for force and torque respectively.
        
    Returns
    -------
    f_grid : array, shapes (N_x, N_y, 2)
        Force and torque exerted by each segment on the fluid at each grid position after spreading.
    """

    # Generating the coordinates of the grid
    x = np.linspace(-L_x/2, L_x/2, N_x)
    y = np.linspace(-L_y/2, L_y/2, N_y)

    # Spread forces and torques onto grid
    f_grid = np.zeros((N_x, N_y, 2))
    N_s = T_E.shape[0] # Total number of midpoints
    F_total = np.zeros((2*N_s+1, 2))
    F_total[0::2] = Lambda # edges
    F_total[1::2] = F_B # midpoints
    T_total = T_E + T_C
    sd_2 = sigma_Delta*sigma_Delta
    st_2 = sigma_Theta*sigma_Theta
    inv2sd2 = 1.0/(2*sd_2)
    inv2st2 = 1.0/(2*st_2)
    normF = inv2sd2/np.pi
    normT = inv2st2/np.pi
    
    # 1) Spread all forces
    for n in nb.prange(2*N_s+1):
        Fx = F_total[n,0]
        Fy = F_total[n,1]
        sn_x = sperm_coordinates[n,0]
        sn_y = sperm_coordinates[n,1]
        # Periodic boundary conditions
        if sn_x >  L_x*0.5: sn_x -= L_x
        if sn_x < -L_x*0.5: sn_x += L_x
        if sn_y >  L_y*0.5: sn_y -= L_y
        if sn_y < -L_y*0.5: sn_y += L_y
        for i in range(N_x):
            xi = x[i] - sn_x
            for j in range(N_y):
                yj = y[j] - sn_y
                r2 = xi*xi + yj*yj
                wF = normF * np.exp(-r2 * inv2sd2)
                f_grid[i,j,0] += Fx * wF
                f_grid[i,j,1] += Fy * wF

    # 2) Spread all torques (only at midpoints)
    for n in nb.prange(N_s):
        Tn = T_total[n]
        sn_x = sperm_coordinates[2*n+1, 0]
        sn_y = sperm_coordinates[2*n+1, 1]
        # Periodic boundary conditions
        if sn_x >  L_x*0.5: sn_x -= L_x
        if sn_x < -L_x*0.5: sn_x += L_x
        if sn_y >  L_y*0.5: sn_y -= L_y
        if sn_y < -L_y*0.5: sn_y += L_y
        for i in range(N_x):
            xi = x[i] - sn_x
            for j in range(N_y):
                yj = y[j] - sn_y
                r2 = xi*xi + yj*yj
                wT = normT * np.exp(-r2 * inv2st2)
                # gradient components
                gx = -xi * wT / st_2
                gy = -yj * wT / st_2
                # cross term
                f_grid[i,j,0] +=  0.5 * Tn * gy
                f_grid[i,j,1] -=  0.5 * Tn * gx

    return f_grid