In [26]:
from maxpy.utils import *

size = 32
lin = np.linspace(-3,3,size)

x,y,z = np.meshgrid(lin,lin,lin, indexing="ij")
flow = np.array([-y, x, z/z])*np.exp(-(x**2+y**2)) 
Jac = get_jacobian(flow, lin, lin, lin)

In [44]:
import jax.numpy as np
import jax

def f(x):
    return np.array([x[0]**2,x[1]**2])

x = np.array([[3.,11.],[5.,13.],[7.,17.]])
print(x.shape)

jac = jax.jacobian(f)
vmap_jac = jax.vmap(jac)
result = np.linalg.det(vmap_jac(x))
print(result)

(3, 2)
[132. 260. 476.]


In [58]:
from maxpy.utils import *
import jax.numpy as jnp
from jax import grad, jit, vmap, jacfwd, jacobian
from jax import random

size = 32
lin = np.linspace(-3,3,size)

@jit
def eddy(x):
    return jnp.asarray([-x[1], x[0], x[2]/x[2]])*jnp.exp(-(x[0]**2+x[1]**2)) 

x,y,z = jnp.meshgrid(lin,lin,lin, indexing="ij")

jac = jacobian(eddy)

Jac = vmap(jac)

inputx = np.array([x.flatten(),y.flatten(),z.flatten()])
print(inputx.shape)
result = Jac(inputx)
print(result.reshape(3,3,size,size,size))

@jit
def get_jacobian(vec, dx, dy, dz):
    dudx, dudy, dudz = jnp.gradient(vec[0])
    dvdx, dvdy, dvdz = jnp.gradient(vec[1])
    dwdx, dwdy, dwdz = jnp.gradient(vec[2])

    return jnp.array([[dudx, dudy, dudz], [dvdx, dvdy, dvdz], [dwdx, dwdy, dwdz]])

(3, 32768)
[[[[[ 2.7413967e-07  2.5890969e-07  0.0000000e+00 ...  0.0000000e+00
      0.0000000e+00  0.0000000e+00]
    [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
      0.0000000e+00  0.0000000e+00]
    [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
      0.0000000e+00  0.0000000e+00]
    ...
    [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
      0.0000000e+00  0.0000000e+00]
    [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
      0.0000000e+00  0.0000000e+00]
    [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
      0.0000000e+00  0.0000000e+00]]

   [[ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
      0.0000000e+00  0.0000000e+00]
    [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
      0.0000000e+00  0.0000000e+00]
    [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
      0.0000000e+00  0.0000000e+00]
    ...
    [ 0.0000000e+00  0.000

In [35]:
@jit
def get_parallel_vector_operator(vec, x, y, z):
    sizex, sizey, sizez = vec.shape[1], vec.shape[2], vec.shape[3]

    Jac = get_jacobian(vec, x, y, z)
    
    # Reshape vec for vectorized operations
    vec_reshaped = vec.reshape(vec.shape[0], -1)  
    Jac_reshaped = Jac.reshape(Jac.shape[0], Jac.shape[1], -1)  

    Jv = jnp.einsum("ijk,jk->ik", Jac_reshaped, vec_reshaped)

    cross_product = jnp.cross(vec_reshaped.T, Jv.T)
    parallel = jnp.linalg.norm(cross_product, axis=1).reshape(sizex, sizey, sizez)

    return parallel

In [36]:
%timeit parallel = get_parallel_vector_operator(flow, lin ,lin, lin)

62.1 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [19]:
%timeit parallel = get_parallel_vector_operator(flow, lin ,lin, lin)

617 µs ± 3.41 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]:
fig = go.Figure(data=go.Isosurface(
    x=x.flatten(),
    y=y.flatten(),
    z=z.flatten(),
    value=parallel.flatten(),
    opacity=0.6,
    isomin=parallel.min(),
    isomax=parallel.min()+1e-6,
    surface_count=10,
    caps=dict(x_show=False, y_show=False)
    ))
fig.show(renderer="browser")

In [None]:
save_as_vti(parallel, "../../Data/VTK/", "parallel")