In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.special
import dataclasses
from functools import partial

import vtk
from vtk.util import numpy_support

import jaxdem as jd

jax.config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np

In [None]:
# TODO:
# make rigid body particles
# make spheres
# make ellipsoids
# make h5 io module

In [1]:
def save_spheres_vtp(filename, positions, radii, particle_ids=None):
    """
    Writes particle positions, radii, and IDs to a .vtp file.
    
    Args:
        filename (str): Output filename.
        positions (array): (N, 2) or (N, 3) centers.
        radii (array): (N,) radii.
        particle_ids (array, optional): (N,) integer IDs for grouping/coloring.
                                      If None, defaults to 0..N-1.
    """
    import os
    if os.path.exists(filename):
        os.remove(filename)

    pos = np.array(positions)
    rad = np.array(radii)
    
    # default to simple enumeration if no IDs provided
    if particle_ids is None:
        p_ids = np.arange(len(pos), dtype=np.int32)
    else:
        p_ids = np.array(particle_ids, dtype=np.int32)

    # 1. Setup Points
    points = vtk.vtkPoints()
    if pos.shape[1] == 2:
        pos_3d = np.column_stack((pos, np.zeros(pos.shape[0])))
    else:
        pos_3d = pos
    points.SetData(numpy_support.numpy_to_vtk(pos_3d, deep=True))
    
    # 2. PolyData
    polydata = vtk.vtkPolyData()
    polydata.SetPoints(points)
    
    # 3. Add Data Arrays
    point_data = polydata.GetPointData()
    
    # Diameter for scaling (Scale Factor 1.0)
    diameters = rad * 2.0
    scale_array = numpy_support.numpy_to_vtk(diameters, deep=True)
    scale_array.SetName("Diameter")
    point_data.AddArray(scale_array)
    
    # Radius (optional, good to keep)
    rad_array = numpy_support.numpy_to_vtk(rad, deep=True)
    rad_array.SetName("Radius")
    point_data.AddArray(rad_array)
    
    # Particle ID for coloring
    id_array = numpy_support.numpy_to_vtk(p_ids, deep=True)
    id_array.SetName("ParticleID")
    point_data.AddArray(id_array)
    
    # Set active scalars (default coloring)
    point_data.SetActiveScalars("ParticleID")
    
    # 4. Write
    writer = vtk.vtkXMLPolyDataWriter()
    writer.SetFileName(filename)
    writer.SetInputData(polydata)
    writer.SetDataModeToBinary()
    writer.Write()
    print(f"Wrote {len(pos)} particles to {filename}")

In [None]:
def get_fibonacci_sphere_coordinates(num_points, radius=1.0, asperity_radius=0.9):
    """
    Generates x,y,z coordinates for points uniformly distributed on a sphere surface.
    """
    # Create an array of indices from 0 to N-1
    indices = np.arange(num_points, dtype=float) + 0.5
    
    # Golden angle in radians
    phi = np.arccos(1 - 2 * indices / num_points)
    theta = np.pi * (1 + 5**0.5) * indices

    x = radius * np.cos(theta) * np.sin(phi)
    y = radius * np.sin(theta) * np.sin(phi)
    z = radius * np.cos(phi)

    asperity_pos = np.column_stack((x, y, z))
    rad = np.ones(num_points) * asperity_radius
    asperity_pos = np.vstack((asperity_pos, np.zeros(3)))
    rad = np.hstack((rad, np.ones(1) * radius))
    return asperity_pos, rad, np.ones_like(rad)

p_all = []
rad_all = []
p_ids_all = []

for i, asperity_radius in enumerate([0.99, 0.9, 0.5, 0.3, 0.1, 0.01]):
    N = 50
    core_radius = 1.0 - asperity_radius
    p, rad, p_ids = get_fibonacci_sphere_coordinates(N, core_radius, asperity_radius)
    p[:, 0] += i * 2.0
    p_all.append(p)
    rad_all.append(rad)
    p_ids_all.append(p_ids + i)

p = np.concatenate(p_all)
rad = np.concatenate(rad_all)
p_ids = np.concatenate(p_ids_all)

save_spheres_vtp(f"asperities_3d.vtp", p, rad, p_ids)

In [None]:
def calc_mu_eff(vertex_radius, outer_radius, num_vertices):
    return 1 / np.sqrt(((2 * vertex_radius) / ((outer_radius - vertex_radius) * np.sin(np.pi / num_vertices))) ** 2 - 1)

