# PD Topology Optimization using global funcitons for calculations

## all combined

In [None]:
from functools import partial

import jax
import jax.numpy as jnp

import numpy as np
import scipy.spatial
import scipy.optimize
import matplotlib.pyplot as plt


import jax.scipy
import jax.scipy.optimize
from jax.scipy.optimize import minimize
from jax.nn import softplus

from jax import grad, jit


from typing import Union, Tuple

import optax
from jax import jit, vmap, grad, value_and_grad
from jax import lax
from jax import grad
from jax.experimental import checkify


import jax
import jax.numpy as jnp
import numpy as np
import scipy.spatial
from typing import Union, NamedTuple, Optional


# ----------------------------
# PARAMETER STRUCT
# ----------------------------
class PDParams(NamedTuple):
    bar_length: float
    number_of_elements: int
    bulk_modulus: float
    density: float
    thickness: jnp.ndarray
    horizon: float
    critical_stretch: Optional[float]
    prescribed_velocity: Optional[float]
    prescribed_force: Optional[float]
    nodes: jnp.ndarray
    lengths: jnp.ndarray
    pd_nodes: jnp.ndarray
    num_nodes: int
    neighborhood: jnp.ndarray
    reference_position_state: jnp.ndarray
    reference_magnitude_state: jnp.ndarray
    num_neighbors: jnp.ndarray
    max_neighbors: int
    no_damage_region_left: jnp.ndarray
    no_damage_region_right: jnp.ndarray
    width: float  
    right_bc_region: jnp.ndarray
    left_bc_region: jnp.ndarray
    undamaged_influence_state_left: jnp.ndarray
    undamaged_influence_state_right: jnp.ndarray


class PDState(NamedTuple):
    disp: jnp.ndarray
    vel: jnp.ndarray
    acc: jnp.ndarray
    vol_state: jnp.ndarray
    rev_vol_state: jnp.ndarray
    influence_state: jnp.ndarray
    undamaged_influence_state: jnp.ndarray
    strain_energy: float
    time: float


# ----------------------------
# GLOBAL INITIALIZATION FUNCTION
# ----------------------------
def init_problem(bar_length: float = 20.0,
                 density: float = 1.0,
                 bulk_modulus: float = 100.0,
                 number_of_elements: int = 20,
                 horizon: Optional[float] = None,
                 thickness: Union[float, np.ndarray] = 1.0,
                 prescribed_velocity: Optional[float] = None,
                 prescribed_force: Optional[float] = None,
                 critical_stretch: Optional[float]= None):

    
    """
    Create PDParams and PDState tuples for a new problem.
    """

    delta_x = bar_length / number_of_elements
    if horizon is None:
        horizon = delta_x * 3.015

    # thickness as array
    if isinstance(thickness, float) or np.isscalar(thickness):
        thickness_arr = jnp.ones(number_of_elements) * thickness
    elif isinstance(thickness, (np.ndarray, jnp.ndarray)):
        thickness_arr = jnp.asarray(thickness)
        assert thickness_arr.shape[0] == number_of_elements, ValueError("Thickness array length must match number of elements")
    else:
        raise ValueError("thickness must be a float or array")
    
    # nodes and element lengths
    nodes = jnp.linspace(-bar_length / 2.0, bar_length / 2.0, num=number_of_elements + 1)
    lengths = jnp.array(nodes[1:] - nodes[0:-1])
    pd_nodes = jnp.array(nodes[0:-1] + lengths / 2.0)
    num_nodes = pd_nodes.shape[0]

    # kdtree setup
    tree = scipy.spatial.cKDTree(pd_nodes[:, None])
    reference_magnitude_state, neighborhood = tree.query(
        pd_nodes[:, None], k=100, p=2, eps=0.0,
        distance_upper_bound=(horizon + np.max(lengths) / 2.0))

    # trim out self-distance column
    reference_magnitude_state = jnp.delete(reference_magnitude_state, 0, 1)

    num_neighbors = jnp.asarray((neighborhood != tree.n).sum(axis=1)) - 1
    max_neighbors = int(np.max((neighborhood != tree.n).sum(axis=1)))

    neighborhood = jnp.asarray(neighborhood[:, :max_neighbors])
    reference_magnitude_state = reference_magnitude_state[0:, :max_neighbors - 1]

    row_indices = jnp.arange(neighborhood.shape[0]).reshape(-1, 1)
    neighborhood = jnp.where(neighborhood == tree.n, row_indices, neighborhood)
    neighborhood = jnp.delete(neighborhood, 0, 1)

    reference_position_state = pd_nodes[neighborhood] - pd_nodes[:, None]
    reference_magnitude_state = jnp.where(reference_magnitude_state == np.inf, 0.0, reference_magnitude_state)


    if prescribed_velocity is not None and prescribed_force is not None:
        raise ValueError("Only one of prescribed_velocity or prescribed_force should be set, not both.")
    
    if prescribed_velocity is None and prescribed_force is None:
        raise ValueError("Either prescribed_velocity or prescribed_force must be set.")

    #:The node indices of the boundary region at the left end of the bar
    li = 0
    left_bc_mask = neighborhood[li] != li
    left_bc_region = neighborhood[li][left_bc_mask]

    #:The node indices of the boundary region at the right end of the bar
    ri = num_nodes - 1
    right_bc_mask = neighborhood[ri] != ri
    right_bc_region = neighborhood[ri][right_bc_mask]

    if prescribed_velocity is not None:
        left_bc_region = jnp.asarray(tree.query_ball_point(pd_nodes[0, None], r=(horizon + np.max(lengths) / 2.0), p=2, eps=0.0)).sort()
        right_bc_region = jnp.asarray(tree.query_ball_point(pd_nodes[-1, None], r=(horizon + np.max(lengths) / 2.0), p=2, eps=0.0)).sort()

    # no-damage regions
    no_damage_region_left = jnp.asarray(
        tree.query_ball_point(pd_nodes[0, None], r=(2.5 * horizon + np.max(lengths) / 2.0), p=2, eps=0.0)
    ).sort()
    no_damage_region_right = jnp.asarray(
        tree.query_ball_point(pd_nodes[-1, None], r=(2.5 * horizon + np.max(lengths) / 2.0), p=2, eps=0.0)
    ).sort()

    # initial vol_state (full volume)
    vol_state = jnp.ones((num_nodes, max_neighbors - 1))
    rev_vol_state = vol_state.copy()

    #jax.debug.print("vol_state in init: {v}",v=vol_state)

    influence_state = jnp.where(vol_state > 1.0e-16, 1.0, 0.0)
    undamaged_influence_state = influence_state.copy()
    
    undamaged_influence_state_left = influence_state.at[no_damage_region_left, :].get()
    undamaged_influence_state_right = influence_state.at[no_damage_region_right, :].get()

    width = 1.0  # Width of the bar, can be adjusted if needed


    # package params
    params = PDParams(
        bar_length, number_of_elements, bulk_modulus, density, thickness_arr,
        horizon, critical_stretch, prescribed_velocity, prescribed_force,
        nodes, lengths, pd_nodes, num_nodes, neighborhood,
        reference_position_state, reference_magnitude_state, num_neighbors, max_neighbors,
        no_damage_region_left, no_damage_region_right, width, right_bc_region, left_bc_region,
        undamaged_influence_state_left, undamaged_influence_state_right
    )


    # package initial state
    state = PDState(
        disp=jnp.zeros(num_nodes),
        vel=jnp.zeros(num_nodes),
        acc=jnp.zeros(num_nodes),
        vol_state=vol_state,
        rev_vol_state=rev_vol_state,
        influence_state=influence_state,
        undamaged_influence_state=undamaged_influence_state,
        strain_energy=0.0,
        time=0.0
    )

    return params, state

