In [9]:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt

One thing I want to be present in JAX-SPH is a automatic mesh generation for some simple shapes. If everything can be written in JAX, then that gives the opportunity to make things end-to-end differentiable.

In this notebook, I want to work on:
1. Given an input shape (something relatively simple) and a known quantity of particles, distribute those particles within the shape.
2. Plot those particles using Paraview.

First, how to describe a shape? Maybe I can start "simple" with a tensile bar. Something complex but useful. Especially knowing the sensitivities w.r.t. thickness, lengths, etc. would be nice. A tensile bar can be described well by lines and ellipses. Plus, it has some symmetry so we only need to worry about one quadrant.

In [10]:
# all units in [mm]
grip_width = 10
grip_length = 10
gage_width = 5
gage_length = 30
thickness = 3
fillet_major_diameter = 20
fillet_minor_diameter = 10
ideal_spacing = 0.5 # ideal distance between particles, will be enforced to be less than this

# get the center point of the fillet in the first quadrant
fillet_center_x = fillet_minor_diameter/2 + gage_width/2
fillet_center_y = gage_length/2

# get the point where the fillet changes to the grip section
fillet_upper_corner_x = grip_width/2
fillet_upper_corner_y = ((1 - (grip_width/2 - fillet_center_x)**2 / (fillet_minor_diameter/2)**2) * (fillet_major_diameter/2)**2)**0.5 + fillet_center_y

# determine the number of particles in the two "constant" directions
num_particles_y = int(2*(fillet_upper_corner_y+grip_length)/ideal_spacing)+1
num_particles_z = int(thickness/ideal_spacing)+1

y = jnp.linspace(-fillet_upper_corner_y-grip_length,fillet_upper_corner_y+grip_length,num_particles_y)
z = jnp.linspace(-thickness/2,thickness/2,num_particles_z)

condlist = [
    jnp.logical_or(y < -fillet_upper_corner_y, y > fillet_upper_corner_y),
    jnp.logical_and(y < -gage_length/2, y >= -fillet_upper_corner_y),
    jnp.logical_and(y > gage_length/2, y <= fillet_upper_corner_y)
]

funclist = [
    grip_width,
    lambda y: (fillet_center_x - ((1 - (y+fillet_center_y)**2 / (fillet_major_diameter/2)**2) * (fillet_minor_diameter/2)**2)**0.5)*2, 
    lambda y: (fillet_center_x - ((1 - (y-fillet_center_y)**2 / (fillet_major_diameter/2)**2) * (fillet_minor_diameter/2)**2)**0.5)*2, 
    gage_width
]

widths = jnp.piecewise(y,condlist,funclist)
num_particles_x = (widths/ideal_spacing).astype(int)+1

def generate_slice(width, thickness, num_width, num_thick, y_coord):
    x = jnp.tile(jnp.linspace(-width/2,width/2,num_width),num_thick)
    z = jnp.repeat(jnp.linspace(-thickness/2,thickness/2,num_thick),num_width)
    y = jnp.repeat(y_coord, num_thick*num_width)
    return jnp.stack([x,y,z],axis=1)

for i in range(num_particles_y):
    slice_i = generate_slice(widths[i], thickness, num_particles_x[i], num_particles_z, y[i])
    if i == 0:
        particles = slice_i
    else:
        particles = jnp.concatenate((particles,slice_i))

particles.shape

(13839, 3)

In [13]:
# Intensive to load here, just forewarning...

import pyvista
import numpy as onp
pyvista.set_jupyter_backend('client')

In [15]:
pdata = pyvista.PolyData(onp.array(particles))
sphere = pyvista.Sphere(radius = ideal_spacing/2)
pc = pdata.glyph(scale=False, geom=sphere, orient=False)
# pc.plot()
pc.save("test.vtk")