# The fast spectral method for 3D


In [None]:
import numpy as np
from math import pi
from absl import logging
import time

logging.set_verbosity("info")


In [None]:
from kipack.collision.inelastic import FSInelasticVHSCollision
from kipack.collision.vmesh import SpectralMesh
from config import get_config


In [None]:
def isotropic_f(v):
    t = 6.5
    K = 1 - np.exp(-t / 6)
    v_norm = v[:, None, None] ** 2 + v[:, None] ** 2 + v ** 2
    return (
        1
        / (2 * (2 * pi * K) ** (3 / 2))
        * np.exp(-(v_norm) / (2 * K))
        * ((5 * K - 3) / K + (1 - K) / (K ** 2) * (v_norm))
    )


def extQ(v):
    t = 6.5
    K = 1 - np.exp(-t / 6)
    dK = np.exp(-t / 6) / 6
    v_norm = v[:, None, None] ** 2 + v[:, None] ** 2 + v ** 2
    df = (-3 / (2 * K) + (v_norm) / (2 * K ** 2)) * isotropic_f(v) + 1 / (
        2 * (2 * pi * K) ** (3 / 2)
    ) * np.exp(-v_norm / (2 * K)) * (3 / (K ** 2) + (K - 2) / (K ** 3) * v_norm)
    return df * dK


def anisotropic_f(v):
    return (
        0.8
        * pi ** (-1.5)
        * (
            np.exp(
                -(16 ** (1 / 3))
                * (
                    (v - 2)[:, None, None] ** 2
                    + (v - 2)[:, None] ** 2
                    + (v - 2) ** 2
                )
            )
            + np.exp(
                -(v + 0.5)[:, None, None] ** 2
                - (v + 0.5)[:, None] ** 2
                - (v + 0.5) ** 2
            )
        )
    )


def maxwellian(v, rho, u, T):
    v_u = (
        ((v - u[0]) ** 2)[:, None, None]
        + ((v - u[1]) ** 2)[:, None]
        + (v - u[2]) ** 2
    )
    return rho / (2 * pi * T) ** (3 / 2) * np.exp(-v_u / (2 * T))


In [None]:
cfg = get_config("3d")
vmesh = SpectralMesh(cfg)
coll = FSInelasticVHSCollision(cfg, vmesh)


## On CPU using pyFFTW


In [None]:
t_0 = time.time()
Q = coll(isotropic_f(vmesh.center), device="cpu")
dt = time.time() - t_0

print(
    f"Runtime: {1000 * dt:.2f}ms, error: {np.max(np.abs(Q - extQ(vmesh.center)))}"
)


## On GPU using Cupy


In [None]:
import jax

@jax.jit
def coll_gpu(x):
    return coll(x, device="gpu")

In [None]:
t_0 = time.time()
Q = coll_gpu(isotropic_f(vmesh.center))
dt = time.time() - t_0

print(
    f"Runtime: {1000 * dt:.2f}ms, error: {np.max(np.abs(Q - extQ(vmesh.center)))}"
)
