In [None]:
import torch
import jax.lax as lax
import jax.numpy as jnp
import numpy as np

In [None]:
lax.is_finite(jnp.array([1, 2, 3, 4], dtype=jnp.float32))

In [None]:
# jax.lax.sort vs torch.ops.aten.sort.default
x_np = np.array([[3.0, -1.0, 2.5, 2.5], [0.0, -4.2, 8.1, 1.1]], dtype=np.float32)
x_jax = jnp.array(x_np)
x_torch = torch.tensor(x_np)

# Test ascending sort on last dimension
jax_sorted_last = np.array(lax.sort(x_jax, dimension=-1))
torch_sorted_last, _ = torch.ops.aten.sort.default(x_torch, dim=-1, descending=False)
np.testing.assert_allclose(
    jax_sorted_last, torch_sorted_last.numpy(), rtol=0.0, atol=0.0
)

# Test sort on a non-last dimension
jax_sorted_dim0 = np.array(lax.sort(x_jax, dimension=0))
torch_sorted_dim0, _ = torch.ops.aten.sort.default(x_torch, dim=0, descending=False)
np.testing.assert_allclose(
    jax_sorted_dim0, torch_sorted_dim0.numpy(), rtol=0.0, atol=0.0
)

print("jax.lax.sort mapping check passed")

In [5]:
# jax.lax.scatter_mul vs torch.ops.aten.scatter_reduce_.two (reduce='prod')
operand_np = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
indices_np = np.array([[1], [3], [1]], dtype=np.int32)
updates_np = np.array([10.0, 2.0, 0.5], dtype=np.float32)

operand_jax = jnp.array(operand_np)
indices_jax = jnp.array(indices_np)
updates_jax = jnp.array(updates_np)
dnums = lax.ScatterDimensionNumbers(
    update_window_dims=(),
    inserted_window_dims=(0,),
    scatter_dims_to_operand_dims=(0,),
)
out_jax = np.array(lax.scatter_mul(operand_jax, indices_jax, updates_jax, dnums))

operand_torch = torch.tensor(operand_np)
index_torch = torch.tensor(indices_np.squeeze(-1), dtype=torch.long)
updates_torch = torch.tensor(updates_np)
out_torch = operand_torch.clone()
torch.ops.aten.scatter_reduce_.two(
    out_torch, 0, index_torch, updates_torch, "prod", include_self=True
)

np.testing.assert_allclose(out_jax, out_torch.numpy(), rtol=0.0, atol=0.0)
print("jax.lax.scatter_mul mapping check passed")

jax.lax.scatter_mul mapping check passed


In [6]:
import jax.numpy as jnp
from jax import lax

# 1. 准备数据
x = jnp.array([-5.0, -1.0, 0.5, 3.0, 10.0])

# 2. 调用 lax.clamp
# 语义：把 x 限制在 [-2.0, 2.0] 之间
# 参数顺序：(最小值, 输入数据, 最大值)
result = lax.clamp(-2.0, x, 2.0)

print(f"输入: {x}")
print(f"输出: {result}")

# 输出预期:
# [-2.0, -1.0, 0.5, 2.0, 2.0]
# (-5 变成了 -2, 10 变成了 2)

输入: [-5.  -1.   0.5  3.  10. ]
输出: [-2.  -1.   0.5  2.   2. ]