def compute_partial_volumes(params, thickness:jax.Array):

    # Setup some local (to function) convenience variables
    neigh = params.neighborhood
    lens = params.lengths
    ref_mag_state = params.reference_magnitude_state
    horiz = params.horizon


    # Initialize the volume_state to the lengths * width * thickness
    width = params.width
    vol_state_uncorrected = lens[neigh] * thickness[neigh] * width 


    #Zero out entries that are not in the family
    vol_state_uncorrected = jnp.where(ref_mag_state < 1.0e-16, 0.0, vol_state_uncorrected) 


    vol_state = jnp.where(ref_mag_state < horiz + lens[neigh] / 2.0, vol_state_uncorrected, 0.0)
    #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(vol_state)))

    # Check to see if the neighboring node has a partial volume
    is_partial_volume = jnp.abs(horiz - ref_mag_state) < lens[neigh] / 2.0
    #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(is_partial_volume)))

    # Two different scenarios:
    is_partial_volume_case1 = is_partial_volume * (ref_mag_state >= horiz)
    is_partial_volume_case2 = is_partial_volume * (ref_mag_state < horiz)

    # Compute the partial volumes conditionally
    vol_state = jnp.where(is_partial_volume_case1, (lens[neigh] / 2.0 - (ref_mag_state - horiz)) * width * thickness[neigh], vol_state)
    vol_state = jnp.where(is_partial_volume_case2, (lens[neigh] / 2.0 + (horiz - ref_mag_state)) * width * thickness[neigh], vol_state)

    # If the partial volume is predicted to be larger than the unocrrected volume, set it back
    vol_state = jnp.where(vol_state > vol_state_uncorrected, vol_state_uncorrected, vol_state)

    # Now compute the "reverse volume state", this is the partial volume of the "source" node, i.e. node i,
    # as seen from node j.  This doesn't show up anywhere in any papers, it's just used here for computational
    # convenience
    vol_array = lens[:,None] * width * thickness[:, None]
    #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(vol_array)))

    rev_vol_state = jnp.ones_like(vol_state) * vol_array
    #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(rev_vol_state)))
    
    rev_vol_state = jnp.where(is_partial_volume_case1, (lens[:, None] / 2.0 - (ref_mag_state - horiz)) * width * thickness[:, None], rev_vol_state)
    #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(rev_vol_state)))
    
    rev_vol_state = jnp.where(is_partial_volume_case2, (lens[:, None] / 2.0 + (horiz - ref_mag_state)) * width * thickness[:, None], rev_vol_state)
    #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(rev_vol_state)))
    
    #If the partial volume is predicted to be larger than the uncorrected volume, set it back
    rev_vol_state = jnp.where(rev_vol_state > vol_array, vol_array, rev_vol_state)


    return (vol_state, rev_vol_state)

'''
@jax.jit
def my_where(x: jax.Array):
    # Elementwise comparison without jnp.where, using list comprehension
    return jnp.array([i if i >= 1E-12 else 1E-12 for i in x])
'''

###### functions to replace jnp.where with lax.cond for vectorized operations ######
@jax.jit
def my_where(x: jax.Array):
    def cond_fn(i):
        return jax.lax.cond(i >= 1E-12, lambda _: i, lambda _: 1E-12, operand=None)
    return jnp.vectorize(cond_fn)(x)

@jax.jit
def my_stretch_where(ref_mag_state: jax.Array, exten_state: jax.Array):
    def cond_fn(r, e):
        return jax.lax.cond(r > 1.0e-16, lambda _: e / r, lambda _: 0.0, operand=None)
    return jnp.vectorize(cond_fn)(ref_mag_state, exten_state)

@jax.jit
def inf_state_where(inf_state: jax.Array, stretch: jax.Array, critical_stretch: float):
    def cond_fn(stretch, inf_state):
        return jax.lax.cond(stretch > critical_stretch, lambda _: 0.0, lambda _: inf_state, operand=None)
    return jnp.vectorize(cond_fn)(stretch, inf_state)

@jax.jit
def my_replace_zero_val_where(x_array: jax.Array, eps: float):
    def cond_fn(i):
        return jax.lax.cond(i == 0.0, lambda _: eps, lambda _: i, operand=None)
    return jnp.vectorize(cond_fn)(x_array)

@jax.jit
def shape_tens_eps_where(shape_tens: jax.Array, epsilon: float):
    def cond_fn(val):
        return jax.lax.cond(jnp.abs(val) < epsilon, lambda _: epsilon, lambda _: val, operand=None)
    return jnp.vectorize(cond_fn)(shape_tens)
###########################

# Compute the force vector-state using a LPS peridynamic formulation
@partial(jit, static_argnums=(4,))
def compute_force_state_LPS(params,
                            disp:jax.Array, 
                            vol_state:jax.Array,
                            inf_state:jax.Array, allow_damage: bool) -> Tuple[jax.Array, jax.Array]:
        
    #Define some local convenience variables     
    ref_pos = params.pd_nodes 
    ref_pos_state = params.reference_position_state
    ref_mag_state = params.reference_magnitude_state
    neigh = params.neighborhood
    K = params.bulk_modulus
    critical_stretch = params.critical_stretch
    no_damage_region_left = params.no_damage_region_left
    no_damage_region_right = params.no_damage_region_right
    undamaged_influence_state_left = params.undamaged_influence_state_left
    undamaged_influence_state_right = params.undamaged_influence_state_right

    #disp = 0.001
    disp =  disp[0]
    #jnp.zeros_like(ref_pos)

    #Compute the deformed positions of the nodes
    def_pos = ref_pos + disp
    #jax.debug.print("def_pos finite: {b}", b=jnp.all(jnp.isfinite(def_pos)))

    #Compute deformation state
    def_state = def_pos[neigh] - def_pos[:,None]
    #jax.debug.print("def_state finite: {b}", b=jnp.all(jnp.isfinite(def_state)))
    #jax.debug.print("def_state zeros? {z}", z=jnp.any(def_state == 0))

    
    # Compute deformation magnitude state
    def_mag_state = jnp.sqrt(def_state * def_state)
    #def_mag_state = jnp.linalg.norm(def_state, axis=-1)
    #jax.debug.print("[def_mag_state] any<=0? {a} min={m} max={M}",
                #a=~jnp.all(def_mag_state > 0), m=jnp.min(def_mag_state), M=jnp.max(def_mag_state))
    #jax.debug.print("def_mag_state? {z}", z=jnp.any(def_mag_state == 0))

    # Compute deformation unit state
    #def_unit_state = jnp.where(def_mag_state > 1.0e-12, def_state / def_mag_state, 0.0)
    def_unit_state = my_stretch_where(def_mag_state, def_state)

    #def_unit_state = jax.vmap(safe_unit)(def_state, def_mag_state)


    # Compute scalar extension state
    exten_state = def_mag_state - ref_mag_state
    #jax.debug.print("[exten_state] finite={f} min={m} max={M}",
                #f=jnp.all(jnp.isfinite(exten_state)), m=jnp.min(exten_state), M=jnp.max(exten_state))

    #stretch = jnp.where(ref_mag_state > 1.0e-16, exten_state / ref_mag_state, 0.0)
    stretch = my_stretch_where(ref_mag_state, exten_state)


    def damage_branch(inf_state):
        #inf_state = jnp.where(stretch > critical_stretch, 0.0, inf_state)
        inf_state = inf_state_where(inf_state, stretch, critical_stretch)
        inf_state = inf_state.at[no_damage_region_left, :].set(undamaged_influence_state_left)
        inf_state = inf_state.at[no_damage_region_right, :].set(undamaged_influence_state_right)
        return inf_state

    def no_damage_branch(inf_state):
        # return state unchanged
        return inf_state


    # Apply a critical strech fracture criteria
    inf_state = lax.cond(allow_damage, damage_branch, no_damage_branch, inf_state)

    #jax.debug.print("[inf_state pre-eps] any==0? {z} finite={f}",
                #z=jnp.any(inf_state == 0.0), f=jnp.all(jnp.isfinite(inf_state)))

    eps = 1e-10  # or smaller if your scale is tiny

    #inf_state = jnp.where(inf_state == 0.0, eps, inf_state)
    inf_state = my_replace_zero_val_where(inf_state, eps)

    #ref_pos_state = jnp.where(ref_pos_state == 0.0, eps, ref_pos_state)
    ref_pos_state = my_replace_zero_val_where(ref_pos_state, eps) 
    #jax.debug.print("ref_pos_state zeros? {z}", z=jnp.any(ref_pos_state == 0))
    
    # Compute the shape tensor (really a scalar because this is 1d), i.e. the "weighted volume" as 
    # defined in Silling et al. 2007
    # added epsilon to prevent dividing by zero
    epsilon = 1e-8
    shape_tens = (inf_state * ref_pos_state * ref_pos_state * vol_state).sum(axis=1)
    #shape_tens = jnp.where(jnp.abs(shape_tens) < epsilon, epsilon, shape_tens)
    shape_tens = shape_tens_eps_where(shape_tens, epsilon)
    

    # Compute scalar force state for a elastic constitutive model
    ######### compute strain energy density here?  or calculation at least should look like this line here ########
    #scalar_force_state = 9.0 * K / shape_tens[:, None] * exten_state
    scalar_force_state = 9.0 * K * safe_divide(exten_state, shape_tens[:, None])

    # bond strain energy calc
    #bond_strain_energy = 9.0 * K / shape_tens[:, None] * exten_state * exten_state * ref_mag_state
    bond_strain_energy = 9.0 * K * safe_divide(exten_state**2 * ref_mag_state, shape_tens[:, None])

    # Compute the force state
    force_state = inf_state * scalar_force_state * def_unit_state

    ###  return bond_strain_energy
    return force_state, inf_state, bond_strain_energy