def generate_bidisperse_radii(total: int, count_ratio: float, size_ratio: float, small_size: float = 0.5):
    """
    total: number of particles
    count_ratio: proportion of particles that are the small size
    size_ratio: large diameter in terms of the small diameter
    small_size: diameter of the small particle
    """
    if count_ratio > 1 or count_ratio < 0:
        raise ValueError('count ratio out of bounds: 0 <= count_ratio <= 1')
    if size_ratio < 0:
        raise ValueError('size ratio out of bounds: 0 <= size_ratio')
    n_small = int(total * count_ratio)
    n_large = total - n_small
    radii = np.ones(total).astype(DT_FLOAT)
    radii[:n_small] = small_size
    radii[n_small:] = small_size * size_ratio
    return radii

def generate_polydisperse_radii(total: int, std_dev: float, avg_size: float = 0.5, random_seed: int = 42):
    """
    total: number of particles
    std_dev: standard deviation of the gaussian
    avg_size: size of the average particle
    """
    print("TODO: verify that this is the right way to handle polydispersity")
    np.random.seed(random_seed)
    return np.random.normal(size=total, loc=avg_size, scale=std_dev)

def get_closest_vertex_radius_for_mu_eff(mu_eff, outer_radius, num_vertices):
    # Calculate mathematically valid bounds
    sin_term = np.sin(np.pi / num_vertices)
    min_vertex_radius = outer_radius * sin_term / (2 + sin_term) + 1e-12
    max_vertex_radius = outer_radius - 1e-12
    
    # Check if target mu_eff is achievable
    max_mu_eff = calc_mu_eff(min_vertex_radius, outer_radius, num_vertices)
    min_mu_eff = calc_mu_eff(max_vertex_radius, outer_radius, num_vertices)
    
    if mu_eff > max_mu_eff or mu_eff < min_mu_eff:
        # Target mu_eff is outside achievable range
        return np.nan
    try:
        # Use root finding since we want calc_mu_eff(vertex_radius) = mu_eff
        def objective(vertex_radius):
            return calc_mu_eff(vertex_radius, outer_radius, num_vertices) - mu_eff
        
        # Brent's method is robust for this monotonic function
        result = brentq(objective, min_vertex_radius, max_vertex_radius, xtol=1e-12)
        return result
        
    except (ValueError, RuntimeError, ZeroDivisionError):
        # Fallback to bounded scalar minimization if root finding fails
        def obj_squared(vertex_radius):
            try:
                return (calc_mu_eff(vertex_radius, outer_radius, num_vertices) - mu_eff) ** 2
            except (ValueError, RuntimeError, ZeroDivisionError):
                return np.inf
        
        result = minimize_scalar(obj_squared, bounds=(min_vertex_radius, max_vertex_radius), method='bounded')
        return result.x if result.success else np.nan

def get_closest_num_vertices_for_mu_eff_and_radii(mu_eff, outer_radius, vertex_radius, min_nv=1, max_nv=np.inf):
    pass


def get_closest_num_vertices_for_friction_and_segment_length(vertex_radius, outer_radius, target_segment_length, target_friction, target_num_vertices, vertex_count_offset=5, min_num_vertices=2):
    ideal_num_vertices = target_num_vertices
    min_cost = np.inf
    for num_vertices in range(max(min_num_vertices, ideal_num_vertices - vertex_count_offset), ideal_num_vertices + vertex_count_offset):
        if num_vertices <= min_num_vertices:
            continue
        vertex_angle = 2 * np.pi / num_vertices
        inner_radius = outer_radius - vertex_radius
        friction = calc_mu_eff(vertex_radius, outer_radius, num_vertices)
        segment_length = inner_radius / vertex_radius * np.sin(vertex_angle / 2)

        cost = abs(segment_length / target_segment_length - 1) + abs(friction / target_friction - 1)
        if ~np.isnan(cost) and cost < min_cost:
            ideal_num_vertices = num_vertices
            min_cost = cost
    return ideal_num_vertices

def get_bumpy_dists(num_vertices, outer_radius, vertex_radius):
    sigma = outer_radius * 2
    sigma_v = vertex_radius * 2
    n_v = num_vertices
    sigma_p_i = sigma - sigma_v
    # the closest center-to-center distance between two particles of the same species at a symmetric contact
    d_0 = (np.cos(np.pi / n_v) / 2 + 1 / 2 + np.sqrt((sigma_v / sigma_p_i) ** 2 - (np.sin(np.pi / n_v) / 2) ** 2)) * sigma_p_i
    # the absolute closest center-to-center distance between two particles of the same species
    d_min = (np.sqrt(np.cos(np.pi / n_v) ** 2 + (sigma_v / sigma_p_i) ** 2 + np.cos(np.pi / n_v) * np.sqrt(4 * (sigma_v / sigma_p_i) ** 2 - np.sin(np.pi / n_v) ** 2))) * sigma_p_i
    return d_0, d_min