I want to build off Sunday's work by wrapping the tensile bar stuff into one function and seeing if I can speed things up a bit.

In [51]:
import jax.numpy as jnp
import jax

import pyvista
import numpy as onp
pyvista.set_jupyter_backend('client')

I think I will have to split things up. One function will get the numbers of points in the different directions, the other will generate the actual point positions. This is so the point position generator will be jitable and differentiable. Maybe they can be wrapped together if jitability is not important?

In [52]:
def get_tensile_bar_numbers(grip_width:            float,
                            grip_length:           float,
                            gage_width:            float,
                            gage_length:           float,
                            thickness:             float,
                            fillet_major_diameter: float,
                            fillet_minor_diameter: float,
                            ideal_spacing:         float
                            ) -> jnp.array:
    # just to get the coordinates straight, (0,0,0) is in the middle of the gage section, middle of the thickness
    # the y axis is pointing vertically along the tensile direction

    # center point of the fillet in the first quadrant
    fillet_center_x = fillet_minor_diameter/2 + gage_width/2
    fillet_center_y = gage_length/2

    # point where the fillet changes to the grip section
    fillet_upper_corner_x = grip_width/2
    fillet_upper_corner_y = ((1 - (grip_width/2 - fillet_center_x)**2 / (fillet_minor_diameter/2)**2) * (fillet_major_diameter/2)**2)**0.5 + fillet_center_y

    # length of the fillet region
    fillet_length = fillet_upper_corner_y - gage_length/2

    # number of particles in the different regions along y
    num_p_grip_y = int(grip_length/ideal_spacing)
    num_p_fillet_y = int(fillet_length/ideal_spacing)+2
    num_p_gage_y = int(gage_length/ideal_spacing)-1

    # determine the number of particles along z
    num_p_z = int(thickness/ideal_spacing)+1

    # starting from the bottom grip
    num_p_grip_x = int(grip_width/ideal_spacing)+1
    widths = jnp.repeat(num_p_grip_x,num_p_grip_y)
    heights = jnp.repeat(num_p_grip_y,num_p_grip_y)
    x_ids = jnp.repeat(jnp.repeat(jnp.arange(0,num_p_grip_x),num_p_grip_y),num_p_z)
    y_ids = jnp.tile(jnp.tile(jnp.arange(0,num_p_grip_y),num_p_z),num_p_grip_x)
    z_ids = jnp.tile(jnp.tile(jnp.arange(0,num_p_z),num_p_grip_x),num_p_grip_y)
    xyz_ids = jnp.stack([x_ids,y_ids,z_ids],axis=1)
    section_ids = jnp.repeat(0,num_p_grip_y)

    # now the lower fillet
    y = jnp.linspace(-fillet_upper_corner_y,-gage_length/2,num_p_fillet_y)
    w = (fillet_center_x - ((1 - (y+fillet_center_y)**2 / (fillet_major_diameter/2)**2) * (fillet_minor_diameter/2)**2)**0.5)*2
    for j in range(num_p_fillet_y):
        num_p_fillet_x = int(w[j]/ideal_spacing)+1
        widths = jnp.concatenate([widths,jnp.array([num_p_fillet_x])])
        heights = jnp.concatenate([heights,jnp.array([num_p_fillet_y])])
        x_ids = jnp.repeat(jnp.arange(0,num_p_fillet_x),num_p_z)
        y_ids = jnp.repeat(j+num_p_grip_y,num_p_fillet_x*num_p_z)
        z_ids = jnp.tile(jnp.arange(0,num_p_z),num_p_fillet_x)
        xyz_ids = jnp.concatenate([xyz_ids,jnp.stack([x_ids,y_ids,z_ids],axis=1)],axis=0)
        section_ids = jnp.concatenate([section_ids,jnp.array([1])])

    # the gage section
    num_p_gage_x = int(gage_width/ideal_spacing)+1
    widths = jnp.concatenate([widths,jnp.repeat(num_p_gage_x,num_p_gage_y)])
    heights = jnp.concatenate([heights,jnp.repeat(num_p_gage_y,num_p_gage_y)])
    x_ids = jnp.repeat(jnp.repeat(jnp.arange(0,num_p_gage_x),num_p_gage_y),num_p_z)
    y_ids = jnp.tile(jnp.tile(jnp.arange(num_p_grip_y+num_p_fillet_y,num_p_grip_y+num_p_fillet_y+num_p_gage_y),num_p_z),num_p_gage_x)
    z_ids = jnp.tile(jnp.tile(jnp.arange(0,num_p_z),num_p_gage_x),num_p_gage_y)
    xyz_ids = jnp.concatenate([xyz_ids,jnp.stack([x_ids,y_ids,z_ids],axis=1)],axis=0)
    section_ids = jnp.concatenate([section_ids,jnp.repeat(2,num_p_gage_y)])

    # the upper fillet
    y = jnp.linspace(gage_length/2,fillet_upper_corner_y,num_p_fillet_y)
    w = (fillet_center_x - ((1 - (y-fillet_center_y)**2 / (fillet_major_diameter/2)**2) * (fillet_minor_diameter/2)**2)**0.5)*2
    for j in range(num_p_fillet_y):
        num_p_fillet_x = int(w[j]/ideal_spacing)+1
        widths = jnp.concatenate([widths,jnp.array([num_p_fillet_x])])
        heights = jnp.concatenate([heights,jnp.array([num_p_fillet_y])])
        x_ids = jnp.repeat(jnp.arange(0,num_p_fillet_x),num_p_z)
        y_ids = jnp.repeat(j+num_p_grip_y+num_p_fillet_y+num_p_gage_y,num_p_fillet_x*num_p_z)
        z_ids = jnp.tile(jnp.arange(0,num_p_z),num_p_fillet_x)
        xyz_ids = jnp.concatenate([xyz_ids,jnp.stack([x_ids,y_ids,z_ids],axis=1)],axis=0)
        section_ids = jnp.concatenate([section_ids,jnp.array([3])])

    # finally the upper grip
    widths = jnp.concatenate([widths,jnp.repeat(num_p_grip_x,num_p_grip_y)])
    heights = jnp.concatenate([heights,jnp.repeat(num_p_grip_y,num_p_grip_y)])
    x_ids = jnp.repeat(jnp.repeat(jnp.arange(0,num_p_grip_x),num_p_grip_y),num_p_z)
    y_ids = jnp.tile(jnp.tile(jnp.arange(num_p_grip_y+2*num_p_fillet_y+num_p_gage_y,2*num_p_grip_y+2*num_p_fillet_y+num_p_gage_y),num_p_z),num_p_grip_x)
    z_ids = jnp.tile(jnp.tile(jnp.arange(0,num_p_z),num_p_grip_x),num_p_grip_y)
    xyz_ids = jnp.concatenate([xyz_ids,jnp.stack([x_ids,y_ids,z_ids],axis=1)],axis=0)
    section_ids = jnp.concatenate([section_ids,jnp.repeat(4,num_p_grip_y)])

    return xyz_ids, section_ids, widths, heights, num_p_z