def smooth_ramp(t, t0, c=1.0, beta=5.0):
    """
    Function that linearly ramps up to c at t0, then smoothly transitions to c.

    Parameters:
    - t: Time variable (scalar or numpy array).
    - t0: Time at which the transition occurs.
    - c: Final constant value after transition.
    - beta: Smoothness parameter (higher values = sharper transition).

    Returns:
    - f: Value of the function at time t.
    """
    # Linear ramp before t0 (with slope c/t0)
    linear_ramp = (c / t0) * t

    # Smooth transition using an exponential decay term
    smooth_transition = c * (1 - jnp.exp(-beta * (t - t0))) + (c / t0) * t0

    # Use `np.where` to define the piecewise function
    f = jnp.where(t < t0, linear_ramp, smooth_transition)
    
    return f

# Internal force calculation
#@jax.jit(static_argnums=7)  
@partial(jax.jit, static_argnums=(7))
def compute_internal_force(params, disp, vol_state, rev_vol_state, inf_state, thickness, time, allow_damage):

        
    # Define some local convenience variables     
    neigh = params.neighborhood
    prescribed_force = params.prescribed_force
    width = params.width
    num_nodes = params.num_nodes
    left_bc_region = params.left_bc_region
    right_bc_region = params.right_bc_region

    #jax.debug.print("disp zeros? {z}", z=jnp.any(disp == 0))
    

    ##### return bond_strain_energy #####
    force_state, inf_state, bond_strain_energy = compute_force_state_LPS(params, disp, vol_state, inf_state, allow_damage)

    #Integrate nodal forces 
    force = (force_state * vol_state).sum(axis=1)
    force = force.at[neigh].add(-force_state * rev_vol_state)

    #strain_energy = bond_strain_energy
    strain_energy = 0.5 * (bond_strain_energy * vol_state).sum(axis=1)

    #total_strain_energy = jnp.sum(strain_energy)
    #jax.debug.print("strain_energy: {s}", s=jnp.sum(strain_energy))

    if prescribed_force is not None:
        li = 0
        ri = num_nodes - 1
        #ramp_force = smooth_ramp(time, t0=1.e-5, c=prescribed_force) 
        ramp_force = smooth_ramp(time, t0=1.e-3, c=prescribed_force) 

        denom_left  = vol_state[li].sum() + rev_vol_state[li][0]
        denom_right = vol_state[ri].sum() + rev_vol_state[ri][0]



        eps = 1e-12
        #denom_left  = jnp.where(jnp.abs(denom_left)  < eps, eps, denom_left)
        #denom_right = jnp.where(jnp.abs(denom_right) < eps, eps, denom_right)

        denom_left  = jnp.clip(denom_left,  1e-8, jnp.inf)
        denom_right = jnp.clip(denom_right, 1e-8, jnp.inf)


        # Compute the left boundary condition nodal forces
        left_bc_area = width * thickness[left_bc_region]
        #left_bc_nodal_forces = (ramp_force * left_bc_area)/(vol_state[li].sum() + rev_vol_state[li][0])
        left_bc_nodal_forces = (ramp_force * left_bc_area)/denom_left
        force = force.at[left_bc_region].add(-left_bc_nodal_forces)

        # For the leftmost node (if needed)
        left_bc_area_li = width * thickness[li]
        force = force.at[li].add(-ramp_force * left_bc_area_li)

        # Compute the right boundary condition nodal forces
        right_bc_area = width * thickness[right_bc_region]
        #right_bc_nodal_forces = (ramp_force * right_bc_area)/(vol_state[ri].sum() + rev_vol_state[ri][0])
        right_bc_nodal_forces = (ramp_force * right_bc_area)/denom_right
        force = force.at[right_bc_region].add(right_bc_nodal_forces)

        # For the rightmost node (if needed)
        right_bc_area_ri = width * thickness[ri]
        force = force.at[ri].add(-ramp_force * right_bc_area_ri)

    return force, inf_state, strain_energy

@partial(jax.jit, static_argnums=(2,))
def solve_one_step(params, vals, allow_damage:bool):

    (disp, vel, acc, vol_state, rev_vol_state, inf_state, thickness, undamaged_inf_state, strain_energy, time) = vals

    prescribed_velocity = params.prescribed_velocity
    bar_length = params.bar_length
    left_bc_region = params.left_bc_region
    right_bc_region = params.right_bc_region
    pd_nodes = params.pd_nodes
    rho = params.density


    # TODO: Solve for stable time step
    time_step = 1E-06

    #jax.debug.print("in solve_one_step: {t}", t=time)
    ##########################
    # Check inputs for NaNs
    ##########################

    # Check if any of the input arrays contain NaNs
    for name, vals in zip(
        ["disp", "vel", "acc", "vol_state", "rev_vol_state", "inf_state", "thickness"],
        [disp, vel, acc, vol_state, rev_vol_state, inf_state, thickness]):
        is_finite = jnp.all(jnp.isfinite(vals))
        #jax.debug.print("no NaNs detected in {n}: {f}", n=vals, f=is_finite)

    if prescribed_velocity is not None:
        bc_value = prescribed_velocity * time
        # Apply displacements bcs
        f = lambda x: 2.0 * bc_value / bar_length * x
        disp = disp.at[left_bc_region].set(f(pd_nodes[left_bc_region]))
        disp = disp.at[right_bc_region].set(f(pd_nodes[right_bc_region]))




    force, inf_state, strain_energy = compute_internal_force(params, disp, vol_state, rev_vol_state, inf_state, thickness, time, allow_damage)

    acc_new = force / rho

    vel = vel.at[:].add(0.5 * (acc + acc_new) * time_step)

    disp = disp.at[:].add(vel * time_step + (0.5 * acc_new * time_step * time_step))
    acc = acc.at[:].set(acc_new)

    #jax.debug.print("disp in solve_one_step: {d}", d=disp)

    strain_energy_total = jnp.sum(strain_energy)
    #jax.debug.print("strain_energy_total in solve: {s}", s=strain_energy_total)

    def nan_debug_print(x):
        jax.debug.print("NaNs detected in strain_energy_total")
        return x

    def no_nan(x):
        return x
    '''
    # This will always execute, but only print if the condition is true
    _ = jax.lax.cond(
        ~jnp.all(jnp.isfinite(strain_energy_total)),
        nan_debug_print,
        no_nan,
        operand=strain_energy_total )
    '''

    #jax.debug.print("disp: {d}", d=disp)
    #jax.debug.print("disp? {z}", z=jnp.any(disp == 0))

    return (disp, vel, acc, vol_state, rev_vol_state, inf_state, thickness, undamaged_inf_state, strain_energy_total, time + time_step)

### put wrapper on solve
#@partial(jax.jit, static_argnums=(3,4))
def _solve(params, state, thickness:jax.Array, allow_damage:bool, max_time:float=1.0):
    '''
        Solves in time using Verlet-Velocity
    '''

    EPS = 1.0e-12  # Minimum safe volume to avoid NaNs

    time_step = 1.0e-7
    num_steps = int(max_time / time_step)

    vol_state, rev_vol_state = compute_partial_volumes(params, thickness)

    # Clamp to avoid divide-by-zero or log(0) NaNs
    vol_state = jnp.maximum(vol_state, EPS)
    rev_vol_state = jnp.maximum(rev_vol_state, EPS)

    #jax.debug.print("vol_state= {V}",V=vol_state)
    #jax.debug.print("vol_state min={vmin}", vmin=jnp.min(vol_state))
    #jax.debug.print("vol_state zeros? {z}", z=jnp.any(vol_state == 0))


    inf_state = state.influence_state.copy() 
    undamaged_inf_state = state.undamaged_influence_state.copy()

    # Initialize a fresh influence state for this run
    #inf_state = jnp.where(vol_state > 1.0e-16, 1.0, 0.0)

    #jax.debug.print("inf_state update after where: {i}",i=inf_state)
    # The fields
    disp = jnp.zeros_like(params.pd_nodes) 
    vel = jnp.zeros_like(params.pd_nodes)
    acc = jnp.zeros_like(params.pd_nodes)
    time = 0.0
    strain_energy = 0.0

    def loop_body(i, vals):
        new_vals = solve_one_step(params, vals, allow_damage)
        return new_vals

    #Solve
    vals = (disp, vel, acc, vol_state, rev_vol_state, inf_state, thickness, undamaged_inf_state, strain_energy, time)

    vals_returned = jax.lax.fori_loop(0, num_steps, loop_body, vals)


    return PDState(
        disp=vals_returned[0],
        vel=vals_returned[1],
        acc=vals_returned[2],
        vol_state=vals_returned[3],
        rev_vol_state=vals_returned[4],
        influence_state=vals_returned[5],
        undamaged_influence_state=vals_returned[7],
        strain_energy=vals_returned[8],
        time=vals_returned[9]
        )


