In [2]:
import jax

jax.devices()

[CudaDevice(id=0)]

In [3]:
import dolfinx as dlx
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from eikonax import tensorfield
from mpi4py import MPI

from cardiac_electrophysiology import fiberfield

sns.set_theme(style="ticks")

In [None]:
mesh = dlx.mesh.create_rectangle(
    MPI.COMM_WORLD,
    [np.array([0, 0]), np.array([1, 1])],
    [2, 2],
    dlx.mesh.CellType.triangle,
)
num_cells = mesh.geometry.dofmap.shape[0]

In [None]:
mean_angle_vector = np.full(num_cells, np.pi / 4)
mean_parameter_vector = np.arctanh(np.cos(mean_angle_vector))
first_basis_vector = np.repeat([[1, 0]], num_cells, axis=0)
second_basis_vector = np.repeat([[0, 1]], num_cells, axis=0)
longitudinal_velocity_vector = np.full(num_cells, 1)
transverse_velocity_vector = np.full(num_cells, 1)

In [None]:
fiber_field = fiberfield.FiberTensor(
    dimension=2,
    mean_parameter_vector=mean_parameter_vector,
    basis_vectors=[first_basis_vector, second_basis_vector],
    conduction_velocities=[longitudinal_velocity_vector, transverse_velocity_vector],
)