# The fast spectral method for 3D


In [1]:
import numpy as np
from math import pi
from absl import logging
import time
import jax

from kipack.collision.inelastic import FSInelasticVHSCollision
from kipack.collision.vmesh import SpectralMesh
from config import get_config

logging.set_verbosity("info")

In [2]:
def isotropic_f(v):
    t = 6.5
    K = 1 - np.exp(-t / 6)
    v_norm = v[:, None, None] ** 2 + v[:, None] ** 2 + v ** 2
    return (
        1
        / (2 * (2 * pi * K) ** (3 / 2))
        * np.exp(-(v_norm) / (2 * K))
        * ((5 * K - 3) / K + (1 - K) / (K ** 2) * (v_norm))
    )


def extQ(v):
    t = 6.5
    K = 1 - np.exp(-t / 6)
    dK = np.exp(-t / 6) / 6
    v_norm = v[:, None, None] ** 2 + v[:, None] ** 2 + v ** 2
    df = (-3 / (2 * K) + (v_norm) / (2 * K ** 2)) * isotropic_f(v) + 1 / (
        2 * (2 * pi * K) ** (3 / 2)
    ) * np.exp(-v_norm / (2 * K)) * (3 / (K ** 2) + (K - 2) / (K ** 3) * v_norm)
    return df * dK


def anisotropic_f(v):
    return (
        0.8
        * pi ** (-1.5)
        * (
            np.exp(
                -(16 ** (1 / 3))
                * (
                    (v - 2)[:, None, None] ** 2
                    + (v - 2)[:, None] ** 2
                    + (v - 2) ** 2
                )
            )
            + np.exp(
                -(v + 0.5)[:, None, None] ** 2
                - (v + 0.5)[:, None] ** 2
                - (v + 0.5) ** 2
            )
        )
    )


def maxwellian(v, rho, u, T):
    v_u = (
        ((v - u[0]) ** 2)[:, None, None]
        + ((v - u[1]) ** 2)[:, None]
        + (v - u[2]) ** 2
    )
    return rho / (2 * pi * T) ** (3 / 2) * np.exp(-v_u / (2 * T))


In [3]:
cfg = get_config("3d")
vmesh = SpectralMesh(cfg)

INFO:absl:3 dimensional collision model.
INFO:absl:Number of velocity cells: 32.
INFO:absl:Velocity domain: [-8.82842712474619, 8.82842712474619].


## On CPU using pyFFTW


In [8]:
coll_pyfftw = FSInelasticVHSCollision(cfg, vmesh, use_pyfftw=True)

t_0 = time.time()
Q = coll_pyfftw(isotropic_f(vmesh.center))
dt = time.time() - t_0

print(
    f"Runtime: {1000 * dt:.2f}ms, error: {np.max(np.abs(Q - extQ(vmesh.center)))}"
)


INFO:absl:e: 1.0
INFO:absl:Collision model precomputation finishes!


Runtime: 1132.86ms, error: 2.6901211073158564e-06


## On GPU using Cupy


In [9]:
coll_jax = FSInelasticVHSCollision(cfg, vmesh)

coll = jax.jit(lambda x: coll_jax(x))

INFO:absl:e: 1.0


In [14]:
t_0 = time.time()
Q = coll(isotropic_f(vmesh.center))
dt = time.time() - t_0

print(
    f"Runtime: {1000 * dt:.2f}ms, error: {np.max(np.abs(Q - extQ(vmesh.center)))}"
)


Runtime: 2.15ms, error: 2.6918714866042137e-06