def ensure_thickness_vector(thickness, num_nodes):
    # Accept float, scalar array, or vector. Return vector of length num_nodes (NumPy/JAX array).
    thickness = jnp.asarray(thickness)
    if thickness.ndim == 0 or thickness.size == 1:
        return jnp.full((num_nodes,), float(thickness))
    if thickness.shape != (num_nodes,):
        raise ValueError(f"Thickness must have shape ({num_nodes},), got {thickness.shape}")
    return thickness


def loss(params, state, thickness_vector:jax.Array, allow_damage:bool, max_time:float):


    output_vals = _solve(params, state, thickness=thickness_vector, allow_damage=allow_damage, max_time=max_time)
    #output_vals = _solve_debug(params, state, thickness0, allow_damage, max_time=max_time)


    checkify.check(jnp.all(jnp.isfinite(output_vals[0])), "NaN in solution")

    jax.debug.print("strain energy : {s}", s=output_vals[7])

    strain_energy = output_vals[7]

    normalization_factor = 1

    total_strain_energy = strain_energy / normalization_factor

    #jax.debug.print("vals[0]: {t}", t=output_vals[0])
    jax.debug.print("total strain energy: {s}", s=total_strain_energy)
    #jax.debug.print("vol_state: {t}", t=vol_state)

    #loss_value = total_strain_energy + 0.001 * jnp.sum(raw_thickness)

    loss_value = total_strain_energy
    
    return loss_value

### Main Program ####
if __name__ == "__main__":
    # Define fixed parameters
    fixed_length = 10.0  # Length of the bar
    delta_x = 0.25  # Element length
    fixed_horizon = 3.6 * delta_x  # Horizon size
    thickness = 1.0  # Thickness of the bar
    critical_stretch = None  # Critical stretch for damage, set to None for no damage

    allow_damage = False

    if critical_stretch is None:
        critical_stretch = 1.0e3   # This is a large value, effectively disabling damage

    if critical_stretch is not None:
        allow_damage = True


    # Initialize the problem with fixed parameters
    params, state = init_problem(bar_length=fixed_length,
                                 density=7850.0,
                                 bulk_modulus=200E9,
                                 number_of_elements=int(fixed_length / delta_x),
                                 horizon=fixed_horizon,
                                 thickness=thickness,
                                 prescribed_force=1.0e7,
                                 critical_stretch=critical_stretch)
    max_time = 1E-3
    #max_time = 1E-6
    max_time = float(max_time)


    thickness0 = ensure_thickness_vector(thickness, params.num_nodes)

    results = _solve(params, state, thickness0, allow_damage, max_time=float(max_time))
    #results = _solve_debug(params, state, thickness0, allow_damage, max_time=max_time)

    #results = _solve(params, state, params.thickness, allow_damage, max_time=max_time)

    #print("disp: ", results.disp)
    #print("vel: ", results.vel)



##################################################
    ################  now using optax to maximize ##########################
    # Initial parameter (scalar for thickness)
    key = jax.random.PRNGKey(0)  # Seed for reproducibility

    # Create a random array with values between 0.5 and 1.0
    #shape = (state.num_nodes,)  # Example shape (adjust as needed)
    #minval = 1.15
    minval = 1.00
    maxval = 2.0

    param = jnp.full((params.num_nodes,), 1.0)

    #param = thickness
    #param = jnp.array([2.0])
    learning_rate = 1E-2
    num_steps = 3

    thickness_min= 1.0E-2
    thickness_max = 1.0E2

    #define gradient bounds
    lower = 1E-2
    upper = 20

    max_time = 1.0E-3

    # Optax optimizer
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(param)

    # Loss function (already defined as 'loss')
    loss_and_grad = jax.value_and_grad(loss, argnums=2)

    #print("param init in optax loop: ", param)

    # Clamp function
    def clamp_params(grads):
        lower = 1E-05
        upper = 1.0E2
        jax.debug.print("entering  clamp_params: {t}", t=grads)
        grads = jax.tree_util.tree_map(lambda x: jnp.clip(x, lower, upper), grads)

        #grads = jnp.clip(grads, a_min=lower, a_max=upper)  # Ensure thickness is within bounds
        jax.debug.print("grad after clamping: {t}", t=grads)

        return grads


    # Optimization loop
    for step in range(num_steps):
        def true_fn(thickness):
            jax.debug.print("thickness is all finite.")
            return thickness

        def false_fn(thickness):
            jax.debug.print("Non-finite thickness detected: {t}", t=thickness)
            return thickness


        jax.debug.print("Initial thickness: {t}", t=param)
        assert jnp.all(jnp.isfinite(param)), "Initial thickness contains NaNs!"

        loss_val, grads = loss_and_grad(params, state, param, allow_damage=allow_damage, max_time=max_time)
        #loss_val, grads = loss_and_grad(param, init_problem, max_time=1.0E-3)
        #jax.debug.print("grads: {g}", g=grads)

        #calling clamp function to restrict gradient
        grads = clamp_params(grads)

        ##print(f"Loss difference: {jnp.abs(loss_val2 - loss_val1)}")

        jax.debug.print("Step {step}, loss: {loss}, thickness: {thickness}, grads: {grads}", step=step, loss=loss_val, thickness=param, grads=grads)
        #breakpoint_if_nonfinite(grads)
        # if jnp.isnan(loss_val) or jnp.any(jnp.isnan(grads)):
        #     print("NaN detected! Stopping optimization.")
        #     break

        updates, opt_state = optimizer.update(grads, opt_state, param)
        param = optax.apply_updates(param, updates)
        ##print(f"Updated param: {param}")
        #print("updated param: ", param)
        #if step % 20 == 0:
            ##print(f"Step {step}, loss: {loss_val}")
            #print(f"Step {step}, loss: {loss_val}, param: {param}")




Initial thickness: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
strain energy : 0.13207894563674927
total strain energy: 0.13207894563674927
entering  clamp_params: [           nan            nan            nan            nan
            nan            nan            nan            nan
            nan  3.8049679e-04  3.6099977e-03  5.2843983e-03
  4.8660608e-03  4.0146741e-03  1.2811575e-03 -2.5271056e-03
 -2.9128604e-03 -1.5753307e-03 -1.3532897e-04 -9.9828350e-04
 -2.4519505e-03 -2.7880346e-04 -1.7339187e-03 -1.8541389e-03
 -3.0508137e-03  1.3138741e-03  3.9915480e-03  5.1036095e-03
  5.2360473e-03  3.5332676e-03  3.2478862e-04 -4.2317244e-03
 -3.8281020e-03 -2.1463435e-03 -4.7509625e-04 -3.3635100e-05
  0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00]
