In [9]:
import torch
import jax.lax as lax
import jax.numpy as jnp
import numpy as np

In [10]:
lax.is_finite(jnp.array([1, 2, 3, 4], dtype=jnp.float32))

Array([ True,  True,  True,  True], dtype=bool)

In [11]:
# jax.lax.sort vs torch.ops.aten.sort.default
x_np = np.array([[3.0, -1.0, 2.5, 2.5], [0.0, -4.2, 8.1, 1.1]], dtype=np.float32)
x_jax = jnp.array(x_np)
x_torch = torch.tensor(x_np)

# Test ascending sort on last dimension
jax_sorted_last = np.array(lax.sort(x_jax, dimension=-1))
torch_sorted_last, _ = torch.ops.aten.sort.default(x_torch, dim=-1, descending=False)
np.testing.assert_allclose(
    jax_sorted_last, torch_sorted_last.numpy(), rtol=0.0, atol=0.0
)

# Test sort on a non-last dimension
jax_sorted_dim0 = np.array(lax.sort(x_jax, dimension=0))
torch_sorted_dim0, _ = torch.ops.aten.sort.default(x_torch, dim=0, descending=False)
np.testing.assert_allclose(
    jax_sorted_dim0, torch_sorted_dim0.numpy(), rtol=0.0, atol=0.0
)

print("jax.lax.sort mapping check passed")

jax.lax.sort mapping check passed


In [12]:
# jax.lax.scatter_mul vs torch.ops.aten.scatter_reduce_.two (reduce='prod')
operand_np = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)
indices_np = np.array([[1], [3], [1]], dtype=np.int32)
updates_np = np.array([10.0, 2.0, 0.5], dtype=np.float32)

operand_jax = jnp.array(operand_np)
indices_jax = jnp.array(indices_np)
updates_jax = jnp.array(updates_np)
dnums = lax.ScatterDimensionNumbers(
    update_window_dims=(),
    inserted_window_dims=(0,),
    scatter_dims_to_operand_dims=(0,),
)
out_jax = np.array(lax.scatter_mul(operand_jax, indices_jax, updates_jax, dnums))

operand_torch = torch.tensor(operand_np)
index_torch = torch.tensor(indices_np.squeeze(-1), dtype=torch.long)
updates_torch = torch.tensor(updates_np)
out_torch = operand_torch.clone()
torch.ops.aten.scatter_reduce_.two(
    out_torch, 0, index_torch, updates_torch, "prod", include_self=True
)

np.testing.assert_allclose(out_jax, out_torch.numpy(), rtol=0.0, atol=0.0)
print("jax.lax.scatter_mul mapping check passed")

jax.lax.scatter_mul mapping check passed


In [13]:
import jax.numpy as jnp
from jax import lax

# 1. 准备数据
x = jnp.array([-5.0, -1.0, 0.5, 3.0, 10.0])

# 2. 调用 lax.clamp
# 语义：把 x 限制在 [-2.0, 2.0] 之间
# 参数顺序：(最小值, 输入数据, 最大值)
result = lax.clamp(-2.0, x, 2.0)

print(f"输入: {x}")
print(f"输出: {result}")

# 输出预期:
# [-2.0, -1.0, 0.5, 2.0, 2.0]
# (-5 变成了 -2, 10 变成了 2)

输入: [-5.  -1.   0.5  3.  10. ]
输出: [-2.  -1.   0.5  2.   2. ]


In [14]:
import jax

jax.lax.sort

<function jax._src.lax.lax.sort(operand: 'Array | Sequence[Array]', dimension: 'int' = -1, is_stable: 'bool' = True, num_keys: 'int' = 1) -> 'Array | tuple[Array, ...]'>

In [15]:
import torch

In [16]:
x = torch.tensor([1, 2, 3, 4, 5])
torch.ops.aten.sort.default(x)

(tensor([1, 2, 3, 4, 5]), tensor([0, 1, 2, 3, 4]))

In [17]:
jax.lax.dot_general

<function jax._src.lax.lax.dot_general(lhs: 'ArrayLike', rhs: 'ArrayLike', dimension_numbers: 'DotDimensionNumbers', precision: 'PrecisionLike' = None, preferred_element_type: 'DTypeLike | None' = None, *, out_sharding=None) -> 'Array'>

In [18]:
jax.lax.dot
jax.lax.dot_general
jax.lax.dot_general_p
jax.lax.batch_matmul

<function jax._src.lax.lax.batch_matmul(lhs: 'Array', rhs: 'Array', precision: 'PrecisionLike' = None) -> 'Array'>

In [19]:
torch.matmul

<function torch._VariableFunctionsClass.matmul>

In [20]:
jnp.matmul

<PjitFunction of <function matmul at 0x133cf1120>>

In [21]:
jax.lax.cbrt

<function jax.lax.cbrt(x: 'ArrayLike', accuracy=None) -> 'Array'>

In [24]:
jax.lax.iota(jnp.float32, 4)

Array([0., 1., 2., 3.], dtype=float32)

In [None]:
jax.lax.top_k

In [25]:
j_x = jnp.arange(24).reshape(2, 3, 4)
t_x = torch.arange(24).reshape(2, 3, 4)

In [26]:
j_x

Array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]], dtype=int32)

In [27]:
t_x

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [30]:
jax.lax.argmax(j_x, axis=0, index_dtype=jnp.int64)

  jax.lax.argmax(j_x,axis=0,index_dtype=jnp.int64)


Array([[1, 1, 1, 1],
       [1, 1, 1, 1],
       [1, 1, 1, 1]], dtype=int32)

In [32]:
torch.ops.aten.argmax.default(t_x, dim=0)

tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

