In [9]:
import jax
import jax.numpy as jnp
from jax import lax
import torch
import numpy as np

In [2]:
### abs

inp = jnp.array([-1.0, 2.0, -3.0])
out = jnp.abs(inp)
print("JAX abs:", out)

JAX abs: [1. 2. 3.]


In [3]:
jax.lax.acos(inp)

Array([3.1415925,       nan,       nan], dtype=float32)

In [4]:
lax.max(inp, inp)

Array([-1.,  2., -3.], dtype=float32)

In [5]:
x1 = np.random.rand(3, 4).astype(np.float32)

In [6]:
lax.argmax(x1, axis=0, index_dtype=jnp.int32)

Array([0, 1, 1, 1], dtype=int32)

In [7]:
torch.argmax(torch.tensor(x1), dim=0, index_type=torch.int32)

TypeError: argmax() got an unexpected keyword argument 'index_type'

In [None]:
lax.abs(x=inp)

Array([1., 2., 3.], dtype=float32)

In [None]:
lax.argmax(operand=x1, axis=0, index_dtype=jnp.int32)

Array([2, 2, 0, 0], dtype=int32)

In [None]:
lax.approx_max_k(operand=inp, k=2, reduction_dimension=0, recall_target=0.5)

[Array([ 2., -1.], dtype=float32), Array([1, 0], dtype=int32)]

In [None]:
import inspect


inspect.signature(lax.approx_max_k)

<Signature (operand: jax.Array, k: int, reduction_dimension: int = -1, recall_target: float = 0.95, reduction_input_size_override: int = -1, aggregate_to_topk: bool = True) -> tuple[jax.Array, jax.Array]>

In [None]:
inspect.signature(torch.topk)

ValueError: no signature found for builtin <built-in method topk of type object at 0x1105ff6d0>

In [None]:
file_path = "/Users/haifeng/Program/FUEL-JAX/input/abs/01.npz"

In [None]:
x = np.random.rand(3, 4).astype(np.float32)
np.savez_compressed(file_path, x=x)

In [None]:
data = np.load("/Users/haifeng/Program/FUEL-JAX/output/abs/test.npz")

In [None]:
data

NpzFile '/Users/haifeng/Program/FUEL-JAX/output/abs/test.npz' with keys: out

In [None]:
data["out"]

array([[0.62513006, 0.6944465 , 0.5826604 , 0.85714823],
       [0.22155215, 0.16254286, 0.36576876, 0.6015664 ],
       [0.39712742, 0.23268753, 0.14461608, 0.5430027 ]], dtype=float32)

In [None]:
data.keys()

KeysView(NpzFile '/Users/haifeng/Program/FUEL-JAX/output/abs/test.npz' with keys: out)

In [8]:
def get_file(op_name):
    return f"/Users/haifeng/Program/FUEL-JAX/input/{op_name}/test_00.npz"

In [16]:
x = np.random.randn(16, 256, 128)
np.savez_compressed(get_file("acos"), x=x)
out_1 = jax.lax.acos(x=x)
out_2 = torch.acos(torch.from_numpy(x))

In [None]:
x_tensor = torch.from_numpy(x)
x_tensor = x_tensor.to(torch.float8_e4m3fn)

torch.asin(x_tensor)

NotImplementedError: "asin_vml_cpu" not implemented for 'Float8_e4m3fn'

In [25]:
x_Array = jnp.array(x, dtype=jnp.float8_e4m3fn)

In [27]:
lax.asin(x_Array)

Array([[[-0.5, -0.40625, -0.3125, ..., -0.015625, -0.5, -0.375],
        [-0.1875, 0.0703125, nan, ..., -0.28125, nan, nan],
        [-0.0429688, nan, 0.6875, ..., -1.625, nan, 0.0546875],
        ...,
        [nan, 0.75, -0.0546875, ..., 0.4375, 0.28125, nan],
        [-0.03125, -0.34375, 1.625, ..., -1.25, 0.5, 0.0078125],
        [nan, 1.25, 0.8125, ..., 0.75, -1, 0.375]],

       [[0.101562, 0.171875, nan, ..., -0.375, -0.6875, nan],
        [0.5, -0.140625, nan, ..., 0.6875, -1, nan],
        [-0.8125, 0.3125, nan, ..., nan, -0.0351562, -1],
        ...,
        [1, nan, 0.375, ..., -0.8125, -0.4375, -0.5],
        [0.625, -0.75, -0.28125, ..., nan, 0.75, -0.00390625],
        [nan, 0.28125, 0.15625, ..., -0.5, 0.0078125, -0.0859375]],

       [[1, 1, nan, ..., 0.1875, 1.25, -0.4375],
        [-0.625, -0.09375, -0.34375, ..., -0.109375, 0.5, nan],
        [1.625, -0.3125, -0.6875, ..., nan, 0.25, -1.625],
        ...,
        [0.8125, nan, -0.25, ..., nan, nan, 0.109375],
        

In [30]:
torch.acos(x_tensor)

NotImplementedError: "acos_vml_cpu" not implemented for 'Float8_e4m3fn'

In [29]:
torch.add(x_tensor, x_tensor)

NotImplementedError: "add_stub" not implemented for 'Float8_e4m3fn'

In [31]:
y = np.random.randn(16, 256, 128)

In [32]:
np.savez_compressed(get_file("atan2"), x=x, y=y)

In [33]:
x

array([[[-4.73241283e-01, -4.11359501e-01, -3.11700548e-01, ...,
         -1.61781858e-02, -4.78956225e-01, -3.69178510e-01],
        [-1.89940399e-01,  7.25882480e-02,  1.37861612e+00, ...,
         -2.69413517e-01, -1.11238766e+00, -1.29293983e+00],
        [-4.13579656e-02, -1.18696461e+00,  6.04095931e-01, ...,
         -9.82967452e-01, -1.38086993e+00,  5.55792367e-02],
        ...,
        [ 1.49510017e+00,  6.78972240e-01, -5.35712809e-02, ...,
          4.49387700e-01,  2.79040008e-01, -1.76804866e+00],
        [-2.96271370e-02, -3.31341499e-01,  1.01734186e+00, ...,
         -9.43385766e-01,  5.12830836e-01,  7.70437141e-03],
        [ 2.10552794e+00,  9.44846506e-01,  7.23677982e-01, ...,
          6.94373072e-01, -8.32334313e-01,  3.76059839e-01]],

       [[ 9.85273230e-02,  1.64086263e-01,  1.59509710e+00, ...,
         -3.60353687e-01, -6.43435083e-01,  1.35267098e+00],
        [ 5.11439079e-01, -1.42203240e-01, -1.19765086e+00, ...,
          6.04148442e-01, -7.91803961e

In [35]:
torch.rsqrt(x_tensor)

NotImplementedError: "rsqrt_cpu" not implemented for 'Float8_e4m3fn'