In [53]:
def get_tensile_bar_positions(grip_width:            float,
                              grip_length:           float,
                              gage_width:            float,
                              gage_length:           float,
                              thickness:             float,
                              fillet_major_diameter: float,
                              fillet_minor_diameter: float,
                              ideal_spacing:         float,
                              xyz_ids:               jnp.array,
                              section_ids:           jnp.array,
                              widths:                jnp.array,
                              heights:               jnp.array,
                              num_p_z:               int
                              ) -> jnp.array:
    indices = section_ids[xyz_ids[:,1]]

    fillet_center_x = fillet_minor_diameter/2 + gage_width/2
    fillet_center_y = gage_length/2
    fillet_length = ((1 - (grip_width/2 - fillet_minor_diameter/2 - gage_width/2)**2 / (fillet_minor_diameter/2)**2) * (fillet_major_diameter/2)**2)**0.5

    fillet_width_upper = lambda y: (fillet_center_x - ((1 - (y+fillet_center_y)**2 / (fillet_major_diameter/2)**2) * (fillet_minor_diameter/2)**2)**0.5)*2
    fillet_width_lower = lambda y: (fillet_center_x - ((1 - (y-fillet_center_y)**2 / (fillet_major_diameter/2)**2) * (fillet_minor_diameter/2)**2)**0.5)*2

    branches = [
        lambda xyz_id: jnp.array([grip_width*xyz_id[0]/widths[xyz_id[1]]-grip_width/2,
                                  grip_length*xyz_id[1]/heights[xyz_id[1]] - grip_length - gage_length/2 - fillet_length,
                                  thickness*xyz_id[2]/num_p_z]), # lower grip
        lambda xyz_id: jnp.array([fillet_width_lower(xyz_id[1])*xyz_id[0]/widths[xyz_id[1]] - fillet_width_lower(xyz_id[1])/2,
                                  fillet_length*(xyz_id[1] - heights[0])/heights[xyz_id[1]] - gage_length/2 - fillet_length,
                                  thickness*xyz_id[2]/num_p_z]), # lower fillet
        lambda xyz_id: jnp.array([gage_width*xyz_id[0]/widths[xyz_id[1]]-gage_width/2,
                                  gage_length*(xyz_id[1] - heights[0])/heights[xyz_id[1]] - gage_length/2,
                                  thickness*xyz_id[2]/num_p_z]), # gage section
        lambda xyz_id: jnp.array([fillet_width_upper(xyz_id[1])*xyz_id[0]/widths[xyz_id[1]] - fillet_width_upper(xyz_id[1])/2,
                                  fillet_length*xyz_id[1]/heights[xyz_id[1]] + gage_length/2,
                                  thickness*xyz_id[2]/num_p_z]), # upper fillet
        lambda xyz_id: jnp.array([grip_width*xyz_id[0]/widths[xyz_id[1]]-grip_width/2,
                                  grip_length*xyz_id[1]/heights[xyz_id[1]] + gage_length/2 + fillet_length,
                                  thickness*xyz_id[2]/num_p_z]), # upper grip
    ]

    xyz_positions = jnp.vectorize(jax.lax.switch,excluded=(1,),signature='(),(m)->(n)')(indices, branches, xyz_ids)
    
    return xyz_positions

