# The fast spectral method for 2D

In [None]:
import numpy as np
from math import pi
from absl import logging
import time

logging.set_verbosity("info")

In [None]:
from kipack.collision.inelastic import FSInelasticVHSCollision
from kipack.collision import SpectralMesh
from config import get_config

In [None]:
def bkw_f(v):
    t = 0.5
    K = 1 - 0.5 * np.exp(-t / 8)
    v_norm = (v**2)[:, None] + v**2
    return (
        1
        / (2 * pi * K**2)
        * np.exp(-0.5 * v_norm / K)
        * (2 * K - 1 + 0.5 * v_norm * (1 - K) / K)
    )


def ext_Q(v):
    t = 0.5
    K = 1 - np.exp(-t / 8) / 2
    dK = np.exp(-t / 8) / 16
    v_norm = (v**2)[:, None] + v**2
    df = (-2 / K + v_norm / (2 * K**2)) * bkw_f(v) + 1 / (
        2 * pi * K**2
    ) * np.exp(-v_norm / (2 * K)) * (2 - v_norm / (2 * K**2))
    return df * dK

In [None]:
cfg = get_config("2d")
vmesh = SpectralMesh(cfg)
coll = FSInelasticVHSCollision(cfg, vmesh)

## On CPU using pyFFTW

In [None]:
t_0 = time.time()
Q = coll(bkw_f(vmesh.center), device="cpu")
dt = time.time() - t_0

print(
    f"Runtime: {1000* dt:.2f}ms, error: {np.max(np.abs(Q - ext_Q(vmesh.center)))}"
)

## On GPU using Cupy

In [None]:
import jax


@jax.jit
def coll_gpu(x):
    return coll(x, device="gpu")

In [None]:
t_0 = time.time()
Q = coll_gpu(bkw_f(vmesh.center))
dt = time.time() - t_0

print(
    f"Runtime: {1000* dt:.2f}ms, error: {np.max(np.abs(Q - ext_Q(vmesh.center)))}"
)