In [1]:
import numpy as np
import numba as nb
from scipy.spatial import cKDTree

In [32]:
class Sperm2D:
    def __init__(self,
                 length=70.1,
                 n_segments=50,
                 bending_modulus=1800,
                 amplitude=0.2,
                 wavenumber=1.0, # Note: not the conventional wavenumner. 2*pi*k*l/L, where k is the wavenumber
                 frequency=10.0,
                 init_position=[0, 0], # tip of head
                 init_angle=0,
                 phase=0,
                 head_semi_major=3,
                 head_semi_minor=1): # Also the tail radius
        """
        Initialize a single 2D sperm filament.
        
        Parameters
        ----------
        length : float
            Total length of the filament (L).
        n_segments : int
            Number of discrete segments (N).
        bending_modulus : float
            Bending stiffness K_B.
        amplitude, wavenumber, frequency : float
            Parameters for the preferred-curvature waveform kappa(s,t).
        """
        # Geometry
        self.L = length
        self.N_flag = n_segments
        self.Delta_L = length / n_segments
        self.a = head_semi_major
        self.b = head_semi_minor
        
        # Mechanical parameters
        self.K_B = bending_modulus
        self.K_0 = amplitude
        self.k = wavenumber
        self.omega = frequency
        self.phi = phase
        
        # State: positions Y[0..N], angles theta[0..N]
        self.Y_0 = init_position
        self.theta_0 = init_angle
        edge_midpoint = np.hstack([0, self.a, np.arange(2.1*self.a, self.L+2.1*self.a, self.Delta_L/2)]) # Extra 0.1a to account for the linkage between the flagellum and the head
        self.Y = np.vstack([self.Y_0[0]+edge_midpoint[1::2]*np.cos(self.theta_0), self.Y_0[1]+edge_midpoint[1::2]*np.sin(self.theta_0)]).T # An array of length N+1 (midpoint of segmenets)
        self.theta = np.array([init_angle]*(self.N_flag+1)) # An array of length N+1 (midpoint of segments)
        self.edge = edge_midpoint[::2]
        
        # Lagrange multipliers for constraints (N+2 of them), corresponding to the edges of each segment
        self.Lambda = np.zeros(self.N_flag+2)
    
    def preferred_curvature(self, t):
        """
        Traveling-wave preferred curvature Kappa(s,t) along filament using Eq. 2.1 of Schoeller et al. 2018
        Returns an array of length N, corresponding to the midpoint of each filament segment only, i.e. no head.
        """
        s = np.arange(self.Delta_L/2, self.L, self.Delta_L)
        base = self.K_0 * np.sin(2*np.pi*self.k*s/self.L - self.omega*t)
        decay = np.where(s > self.L/2, 2*(self.L - s)/self.L, 1.0)
        kappa = base * decay
        return kappa
    
    def internal_moment(self, t):
        """
        Compute M_{n+1/2} for n=1..N+1 using Eq. 34 of Schoeller et al. 2020
        Returns an array of length N+2, corresponding to the edges of each segment.
        """
        kappa = self.preferred_curvature(t)
        t_hat_x, t_hat_y = np.cos(self.theta), np.sin(self.theta)
        cross = t_hat_x[:-1]*t_hat_y[1:] - t_hat_y[:-1]*t_hat_x[1:]
        delta_s = np.zeros(self.N_flag)
        delta_s[0] = 1.1*self.a+self.Delta_L/2
        delta_s[1:] = self.Delta_L
        M = np.zeros(self.N_flag+2)
        M[1:self.N_flag+1] = self.K_B * (cross/delta_s - kappa)
        return M

In [52]:
@nb.njit
def barrier_force(positions, radii, F_S, chi, neighbors):
    """
    Compute steric barrier forces, treating index 0 as head and others as tail.
    Assume no repulsion due to head-on tail-tail/head-tail interactions. Fair assumption since there're few such interactions.
    
    Parameters
    ----------
    positions : array, shape (N, dim)
        Coordinates of N segment centers in 2D or 3D.
    radii : array of N elements
        Radii of each segment. Only the first one is head, the rest all tail.
    F_S : float
        Reference strength of the repulsive force.
    chi : float
        Range factor.
    neighbors: array
        Index pairs where the distance between the segments is <2*head_radius
    
    Returns
    -------
    forces : ndarray, shape (N, dim)
        Steric barrier force on each segment.
    """
    F_B = np.zeros_like(positions)
    chi2m1 = chi**2 - 1
    for i in range(neighbors.shape[0]):
        n = neighbors[i,0]
        m = neighbors[i,1]
        if abs(n-m)==1:
            continue
        contact = radii[n] + radii[m]
        r_threshold = chi*contact
        # compute diff and dist
        dx = positions[n, 0] - positions[m, 0]
        dy = positions[n, 1] - positions[m, 1]
        dist_nm = np.sqrt(dx*dx + dy*dy)  # 2D
        if (dist_nm > 0.0) and (dist_nm < r_threshold):
            n_x = dx / dist_nm
            n_y = dy / dist_nm
            num   = r_threshold*r_threshold - dist_nm*dist_nm
            denom = contact*contact * chi2m1
            mag   = F_S * (num/denom)**4 / contact
            F_x, F_y = mag * n_x, mag * n_y
            F_B[n,0] +=  F_x
            F_B[n,1] +=  F_y
            F_B[m,0] -=  F_x
            F_B[m,1] -=  F_y
    return F_B

In [None]:
# 1.1 Domain and grid
L_x = L_y = 931.3                 # physical size (periodic)
N_x, N_y = 64, 64                   # FFT grid points
dx, dy = L_x/N_x, L_y/N_y
x = np.arange(0, N_x, dx)
y = np.arange(0, N_y, dy)
# make 2D mesh if needed for plotting

F_b = barrier_force(sperm_1.Y, sperm_1.a, sperm_1.b, 550, chi=1.1)

f_grid = np.zeros((N_x,N_y,2))
for n in range(M):
    i = int(sperm_1.Y[n,0]/dx) % N_x
    j = int(sperm_1.Y[n,1]/dy) % N_y
    f_grid[i,j,:] += sperm_1.Lambda[n] + F_b[n]