grad after clamping: [          nan           nan           nan           nan           nan
           nan           nan           nan           n

AssertionError: Initial thickness contains NaNs!

## Creating the class

In [133]:
class PDJAX():
    '''
       This class initializes a 1D peridynamic problem.  This problem is essentially
       a long bar with displacement boundary conditions applied to boundary regions
       equal to 1-horizon at each end.

       Initialization parameters are as follows:

       + ``bar_length`` - length of bar, it will be centered at the origin

       + ``number_of_elements`` - the discretization level

       + ``bulk_modulus``

       + ``density``

       + ``thickness``
    '''

    def __init__(self, 
                 bar_length:float=20,
                 number_of_elements:int=20,
                 bulk_modulus:float=100,
                 density:float=1.0,
                 thickness:Union[float, np.ndarray]=1.0,
                 horizon:Union[float, None]=None,
                 critical_stretch:Union[float, None] = None,
                 prescribed_velocity:Union[float, None]=None, 
                 prescribed_force:Union[float, None]=None
                 ) -> None:
        '''
           Initialization function
        '''

        #Problem data
        self.bulk_modulus = bulk_modulus
        self.rho = density

        self.bar_length = bar_length
        self.number_of_elements = number_of_elements

        delta_x = bar_length / number_of_elements

        # This array contains the *element* node locations.  i.e., they define the discrete 
        # regions along the bar. The peridynamic node locations will be at the centroid of 
        # these regions.
        self.nodes = jnp.linspace(-bar_length / 2.0, bar_length / 2.0, num=number_of_elements + 1)

        # Set horizon from parameter list or as default
        if horizon != None:
            self.horizon = horizon
        else:
            self.horizon = delta_x * 3.015

        '''
        if isinstance(thickness, float):
            self.thickness = jnp.ones(number_of_elements) * thickness
        elif isinstance(thickness, np.ndarray):
            self.thickness = jnp.asarray(thickness)
            assert thickness.shape[0] == number_of_elements, ValueError("Thickness array length must match number of elements")
        '''
        if isinstance(thickness, float) or np.isscalar(thickness):
            self.thickness = jnp.ones(number_of_elements) * thickness
        elif isinstance(thickness, np.ndarray) or isinstance(thickness, jnp.ndarray):
            self.thickness = jnp.asarray(thickness)
            assert self.thickness.shape[0] == number_of_elements, ValueError("Thickness array length must match number of elements")
        else:
            raise ValueError("thickness must be a float or array")


        # Compute the pd_node locations, kdtree, nns, reference_position_state, etc.
        self.setup_discretization(self.thickness)

        # Debug plotting
        # _, ax = plt.subplots()
        # self.line, = ax.plot(self.pd_nodes, self.displacement)
        self._allow_damage = False

        self.critical_stretch = critical_stretch

        if self.critical_stretch is not None:
            self.allow_damage() 

        # Set boundary regions
        if prescribed_velocity is not None and prescribed_force is not None:
            raise ValueError("Only one of prescribed_velocity or prescribed_force should be set, not both.")
        
        if prescribed_velocity is None and prescribed_force is None:
            raise ValueError("Either prescribed_velocity or prescribed_force must be set.")

        self.prescribed_velocity = prescribed_velocity
        self.prescribed_force = prescribed_force

        #:The node indices of the boundary region at the left end of the bar
        li = 0
        self.left_bc_mask = self.neighborhood[li] != li
        self.left_bc_region = self.neighborhood[li][self.left_bc_mask]

        #:The node indices of the boundary region at the right end of the bar
        ri = self.num_nodes - 1
        self.right_bc_mask = self.neighborhood[ri] != ri
        self.right_bc_region = self.neighborhood[ri][self.right_bc_mask]

        if prescribed_velocity is not None:
            self.left_bc_region = jnp.asarray(self.tree.query_ball_point(self.pd_nodes[0, None], r=(self.horizon + np.max(self.lengths) / 2.0), p=2, eps=0.0)).sort()
            self.right_bc_region = jnp.asarray(self.tree.query_ball_point(self.pd_nodes[-1, None], r=(self.horizon + np.max(self.lengths) / 2.0), p=2, eps=0.0)).sort()

        self.no_damage_region_left = jnp.asarray(self.tree.query_ball_point(self.pd_nodes[0, None], r=(2.5*self.horizon + np.max(self.lengths) / 2.0), p=2, eps=0.0)).sort()
        #:The node indices of the boundary region at the right end of the bar
        self.no_damage_region_right = jnp.asarray(self.tree.query_ball_point(self.pd_nodes[-1, None], r=(2.5*self.horizon + np.max(self.lengths) / 2.0), p=2, eps=0.0)).sort()

        self.width = 1.0

        # Compute the partial volumes
        vol_state, _ = self.compute_partial_volumes(self.thickness)

        # An  array containing the *influence vector-state* as defined in Silling et al. 2007
        # ratio = self.reference_magnitude_state / self.horizon
        # self.influence_state = jnp.ones_like(self.volume_state) - 35.0 * ratio ** 4.0 + 84.0 * ratio ** 5.0 - 70 * ratio ** 6.0 + 20 * ratio ** 7.0
        self.influence_state = jnp.where(vol_state > 1.0e-16, 1.0, 0.0)

        self.undamaged_influence_state = self.influence_state.copy()

        self.undamaged_influence_state_left = self.influence_state.at[self.no_damage_region_left, :].get()
        self.undamaged_influence_state_right = self.influence_state.at[self.no_damage_region_right, :].get()

        self.strain_energy_total = 0.0

        return

    def compute_partial_volumes(self, thickness:jax.Array):

        # Setup some local (to function) convenience variables
        neigh = self.neighborhood
        lens = self.lengths
        ref_mag_state = self.reference_magnitude_state
        horiz = self.horizon

        #jax.debug.print("neigh: {v}", v=neigh.shape)

        # jax.debug.print("thickness in comp par v: {t}",t=thickness)
        # Initialize the volume_state to the lengths * width * thickness
        width = 1.0
        self.width = width
        vol_state_uncorrected = lens[neigh] * thickness[neigh] * width 

        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(vol_state_uncorrected)))

        #jax.debug.print("thickness[neigh]: {t} ", t=thickness[neigh])

        #jax.debug.print("vol_state_unc in com p v: {v}", v=vol_state_uncorrected.shape)
        #jax.debug.print("ref_mag_state : {l}", l=ref_mag_state.shape)

        #Zero out entries that are not in the family
        vol_state_uncorrected = jnp.where(ref_mag_state < 1.0e-16, 0.0, vol_state_uncorrected) 
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(vol_state_uncorrected)))

        #jax.debug.print("thickness: {t}", t=thickness)
        #jax.debug.print("neigh: {n}", n=neigh)
        #jax.debug.print("ref_mag_state: {r}", r=ref_mag_state)
        #jax.debug.print("vol_state_uncorrected: {v}", v=vol_state_uncorrected)
        #jax.debug.print("vol_state: {v}", v=vol_state)

        vol_state = jnp.where(ref_mag_state < horiz + lens[neigh] / 2.0, vol_state_uncorrected, 0.0)
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(vol_state)))


        # Check to see if the neighboring node has a partial volume
        is_partial_volume = jnp.abs(horiz - ref_mag_state) < lens[neigh] / 2.0
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(is_partial_volume)))

        # Two different scenarios:
        is_partial_volume_case1 = is_partial_volume * (ref_mag_state >= horiz)
        is_partial_volume_case2 = is_partial_volume * (ref_mag_state < horiz)
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(is_partial_volume_case1)))
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(is_partial_volume_case2)))


        # Compute the partial volumes conditionally
        vol_state = jnp.where(is_partial_volume_case1, (lens[neigh] / 2.0 - (ref_mag_state - horiz)) * width * thickness[neigh], vol_state)
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(vol_state)))
        vol_state = jnp.where(is_partial_volume_case2, (lens[neigh] / 2.0 + (horiz - ref_mag_state)) * width * thickness[neigh], vol_state)
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(vol_state)))

        # If the partial volume is predicted to be larger than the unocrrected volume, set it back
        # vol_state = jnp.where(vol_state > vol_state_uncorrected, vol_state_uncorrected, vol_state)
        vol_state = jnp.where(vol_state > vol_state_uncorrected, vol_state_uncorrected, vol_state)
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(vol_state)))


        # Now compute the "reverse volume state", this is the partial volume of the "source" node, i.e. node i,
        # as seen from node j.  This doesn't show up anywhere in any papers, it's just used here for computational
        # convenience
        vol_array = lens[:,None] * width * thickness[:, None]
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(vol_array)))

        rev_vol_state = jnp.ones_like(vol_state) * vol_array
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(rev_vol_state)))
        
        rev_vol_state = jnp.where(is_partial_volume_case1, (lens[:, None] / 2.0 - (ref_mag_state - horiz)) * width * thickness[:, None], rev_vol_state)
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(rev_vol_state)))
        
        rev_vol_state = jnp.where(is_partial_volume_case2, (lens[:, None] / 2.0 + (horiz - ref_mag_state)) * width * thickness[:, None], rev_vol_state)
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(rev_vol_state)))
       
        #If the partial volume is predicted to be larger than the uncorrected volume, set it back
        rev_vol_state = jnp.where(rev_vol_state > vol_array, vol_array, rev_vol_state)
        #jax.debug.print("Any NaNs? {y}", y=jnp.any(jnp.isnan(rev_vol_state)))

        # Set attributes
        # self.volume_state = vol_state
        # self.reverse_volume_state = rev_vol_state

        #jax.debug.breakpoint()

        return (vol_state, rev_vol_state)

    def allow_damage(self):
        self._allow_damage = True
        return

    
    def setup_discretization(self, thickness:jax.Array):
        
        nodes = self.nodes

        # The lengths of the *elements*
        self.lengths = jnp.array(nodes[1:] - nodes[0:-1]) 

        # The PD nodes are the centroids of the elements
        self.pd_nodes = jnp.array(nodes[0:-1] + self.lengths / 2.0)
        self.num_nodes = self.pd_nodes.shape[0]

        # Create's a kdtree to do nearest neighbor search
        self.tree = scipy.spatial.cKDTree(self.pd_nodes[:,None])


        # Get PD nodes in the neighborhood of support + largest node spacing, this will
        # find all potential partial volume nodes as well. The distances returned from the
        # search turn out to be the reference_magnitude_state, so we'll store them now
        # to avoid needed to calculate later.
        #set k=100, used to be k=6, will trim down later
        reference_magnitude_state, neighborhood = self.tree.query(self.pd_nodes[:,None], 
                k=100, p=2, eps=0.0, distance_upper_bound=(self.horizon + np.max(self.lengths) / 2.0))

        #jax.debug.print("ref_mag_state before: {r}",r=reference_magnitude_state)
        #trying to delete first column of ref_mag_state for broadcasting issue
        reference_magnitude_state = jnp.delete(reference_magnitude_state, 0, 1)  
        #jax.debug.print("ref_mag_state after initial trim: {r}",r=reference_magnitude_state )

        #jax.debug.print("ref_mag_state.shape after: {r}",r=reference_magnitude_state.shape)

        self.num_neighbors = jnp.asarray((neighborhood != self.tree.n).sum(axis=1)) - 1
        self.max_neighbors = np.max((neighborhood != self.tree.n).sum(axis=1))

        # Convert to JAX arrays and trim down excess neighbors
        neighborhood = jnp.asarray(neighborhood[:, :self.max_neighbors])
        #self.reference_magnitude_state = jnp.delete(reference_magnitude_state[:, :self.max_neighbors], 0,0)

        #changed to select just the first row and alleviate the broadcasting error
        #self.reference_magnitude_state = jnp.delete(reference_magnitude_state[1, :self.max_neighbors], 0,0)


        self.reference_magnitude_state = reference_magnitude_state[0:, :self.max_neighbors-1]
        #self.reference_magnitude_state = reference_magnitude_state[:, 1:self.max_neighbors]    
        

        #jax.debug.print("self.ref_mag_state.shape later: {r}",r=self.reference_magnitude_state)
        # Cleanup neighborhood
        row_indices = jnp.arange(neighborhood.shape[0]).reshape(-1, 1)
        neighborhood = jnp.where(neighborhood == self.tree.n, row_indices, neighborhood)
        self.neighborhood = jnp.delete(neighborhood,0,1)

        # Compute the reference_position_state.  Using the terminology of Silling et al. 2007
        self.reference_position_state = self.pd_nodes[self.neighborhood] - self.pd_nodes[:,None]

        # Cleanup reference_magnitude_state
        self.reference_magnitude_state = jnp.where(self.reference_magnitude_state == np.inf, 0.0, self.reference_magnitude_state)
        #jax.debug.print("self.ref_mag_state.shape at end of setup disc {r}",r=self.reference_magnitude_state.shape)

        return



    def introduce_flaw(self, location:float, allow_damage=False):  

        if allow_damage:
            self.allow_damage()

        _, nodes_near_flaw = self.tree.query(location, k=self.max_neighbors, p=2, eps=0.0, 
                                             distance_upper_bound=(self.horizon + np.max(self.lengths)/2))

        # The search above will produce duplicate neighbor nodes, make them into a
        # unique 1-dimensional list
        nodes_near_flaw = np.array(np.unique(nodes_near_flaw), dtype=np.int64)
    
        # Remove the dummy entries
        nodes_near_flaw = nodes_near_flaw[nodes_near_flaw != self.tree.n]

        families = jnp.asarray(self.neighborhood)
        # Loop over nodes near the crack to see if any bonds in the nodes family
        # cross the crack path
        for idx in nodes_near_flaw:
            # Loop over node family
            node_family = families[idx][families[idx] != idx]
            for bond_idx, end_point_idx in enumerate(node_family):
                # Define the bond line segment as the line between the node and its
                # endpoint.
                min_node, max_node = np.sort([self.pd_nodes[idx], self.pd_nodes[end_point_idx]])

                if min_node < location and location < max_node:
                    self.influence_state = self.influence_state.at[idx, bond_idx].set(0.2)
        return

