In [42]:
from dataclasses import dataclass
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import matplotlib.pyplot as plt

In [47]:
@dataclass
class Data():
    pass

In [None]:
def mesh_points(gauss_point: jnp.ndarray, index: int):
    
    N_cell = gauss_point.shape[0] - 1
    dx = 1.0 / N_cell

    xR = gauss_point[index] + dx / 2
    xL = gauss_point[index] - dx / 2
    return xL, xR

In [65]:
x_gauss = jnp.linspace(0, 1, 11)
N_cell = len(x_gauss) - 1
print("Gauss points:", x_gauss)
xL, xR = mesh_points(x_gauss, 1)
print("Left mesh points:", xL)
print("Right mesh points:", xR)

Gauss points: [0.         0.1        0.2        0.3        0.4        0.5
 0.6        0.7        0.8        0.90000004 1.        ]
Left mesh points: 0.05
Right mesh points: 0.15


In [66]:
# vmap de la fonction mesh_points
vmap_mesh_points = jax.vmap(mesh_points, in_axes=(None, 0))

In [67]:
vmap_mesh_points_result = vmap_mesh_points(x_gauss, jnp.arange(1, N_cell))
print("Vmap Left mesh points:", vmap_mesh_points_result[0])
print("Vmap Right mesh points:", vmap_mesh_points_result[1])

Vmap Left mesh points: [0.05 0.15 0.25 0.35 0.45 0.55 0.65 0.75 0.85]
Vmap Right mesh points: [0.15       0.25       0.35000002 0.45000002 0.55       0.65000004
 0.75       0.85       0.95000005]


In [68]:
def basis_func(x_int: jnp.ndarray, xL: jnp.ndarray, xR: jnp.ndarray):
    phi_Left = (x_int - xL)*2 / (xR - xL)
    phi_Right = (xR - x_int)*2 / (xR - xL)
    return jnp.array([phi_Left, phi_Right])