I just want to try out putting the kernel in that neighborhood summation to check that the values are as expected (near 1 right?)

In [2]:
import jax
import jax.numpy as jnp
import numpy as onp
import pyvista
import matplotlib.pyplot as plt
from functools import partial

I'm going to need to figure out how best to label variables so I can differentiate between the entire domain and one current particle being processed.

Also for vectorized stuff. I think I like to use `_0D` and `_1D` and so on appended to function names.

In [86]:
spacing = 1
edge_size = 100
p = int(edge_size/spacing)
x = jnp.repeat(jnp.repeat(jnp.linspace(0,edge_size,p),p),p)
y = jnp.repeat(jnp.tile(jnp.linspace(0,edge_size,p),p),p)
z = jnp.tile(jnp.tile(jnp.linspace(0,edge_size,p),p),p)
xyz = jnp.stack([x,y,z],axis=1)
xyz_size = xyz.shape[0]

@partial(jax.jit, static_argnames=["hash_map_size","spacing","prime_1","prime_2","prime_3"])
def xyz_to_hash(xyz_location:  jnp.array,
                hash_map_size: int,
                spacing:       float,
                prime_1 = 73856093,
                prime_2 = 19349663,
                prime_3 = 83492791
                ) -> jnp.array:
    '''
    Converts [N,3] array of xyz coordinates to an [N,] array of hash keys
    '''
    primes = jnp.array([prime_1,prime_2,prime_3])
    alpha = jnp.floor(xyz_location/spacing).astype(int)*primes[None,:]
    beta = jnp.bitwise_xor(jnp.bitwise_xor(alpha[:,0],alpha[:,1]),alpha[:,2])
    hash_keys = jnp.mod(beta,hash_map_size)
    return hash_keys

@partial(jax.jit, static_argnames=["hash_map_size"])
def process_hash(hash_keys:     jnp.array,
                 hash_map_size: int):
    sorted_indices = jnp.argsort(hash_keys)
    sorted_hash_keys = hash_keys[sorted_indices]
    counts = jnp.bincount(sorted_hash_keys, length=hash_map_size)
    offsets = jnp.cumsum(counts) - counts
    return [sorted_indices, sorted_hash_keys, counts, offsets]

@jax.jit
def hash_to_count(hash:   int,
                  counts: jnp.array):
    count = counts[hash]
    return count

@jax.jit
def hash_to_xyz(hash:           int,
                increment:      int,
                xyz:            jnp.array,
                sorted_indices: jnp.array,
                offsets:        jnp.array):
    xyz_location = xyz[sorted_indices[offsets[hash] + increment]]
    return xyz_location

hash_map_size = 2*xyz_size
hash_spacing = 2*spacing
hash_keys = xyz_to_hash(xyz,hash_map_size,hash_spacing)
sorted_indices, sorted_hash_keys, counts, offsets = process_hash(hash_keys, hash_map_size)

o = jnp.array([-hash_spacing,0,hash_spacing])
o_x = jnp.repeat(jnp.repeat(o,3),3)
o_y = jnp.repeat(jnp.tile(o,3),3)
o_z = jnp.tile(jnp.tile(o,3),3)
o_xyz = jnp.stack([o_x,o_y,o_z],axis=1)

@jax.jit
def kernel(distance, smoothing_length):
    q = distance/smoothing_length
    M = jnp.piecewise(q,[q<2,q<1],[(2-q)**3, (2-q)**3 - 4*(1-q)**3, 0])
    return (1/(4*jnp.pi))*M

kernel_v = jax.vmap(kernel,in_axes=(0,None))

In [92]:
@jax.jit
def increment_kernel_inner(i,inner_carry):
    kernel_tot, center_xyz, xyz, hash, sorted_indices, offsets, smoothing_length = inner_carry
    other_xyz = hash_to_xyz(hash, i, xyz, sorted_indices, offsets)
    distance = jnp.sqrt(jnp.sum(jnp.square(center_xyz - other_xyz)))
    kernel_tot = kernel_tot + kernel(distance, smoothing_length)
    inner_carry = (kernel_tot, center_xyz, xyz, hash, sorted_indices, offsets, smoothing_length)
    return inner_carry

@jax.jit
def increment_kernel_outer(i,outer_carry):
    kernel_tot, center_xyz, xyz, hashes, counts, sorted_indices, smoothing_length = outer_carry
    hash = hashes[i]
    count = counts[hash]
    inner_carry = (kernel_tot, center_xyz, xyz, hash, sorted_indices, offsets, smoothing_length)
    kernel_tot = jax.lax.fori_loop(0, count, increment_kernel_inner, inner_carry)[0]
    outer_carry = (kernel_tot, center_xyz, xyz, hashes, counts, sorted_indices, smoothing_length)
    return outer_carry

@jax.jit
def index_to_kernel(index, xyz, o_xyz, hash_map_size, hash_spacing, sorted_indices, counts, smoothing_length):
    center_xyz = xyz[index,:]
    hashes = xyz_to_hash(center_xyz+o_xyz, hash_map_size, hash_spacing)
    outer_carry = (0, center_xyz, xyz, hashes, counts, sorted_indices, smoothing_length)
    return jax.lax.fori_loop(0, 27, increment_kernel_outer, outer_carry)[0]

index_to_kernel_v = jax.vmap(index_to_kernel,in_axes=(0,None,None,None,None,None,None,None))

d_tot = index_to_kernel_v(jnp.arange(0,xyz_size),xyz,o_xyz,hash_map_size,hash_spacing,sorted_indices,counts,spacing)
# d_tot = index_to_kernel_v(jnp.arange(1200,1300),xyz,o_xyz,hash_map_size,hash_spacing,sorted_indices,counts,spacing)

In [89]:
@partial(jax.jit, static_argnames=["hash_map_size","hash_spacing"])
def index_to_hashes(index:         int,
                    xyz_global:    jnp.array,
                    xyz_offsets:   jnp.array,
                    hash_map_size: int,
                    hash_spacing:  float):
    xyz_local = xyz_global[index,:]
    hashes_local = xyz_to_hash(xyz_local+xyz_offsets, hash_map_size, hash_spacing)
    return hashes_local

indices_to_hashes = jax.vmap(index_to_hashes, in_axes=(0,None,None,None,None))

@jax.jit
def hash_to_occurances_0D(hash_local:        int,
                          occurances_global: jnp.array):
    occurances_local = occurances_global[hash_local]
    return occurances_local

hash_to_occurances_1D = jax.vmap(hash_to_occurances_0D, in_axes=(0,None), out_axes=0)
hash_to_occurances_2D = jax.vmap(hash_to_occurances_1D, in_axes=(1,None), out_axes=1)

In [90]:
particle_indices_global = jnp.arange(0,xyz_size)
hashes = indices_to_hashes(particle_indices_global, xyz, o_xyz, hash_map_size, hash_spacing)
hashes.shape

(1000000, 27)

I need a minimum reproducable example to highlight the problem I am having.

Can we pre-calculate the offset hashes? - I don't think so.

In [8]:
test = {"v1": int}
test

{'v1': int}