## Importing packages

In [2]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from functools import partial

import jax
import jax.numpy as jnp

import numpy as np
import scipy.spatial
import scipy.optimize
import matplotlib.pyplot as plt


import jax.scipy
import jax.scipy.optimize
from jax.scipy.optimize import minimize
from jax.nn import softplus

from jax import grad, jit


from typing import Union, Tuple

import optax
from jax import jit, vmap, grad, value_and_grad
from jax import lax
from jax import grad
from jax.experimental import checkify

from jax import make_jaxpr

## Global Functions

In [None]:
##### note will need to add damage functions here, but for now just want to see if we can get it to work


def compute_partial_volumes(problem1: PDJAX,
    thickness: jax.Array
) -> tuple[jax.Array, jax.Array]:
    """
    Fully explicit, JAX-compatible partial volume computation.
    All required arrays and constants must be passed as arguments.
    """
    neigh = problem1.neighborhood
    lens = problem1.lengths
    ref_mag_state = problem1.reference_magnitude_state
    horiz = problem1.horizon
    width = problem1.width

    # Initialize the volume_state to the lengths * width * thickness
    vol_state_uncorrected = lens[neigh] * thickness[neigh] * width

    # Zero out entries that are not in the family
    vol_state_uncorrected = jnp.where(ref_mag_state < 1.0e-16, 0.0, vol_state_uncorrected)

    vol_state = jnp.where(ref_mag_state < horiz + lens[neigh] / 2.0, vol_state_uncorrected, 0.0)

    # Check to see if the neighboring node has a partial volume
    is_partial_volume = jnp.abs(horiz - ref_mag_state) < lens[neigh] / 2.0

    # Two different scenarios:
    is_partial_volume_case1 = is_partial_volume * (ref_mag_state >= horiz)
    is_partial_volume_case2 = is_partial_volume * (ref_mag_state < horiz)

    # Compute the partial volumes conditionally
    vol_state = jnp.where(
        is_partial_volume_case1,
        (lens[neigh] / 2.0 - (ref_mag_state - horiz)) * width * thickness[neigh],
        vol_state,
    )
    vol_state = jnp.where(
        is_partial_volume_case2,
        (lens[neigh] / 2.0 + (horiz - ref_mag_state)) * width * thickness[neigh],
        vol_state,
    )

    # If the partial volume is predicted to be larger than the uncorrected volume, set it back
    vol_state = jnp.where(vol_state > vol_state_uncorrected, vol_state_uncorrected, vol_state)

    # Now compute the "reverse volume state"
    vol_array = lens[:, None] * width * thickness[:, None]
    rev_vol_state = jnp.ones_like(vol_state) * vol_array

    rev_vol_state = jnp.where(
        is_partial_volume_case1,
        (lens[:, None] / 2.0 - (ref_mag_state - horiz)) * width * thickness[:, None],
        rev_vol_state,
    )
    rev_vol_state = jnp.where(
        is_partial_volume_case2,
        (lens[:, None] / 2.0 + (horiz - ref_mag_state)) * width * thickness[:, None],
        rev_vol_state,
    )

    rev_vol_state = jnp.where(rev_vol_state > vol_array, vol_array, rev_vol_state)

    return vol_state, rev_vol_state