In [54]:
grip_width = 10
grip_length = 15
gage_width = 5
gage_length = 30
thickness = 3
fillet_major_diameter = 20
fillet_minor_diameter = 10
ideal_spacing = 1.0

xyz_ids,section_ids,widths,heights,num_p_z = get_tensile_bar_numbers(grip_width,
                                                                     grip_length,
                                                                     gage_width,
                                                                     gage_length,
                                                                     thickness,
                                                                     fillet_major_diameter,
                                                                     fillet_minor_diameter,
                                                                     ideal_spacing)

xyz_positions = get_tensile_bar_positions(grip_width,
                                          grip_length,
                                          gage_width,
                                          gage_length,
                                          thickness,
                                          fillet_major_diameter,
                                          fillet_minor_diameter,
                                          ideal_spacing,
                                          xyz_ids,
                                          section_ids,
                                          widths,
                                          heights,
                                          num_p_z)

Now we can convert this collection of IDs into real coordinates.

In [55]:
# Intensive to load here, just forewarning...

pdata = pyvista.PolyData(onp.array(xyz_positions))
sphere = pyvista.Sphere(radius = ideal_spacing/2)
pc = pdata.glyph(scale=False, geom=sphere, orient=False)
pc.plot()
# pc.save("test.vtk")

Widget(value='<iframe src="http://localhost:64566/index.html?ui=P_0x16c92a4ff50_1&reconnect=auto" class="pyvis…

OK now the following example took a while to figure out. The regular `jnp.piecewise` does not like taking a 1D input and giving a 2D output, so you have to use `jax.lax.switch`, and then you have to `jnp.vectorize` that and use the correct signature...

In [56]:
# i = jnp.arange(0,20)

# indices = jnp.where(i<10,0,1)

# branches = [
#     lambda j: jnp.array([j,j+1]),
#     lambda j: jnp.array([2*j,3*j])
# ]

# jnp.vectorize(jax.lax.switch,excluded=(1,),signature='(),()->(n)')(indices, branches, i)

# branches = [
#     lambda j: jnp.array([j[0],j[1]]),
#     lambda j: jnp.array([j[0],j[1]]),
#     lambda j: jnp.array([j[0],j[1]]),
#     lambda j: jnp.array([j[0],j[1]]),
#     lambda j: jnp.array([j[0],j[1]])
# ]


# jnp.vectorize(jax.lax.switch,excluded=(1,),signature='(),(m)->(n)')(sections, branches, xyzw)