In [33]:
j_x

Array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]], dtype=int32)

In [36]:
jax.lax.cumsum(j_x, 1)

Array([[[ 0,  1,  2,  3],
        [ 4,  6,  8, 10],
        [12, 15, 18, 21]],

       [[12, 13, 14, 15],
        [28, 30, 32, 34],
        [48, 51, 54, 57]]], dtype=int32)

In [51]:
jax.lax.reduce_max(
    j_x,
    axes=(
        2,
        1,
    ),
)

Array([11, 23], dtype=int32)

In [40]:
j_x

Array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]], dtype=int32)

In [50]:
torch.amax(t_x, dim=(2, 1))

tensor([11, 23])

In [54]:
torch.ops.aten.amax.default(
    t_x,
    dim=(
        2,
        1,
    ),
)

tensor([11, 23])

In [55]:
jax.lax.reduce_or(j_x, axes=(1, 2))

Array([15, 31], dtype=int32)

In [56]:
j_x_1 = jnp.array([[1.4, 2.4, 1.4], [1.1, 2.3, 2.5]])
jax.lax.reduce_or(j_x_1, axes=(0, 1))

TypeError: logical reduction requires operand dtype bool or int, got float32.

In [62]:
type(jax.lax.top_k(j_x, 1))

list

In [59]:
torch.topk(t_x, 1)

torch.return_types.topk(
values=tensor([[[ 3],
         [ 7],
         [11]],

        [[15],
         [19],
         [23]]]),
indices=tensor([[[3],
         [3],
         [3]],

        [[3],
         [3],
         [3]]]))

In [63]:
jax.lax.linalg.cholesky(j_x)

TypeError: add got incompatible shapes for broadcasting: (2, 3, 4), (2, 4, 3).

In [64]:
def make_spd_matrix_numpy(n, epsilon=1e-5):
    """
    生成一个 (n, n) 的对称正定矩阵
    """
    # 1. 生成随机矩阵 X
    # 使用 randn (正态分布) 或 rand (均匀分布) 都可以
    X = np.random.randn(n, n)

    # 2. 构造 Gram 矩阵: A = X @ X.T
    # 这一步保证了 A 是对称半正定
    A = np.dot(X, X.T)  # 或者 X @ X.T

    # 3. 添加对角线抖动 (Diagonal Jitter)
    # 这一步保证了 A 是严格正定 (特征值全部 > 0)
    # 没有这一步，np.linalg.cholesky 经常会因为浮点误差报错
    A += np.eye(n) * epsilon

    return A

In [65]:
j_x = jnp.array(make_spd_matrix_numpy(10))
j_x

Array([[ 8.85925   , -7.3990436 , -0.7349571 , -0.51188713, -6.956379  ,
         2.5022967 ,  1.2976339 ,  4.183757  , -1.8656918 ,  1.6225336 ],
       [-7.3990436 , 10.948375  ,  1.2316022 ,  0.32384533,  7.109779  ,
        -3.7424417 , -2.3546789 , -0.29677695,  4.271814  , -2.8162992 ],
       [-0.7349571 ,  1.2316022 ,  9.723164  , -3.9999845 ,  2.070975  ,
        -1.4916484 ,  2.9989853 ,  3.0603955 , -1.8261552 , -0.934536  ],
       [-0.51188713,  0.32384533, -3.9999845 , 14.525578  , -0.14456035,
        -0.38652018, -3.403186  , -4.4579296 , -1.9781421 ,  5.458454  ],
       [-6.956379  ,  7.109779  ,  2.070975  , -0.14456035, 10.807528  ,
        -3.0774956 , -2.165046  , -5.1064677 ,  2.6150033 , -1.2000664 ],
       [ 2.5022967 , -3.7424417 , -1.4916484 , -0.38652018, -3.0774956 ,
        11.761526  ,  5.7378354 ,  1.0636797 ,  0.02985495,  1.5216056 ],
       [ 1.2976339 , -2.3546789 ,  2.9989853 , -3.403186  , -2.165046  ,
         5.7378354 ,  8.157122  ,  5.1486387 

In [None]:
out_jax = jax.lax.linalg.cholesky(j_x)

Array([[ 2.9764493 ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [-2.4858625 ,  2.1837726 ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [-0.24692412,  0.28289703,  3.0955067 ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [-0.17197913, -0.04747342, -1.3015705 ,  3.5776613 ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [-2.33714   ,  0.59528637,  0.42819294,  0.01092444,  2.1925943 ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.8406986 , -0.7567549 , -0.34565455, -0.20341699, -0.23349062,
         3.2041695 ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.4359671 , -0.58198583,  1.0567827 , -0.5535349 , -0.56834114,
         1.5763456 ,  1.8431709 ,  0.        

In [67]:
t_x = torch.from_numpy(np.array(j_x))

In [68]:
torch.ops.aten.linalg_cholesky.default(t_x)

tensor([[ 2.9764,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-2.4859,  2.1838,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.2469,  0.2829,  3.0955,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.1720, -0.0475, -1.3016,  3.5777,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-2.3371,  0.5953,  0.4282,  0.0109,  2.1926,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.8407, -0.7568, -0.3457, -0.2034, -0.2335,  3.2042,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.4360, -0.5820,  1.0568, -0.5535, -0.5683,  1.5763,  1.8432,  0.0000,
          0.0000,  0.0000],
        [ 1.4056,  1.4642,  0.9670, -0.8073, -1.4130,  0.2591,  1.4691,  0.9800,
          0.0000,  0.0000],
        [-0.6268,  1.2426, -0.7535, -0.8407,  0.3385,  0.3573, -0.0249,  0.4180,
          0.9767,  0.0000],
        [ 0.5451, -