# Compute the force vector-state using a LPS peridynamic formulation
def compute_force_state_LPS(disp, ref_pos, ref_pos_state, ref_mag_state, inf_state, vol_state, no_damage_region_left,
                            no_damage_region_right, undamaged_influence_state_left, undamaged_influence_state_right,
                            critical_stretch, allow_damage, critical_stretch, K):
    """
    Fully explicit, JAX-compatible force state computation for LPS peridynamics.
    All required arrays and constants must be passed as arguments.
    """
    # Compute deformed positions
    def_pos = ref_pos + disp

    # Compute deformation state
    def_state = def_pos[ref_pos_state.shape[0], ...] - def_pos[:, None]  # ref_pos_state.shape[0] == neighborhood
    # But usually, you want: def_state = def_pos[neigh] - def_pos[:, None]
    # So, pass neigh as an argument and use: def_state = def_pos[neigh] - def_pos[:, None]

    # For clarity, let's assume you pass neigh as an argument:
    # def_state = def_pos[neigh] - def_pos[:, None]

    # Compute deformation magnitude state
    def_mag_state = jnp.sqrt(def_state * def_state)

    # Compute deformation unit state
    def_unit_state = jnp.where(def_mag_state > 1.0e-16, def_state / def_mag_state, 0.0)

    # Compute scalar extension state
    exten_state = def_mag_state - ref_mag_state

    # Compute stretch
    stretch = jnp.where(ref_mag_state > 1.0e-16, exten_state / ref_mag_state, 0.0)

    # Apply critical stretch fracture criteria
    if allow_damage:
        inf_state = jnp.where(stretch > critical_stretch, 0.0, inf_state)
        inf_state = inf_state.at[no_damage_region_left, :].set(undamaged_influence_state_left)
        inf_state = inf_state.at[no_damage_region_right, :].set(undamaged_influence_state_right)

    # Compute shape tensor (weighted volume)
    epsilon = 1e-8
    shape_tens = (inf_state * ref_pos_state * ref_pos_state * vol_state).sum(axis=1)
    shape_tens = jnp.where(jnp.abs(shape_tens) < epsilon, epsilon, shape_tens)

    # Compute scalar force state for elastic constitutive model
    scalar_force_state = 9.0 * K / shape_tens[:, None] * exten_state

    # Bond strain energy calculation
    bond_strain_energy = 9.0 * K / shape_tens[:, None] * exten_state * exten_state * ref_mag_state

    # Compute the force state
    force_state = inf_state * scalar_force_state * def_unit_state

    return force_state, inf_state, bond_strain_energy

def smooth_ramp(self, t, t0, c=1.0, beta=5.0):
    """
    Function that linearly ramps up to c at t0, then smoothly transitions to c.

    Parameters:
    - t: Time variable (scalar or numpy array).
    - t0: Time at which the transition occurs.
    - c: Final constant value after transition.
    - beta: Smoothness parameter (higher values = sharper transition).

    Returns:
    - f: Value of the function at time t.
    """
    # Linear ramp before t0 (with slope c/t0)
    linear_ramp = (c / t0) * t

    # Smooth transition using an exponential decay term
    smooth_transition = c * (1 - jnp.exp(-beta * (t - t0))) + (c / t0) * t0

    # Use `np.where` to define the piecewise function
    f = jnp.where(t < t0, linear_ramp, smooth_transition)
    
    return f

# Internal force calculation
#@partial(jit, static_argnums=(0,))
def compute_internal_force(disp, vol_state, rev_vol_state, inf_state,
                           thickness, time, neighborhood, left_bc_region, right_bc_region,
                            width, prescribed_force, smooth_ramp_fn, num_nodes):
    """
    Fully explicit, JAX-compatible internal force calculation.
    All required arrays and constants must be passed as arguments.
    """


    # Compute force state and bond strain energy
    force_state, inf_state, bond_strain_energy = compute_force_state_LPS(
            disp, ref_pos, ref_pos_state, ref_mag_state, inf_state, vol_state, no_damage_region_left,
            no_damage_region_right, undamaged_influence_state_left, undamaged_influence_state_right,
            critical_stretch, allow_damage, critical_stretch, K)

    # Integrate nodal forces
    force = (force_state * vol_state).sum(axis=1)
    force = force.at[neighborhood].add(-force_state * rev_vol_state)

    # Strain energy per node
    strain_energy = 0.5 * (bond_strain_energy * vol_state).sum(axis=1)

    # Apply boundary forces if needed
    if prescribed_force is not None:
        li = 0
        ramp_force = smooth_ramp_fn(time, t0=1.e-5, c=prescribed_force)
        denom = vol_state[li].sum() + rev_vol_state[li][0]

        # Compute the left boundary condition nodal forces
        left_bc_area = width * thickness[left_bc_region]
        left_bc_nodal_forces = (ramp_force * left_bc_area)/(vol_state[li].sum() + rev_vol_state[li][0])
        force = force.at[left_bc_region].add(-left_bc_nodal_forces)

        # For the leftmost node (if needed)
        left_bc_area_li = width * thickness[li]
        force = force.at[li].add(-ramp_force * left_bc_area_li)

        # Compute the right boundary condition nodal forces
        ri = num_nodes - 1
        right_bc_area = width * thickness[right_bc_region]
        right_bc_nodal_forces = (ramp_force * right_bc_area)/(vol_state[ri].sum() + rev_vol_state[ri][0])
        force = force.at[right_bc_region].add(right_bc_nodal_forces)

        # For the rightmost node (if needed)
        right_bc_area_ri = width * thickness[ri]
        force = force.at[ri].add(-ramp_force * right_bc_area_ri)

    return force, inf_state, strain_energy

def solve_one_step_explicit(vals:Tuple[jax.Array, jax.Array, jax.Array, 
                                    jax.Array, jax.Array, jax.Array, 
                                    jax.Array, jax.Array, jax.Array, jax.Array, 
                                    jax.Array, jax.Array, float, float, float, float, float]):


    (disp, vel, acc, vol_state, rev_vol_state, influence_state, thickness, undamaged_influence_state, 
            strain_energy, left_bc_region, right_bc_region, pd_nodes, time, time_step, prescribed_velocity, bar_length, rho, num_nodes) =  vals
    #compute_internal_force: Callable,
    """
    One step of time integration using Velocity-Verlet, in JAX-compatible functional style.
    """

    vals

    # Mid-step update of displacement
    disp_new = disp + vel * time_step + 0.5 * time_step**2 * acc

    # Apply boundary conditions (left edge fixed)
    disp_new = disp_new.at[left_bc_region].set(0.0)

    # using JIT compilation for compute_internal_force
    compute_internal_force_jit = jax.jit(compute_internal_force)

    # Compute updated internal force
    internal_force, inf_state, strain_energy = compute_internal_force_jit(
                    disp, vol_state, rev_vol_state, inf_state, thickness, time, neighborhood, 
                    left_bc_region, right_bc_region, width, prescribed_force, smooth_ramp_fn, num_nodes)

    # Compute updated acceleration
    acc_new = internal_force / rho

    # Mid-step update of velocity
    vel_new = vel + 0.5 * time_step * (acc + acc_new)

    # Apply prescribed velocity at right edge
    vel_new = vel_new.at[right_bc_region].set(prescribed_velocity)

    # Energy update
    strain_energy_total = jnp.sum(strain_energy)

    # Update time
    time_new = time + time_step

    

    return (disp, vel, acc, vol_state, rev_vol_state, inf_state, thickness,strain_energy_total, time_new)
    

def _solve_explicit(
    problem1: PDJAX,
    thickness: jax.Array,
    max_time: float,
      # A function: vals -> new_vals
    ) -> tuple:
    """
    Pure JAX-compatible version of time integrator using explicit data dependencies.
    """


    lengths = problem1.lengths
    horizon = problem1.horizon
    ref_mag_state = problem1.reference_magnitude_state
    neighborhood = problem1.neighborhood
    influence_state = problem1.influence_state
    undamaged_influence_state = problem1.undamaged_influence_state
    num_nodes = problem1.num_nodes

    # Time stepping
    time_step = 1.0e-6
    num_steps = int(max_time / time_step)

    # Compute volume states
    vol_state, rev_vol_state = compute_partial_volumes(
        thickness, neighborhood, lengths, ref_mag_state, horizon, width=1.0)

    # Initialize fields

    disp = jnp.zeros_like(problem1.nodes)
    vel = jnp.zeros_like(problem1.nodes)
    acc = jnp.zeros_like(problem1.nodes)
    time = 0.0
    strain_energy = 0.0

    #to unpack from problem in hopes of alleviating tracer errors
    prescribed_velocity = problem1.prescribed_velocity
    bar_length =  problem1.bar_length
    left_bc_region = problem1.left_bc_region
    right_bc_region = problem1.right_bc_region
    pd_nodes = problem1.pd_nodes
    rho = problem1.rho

        disp, ref_pos, ref_pos_state, ref_mag_state, inf_state, vol_state, no_damage_region_left,
    no_damage_region_right, undamaged_influence_state_left, undamaged_influence_state_right,
    critical_stretch, allow_damage, critical_stretch, K)

    # Package values into tuple
    vals = (disp, vel, acc, vol_state, rev_vol_state, influence_state, thickness, undamaged_influence_state, 
            strain_energy,time, left_bc_region, right_bc_region, pd_nodes, time, time_step, prescribed_velocity, bar_length, rho, num_nodes)

    # Time loop using fori_loop
    def loop_body(i, vals):
        return solve_one_step_explicit(*vals)

    vals_returned = jax.lax.fori_loop(0, num_steps, loop_body, vals)

    return vals_returned

