# The fast spectral method for 2D

In [1]:
import numpy as np
from math import pi
from absl import logging
import time
import jax

logging.set_verbosity("info")

In [2]:
from kipack.collision.inelastic import FSInelasticVHSCollision
from kipack.collision import SpectralMesh
from config import get_config

In [3]:
def bkw_f(v):
    t = 0.5
    K = 1 - 0.5 * np.exp(-t / 8)
    v_norm = (v**2)[:, None] + v**2
    return (
        1
        / (2 * pi * K**2)
        * np.exp(-0.5 * v_norm / K)
        * (2 * K - 1 + 0.5 * v_norm * (1 - K) / K)
    )


def ext_Q(v):
    t = 0.5
    K = 1 - np.exp(-t / 8) / 2
    dK = np.exp(-t / 8) / 16
    v_norm = (v**2)[:, None] + v**2
    df = (-2 / K + v_norm / (2 * K**2)) * bkw_f(v) + 1 / (
        2 * pi * K**2
    ) * np.exp(-v_norm / (2 * K)) * (2 - v_norm / (2 * K**2))
    return df * dK

In [4]:
cfg = get_config("2d")
vmesh = SpectralMesh(cfg)

INFO:absl:2 dimensional collision model.
INFO:absl:Number of velocity cells: 64.
INFO:absl:Velocity domain: [-7.724873734152916, 7.724873734152916].


## On CPU using pyFFTW

In [7]:
coll_pyfftw = FSInelasticVHSCollision(cfg, vmesh, use_pyfftw=True)

t_0 = time.time()
Q = coll_pyfftw(bkw_f(vmesh.center))
dt = time.time() - t_0

print(
    f"Runtime: {1000* dt:.2f}ms, error: {np.max(np.abs(Q - ext_Q(vmesh.center)))}"
)

INFO:absl:e: 1.0
INFO:absl:Collision model precomputation finishes!


Runtime: 309.40ms, error: 1.7104188186368172e-11


## On GPU using Cupy

In [8]:
coll_jax = FSInelasticVHSCollision(cfg, vmesh)
coll = jax.jit(lambda x: coll_jax(x))

INFO:absl:e: 1.0


In [12]:
t_0 = time.time()
Q = coll(bkw_f(vmesh.center))
dt = time.time() - t_0

print(
    f"Runtime: {1000* dt:.2f}ms, error: {np.max(np.abs(Q - ext_Q(vmesh.center)))}"
)

Runtime: 2.30ms, error: 5.21540641784668e-08