def solve_explicit(problem1,
    thickness: jax.Array,
    max_time: float,
    # Add any other fields you were using from self._solve
) -> tuple:
    """
    JAX-compatible version of solve that explicitly passes all dependencies.
    """

    vals = _solve_explicit(problem1,
        thickness=thickness,
        max_time=max_time)
    
    displacement = vals[0]
    velocity = vals[1]
    acceleration = vals[2]
    influence_state = vals[5]
    strain_energy_total = vals[8]

    return displacement, velocity, acceleration, influence_state, strain_energy_total


def loss(thickness:jax.Array, problem1:PDJAX, max_time:float):

    disp, vel, acc, influence, strain_energy_total = solve_explicit(
        thickness=thickness,
        pd_nodes=problem1.pd_nodes,
        displacement=problem1.displacement,
        velocity=problem1.velocity,
        acceleration=problem1.acceleration,
        horizon=problem1.horizon,
        family=problem1.family,
        volume=problem1.volume,
        boundary_mask=problem1.boundary_mask,
        max_time=max_time,
    )


    checkify.check(jnp.all(jnp.isfinite(strain_energy_total)), "NaN in solution")


    # Check for NaNs in energy
    def nan_debug_print(x):
        jax.debug.print("NaNs detected in strain_energy_total")
        return x

    def no_nan(x):
        return x


    # This will always execute, but only print if the condition is true
    _ = jax.lax.cond(
        ~jnp.all(jnp.isfinite(strain_energy_total)),
        nan_debug_print,
        no_nan,
        operand=strain_energy_total )

    jax.debug.print("total strain energy: {s}", s=strain_energy_total)


    #loss_value = total_strain_energy + 0.001 * jnp.sum(raw_thickness)

    return strain_energy_total



## Main Program Creating the instance

In [11]:
### Main Program ####
if __name__ == "__main__":

    #Define problem size
    fixed_length = 10.0
    delta_x = 0.25
    fixed_horizon = 3.6 * delta_x

    # Initial parameter (scalar for thickness)
    key = jax.random.PRNGKey(0)  # Seed for reproducibility

    # Create a random array with values between 0.5 and 1.0
    shape = (fixed_length/delta_x,)  # Example shape (adjust as needed)
    #minval = 1.15
    minval = 1.15
    maxval = 1.50
    #thickness = jax.random.uniform(key, shape=shape, minval=minval, maxval=maxval)
    thickness = 1.15
    

    #Instantiate a 1d peridynamic problem with equally spaced nodes
    problem1 = PDJAX(bar_length=fixed_length,
                     density=7850.0,
                     bulk_modulus=200E9,
                     number_of_elements=int(fixed_length/delta_x), 
                     horizon=fixed_horizon,
                     thickness=thickness,
                     prescribed_force=1.0e8)
                     #critical_stretch=1.0e-4)
    


### To run forward problem

In [13]:
    # Set up initial conditions
    #######################################
    #problem1.introduce_flaw(0.0)

    vals = solve_explicit(problem1, thickness, max_time=1e-03)
    print("Displacement:", vals[0])
        # #    

TypeError: solve_explicit() missing 4 required positional arguments: 'displacement', 'velocity', 'acceleration', and 'fixed_horizon'

### Opimization Code

In [None]:
    ################  now using optax to maximize ##########################
    # Initial parameter (scalar for thickness)
    key = jax.random.PRNGKey(0)  # Seed for reproducibility

    # Create a random array with values between 0.5 and 1.0
    shape = (problem1.num_nodes,)  # Example shape (adjust as needed)
    #minval = 1.15
    minval = 1.00
    maxval = 2.0
    #param_int = jax.random.uniform(key, shape=shape, minval=minval, maxval=maxval)
    #init_param = jax.random.uniform(key, shape=shape, minval=minval, maxval=maxval)
    param = jnp.array([1.0])
    #param =  softplus(init_param)
    #param =  jnp.ones(problem1.num_nodes) * # Initial thickness guess
    
    #param = thickness
    #param = jnp.array([2.0])
    learning_rate = 1E-2
    num_steps = 5

    thickness_min= 1.0E-2
    thickness_max = 1.0E2

    #define gradient bounds
    lower = 1E-2
    upper = 20

    # Optax optimizer
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(param)

    # Loss function (already defined as 'loss')
    loss_and_grad = jax.value_and_grad(loss)

    #print("param init in optax loop: ", param)

    # Clamp function
    def clamp_params(grads):
        lower = 1E-03
        upper = 1.0E2
        jax.debug.print("entering  clamp_params: {t}", t=grads)
        grads = jax.tree_util.tree_map(lambda x: jnp.clip(x, lower, upper), grads)

        #grads = jnp.clip(grads, a_min=lower, a_max=upper)  # Ensure thickness is within bounds
        jax.debug.print("grad after clamping: {t}", t=grads)

        return grads


    # Optimization loop
    for step in range(num_steps):
        def true_fn(thickness):
            jax.debug.print("thickness is all finite.")
            return thickness

        def false_fn(thickness):
            jax.debug.print("Non-finite thickness detected: {t}", t=thickness)
            return thickness


        jax.debug.print("Initial thickness: {t}", t=param)
        assert jnp.all(jnp.isfinite(param)), "Initial thickness contains NaNs!"

        #h = 1.0E-38
        jax.debug.print("initial param: {p}", p=param)
        loss_val, grads = loss_and_grad(param, problem1, max_time=1.0E-3)
        jax.debug.print("grads: {g}", g=grads)
        #loss_val2, grads = loss_and_grad(param + h, problem1)
        #loss_val = loss_val1
        #grads = jnp.abs(loss_val2 - loss_val1) / h 

        #calling clamp function to restrict gradient
        grads = clamp_params(grads)

        ##print(f"Loss difference: {jnp.abs(loss_val2 - loss_val1)}")

        jax.debug.print("Step {step}, loss: {loss}, grads: {grads}", step=step, loss=loss_val, grads=grads)
        #breakpoint_if_nonfinite(grads)
        # if jnp.isnan(loss_val) or jnp.any(jnp.isnan(grads)):
        #     print("NaN detected! Stopping optimization.")
        #     break

        updates, opt_state = optimizer.update(grads, opt_state, param)
        param = optax.apply_updates(param, updates)
        ##print(f"Updated param: {param}")
        #print("updated param: ", param)
        #if step % 20 == 0:
            ##print(f"Step {step}, loss: {loss_val}")
            #print(f"Step {step}, loss: {loss_val}, param: {param}")

    # Use the optimized thickness
    opt_thickness = param

    problem1.solve_explicit(thickness,nodes,displacement,velocity,acceleration,fixed_horizon,max_time=1.0e-3)
    
    vals = problem1._solve_explicit(opt_thickness,
            nodes,displacement,velocity,acceleration,
            fixed_horizon, max_time=1.0e-3)
    
    jax.debug.print("opt_thickness: {t}", t=opt_thickness)
    jax.debug.print("strain energy: {e}", e=vals[8])
    # fig, ax = plt.subplots()
    # ax.plot(problem1.get_nodes(), vals[0], 'ko')
    # ax.set_xlabel(r'$x$')
    # ax.set_ylabel(r'displacement')
    # plt.show()


    

    #init_param = jax.random.uniform(key, shape=shape, minval=minval, maxval=maxval)
    '''
    #to do some sanity checks
    thickness = jnp.ones((problem1.num_nodes,)) * 1.0
    #thickness = jnp.array([1.0])
    problem1.solve(max_time=1.0e-3)
    

    
    loss_val = loss(thickness, problem1, max_time=1e-3)
    print("LOSS VALUE: ", loss_val)
    print("LOSS FINITE? ", jnp.isfinite(loss_val))
    g = grad(loss)(thickness, problem1, max_time=1e-3)
    print("Gradient: ", g)

    print("Gradient finite? ", jnp.all(jnp.isfinite(g)))
        

    ### jax code that shows where nans first appear
    #checked_loss = checkify.checkify(grad(loss))

    #errors, out = checked_loss(thickness, problem1, max_time=1e-3)
    #print("Output:", out)
    #print("Errors:", errors.get())

    fig, ax = plt.subplots()
    ax.plot(problem1.get_nodes(), problem1.get_solution(), 'ko')
    ax.set_xlabel(r'$x$')
    ax.set_ylabel(r'displacement')
    plt.show()

    '''

    # Check loss and its gradient are finite
    #loss_val = loss(thickness, problem1, max_time=1e-3)
    #print("Loss:", loss_val)
    #print("Loss finite?", jnp.isfinite(loss_val))

    #grads = jax.grad(loss)(thickness, problem1, max_time=1e-3)
    #print("Grads:", grads)
    #print("Grads finite?", jax.tree_util.tree_all(jax.tree_map(jnp.isfinite, grads)))