### Learning Precision on JAX

#### 1. XLA 编译的不确定性 
(a + b) + c   
a + (b + c)
有时为了并行化，并不总是一样，可能会重排求和顺序

In [2]:
import jax

jax.config.update("jax_enable_x64", True)

In [3]:
import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = nn.LayerNorm([10000, 1000])  # LayerNorm for 2D input

    def forward(self, x):
        x = self.ln(x)
        return x


model = Model().eval()

x = torch.randn(
    1, 3, 10000, 1000
)  # As `H` and `W` increase, the error might be amplified

inputs = [x]

c_model = torch.compile(model)

output = model(*inputs)

c_output = c_model(*inputs)

print(torch.allclose(output, c_output, 1e-5, 1e-5))  # loose check in fp32
print(torch.max(torch.abs(output - c_output)))

fp_64_ref = c_model(x.double())
print("Eager divergence", torch.max(torch.abs(output - fp_64_ref)))
print("Compile divergence divergence", torch.max(torch.abs(c_output - fp_64_ref)))

True
tensor(9.5367e-07, grad_fn=<MaxBackward1>)
Eager divergence tensor(7.8756e-07, dtype=torch.float64, grad_fn=<MaxBackward1>)
Compile divergence divergence tensor(8.9890e-07, dtype=torch.float64, grad_fn=<MaxBackward1>)


In [4]:
import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = nn.LayerNorm([10000, 1000])

    def forward(self, x):
        x = self.ln(x)
        return x


model = Model().eval()
x = torch.randn(1, 3, 10000, 1000)
inputs = [x]
c_model = torch.compile(model)
output = model(*inputs)
c_output = c_model(*inputs)

print(torch.allclose(output, c_output, 1e-5, 1e-5))  # loose check in fp32
print(torch.max(torch.abs(output - c_output)))

fp_64_ref = c_model(x.double())
print("Eager divergence", torch.max(torch.abs(output - fp_64_ref)))
print("Compile divergence", torch.max(torch.abs(c_output - fp_64_ref)))

True
tensor(9.5367e-07, grad_fn=<MaxBackward1>)
Eager divergence tensor(9.2591e-07, dtype=torch.float64, grad_fn=<MaxBackward1>)
Compile divergence tensor(9.6588e-07, dtype=torch.float64, grad_fn=<MaxBackward1>)


In [5]:
import jax.numpy as jnp
from jax.config import config

config.update("jax_enable_x64", True)

a = jnp.array(
    [float("inf"), 0.0, 1, -0.0, float("nan"), -float("inf"), -1, -float("nan")]
).astype(jnp.float32)
b = jnp.sort(a)
print(a)  # [ inf   0.   1.  -0.  nan -inf  -1.  nan]
print(b)  # [-inf  -1.   0.  -0.   1.  inf  nan  nan]

a64 = a.astype(jnp.float64)
print(a64)  # [ inf   0.   1.   0.  nan -inf  -1.  nan]

b64 = jnp.sort(
    a64
)  # XlaRuntimeError: UNIMPLEMENTED: While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %bitcast-convert.9 = s64[] bitcast-convert(f64[] %Arg_0.3), metadata={op_name="jit(sort)/jit(main)/bitcast_convert_type[new_dtype=int64]" source_file="<ipython-input-6-9289161c541b>" source_line=1}
print(b64)

ModuleNotFoundError: No module named 'jax.config'

In [None]:
# 灾难性抵消
x = 1234567891.0000001
y = 1234567891.0
x - y

0.0

In [None]:
# 大数吃小数
x = 1e16
y = 1.0
x + y

1e+16

In [None]:
# 浮点数溢出
import numpy as np

x = np.array([2e38], dtype=np.float32)
x * 2

  x * 2


array([inf], dtype=float32)

In [23]:
import torch
from torch.testing._comparison import get_tolerances

precision = [torch.bfloat16, torch.float16, torch.float32, torch.float64]

for dtype in precision:
    tolerance = get_tolerances(dtype, rtol=None, atol=None)
    print(f"Precision: {dtype}, Tolerances: {tolerance}")

Precision: torch.bfloat16, Tolerances: (0.016, 1e-05)
Precision: torch.float16, Tolerances: (0.001, 1e-05)
Precision: torch.float32, Tolerances: (1.3e-06, 1e-05)
Precision: torch.float64, Tolerances: (1e-07, 1e-07)


In [24]:
import torch
import jax
import jax.numpy as jnp
import numpy as np

# 开启 JAX 的 float64 支持（必须在开头）
jax.config.update("jax_enable_x64", True)

# 生成统一的随机输入 (双精度)
input_np = np.random.randn(2, 16, 128).astype(np.float64)
gamma_np = np.ones((128,), dtype=np.float64)
beta_np = np.zeros((128,), dtype=np.float64)
eps = 1e-5

In [25]:
input_pt = torch.from_numpy(input_np)
gamma_pt = torch.from_numpy(gamma_np)
beta_pt = torch.from_numpy(beta_np)

# 使用原生算子进行计算
# 注意：LayerNorm 的参数通常是 normalized_shape
ln_pt = torch.nn.functional.layer_norm(
    input_pt, (128,), weight=gamma_pt, bias=beta_pt, eps=eps
)
res_pt = ln_pt.numpy()

In [26]:
# 强制 JAX 使用最高精度，避免 TPU/GPU 上的隐式 bf16 转换
with jax.default_matmul_precision("highest"):

    def jax_layer_norm(x, g, b, epsilon):
        mean = jnp.mean(x, axis=-1, keepdims=True)
        # 使用有偏方差 (ddof=0)
        var = jnp.var(x, axis=-1, keepdims=True, ddof=0)
        return (x - mean) / jnp.sqrt(var + epsilon) * g + b

    res_jax = jax_layer_norm(input_np, gamma_np, beta_np, eps)

In [29]:
# 使用 PyTorch 提供的标准测试工具
res_jax_np = np.array(res_jax)
res_jax_pt = torch.from_numpy(res_jax_np)
try:
    torch.testing.assert_close(
        res_jax_pt,
        torch.from_numpy(res_pt),
        rtol=None,
        atol=None,  # 强制使用官方默认容差
    )
    print("✅ 精度对齐通过！")
except Exception as e:
    print(f"❌ 对齐失败: {e}")

✅ 精度对齐通过！


##### 2D 卷积算子的精度对齐 demo

In [30]:
import torch
import jax
import jax.numpy as jnp
from jax import lax
import numpy as np

# 开启 JAX 高精度支持
jax.config.update("jax_enable_x64", True)

# 定义超参数
batch, in_channels, out_channels = 2, 3, 4
h, w = 10, 10
kernel_size = 3
stride = 1
padding = 1  # PyTorch 的像素填充

# 生成随机输入（双精度）
x_np = np.random.randn(batch, in_channels, h, w).astype(np.float64)
w_np = np.random.randn(out_channels, in_channels, kernel_size, kernel_size).astype(
    np.float64
)

In [31]:
x_torch = torch.from_numpy(x_np)
w_torch = torch.from_numpy(w_np)

# 执行卷积
out_torch = torch.nn.functional.conv2d(x_torch, w_torch, stride=stride, padding=padding)

In [32]:
# 1. 维度对齐：告知 JAX 输入是 NCHW，权重是 OIHW
# lhs_spec: (batch, feature, spatial...) [cite: 177]
# rhs_spec: (out_feature, in_feature, spatial...) [cite: 178]
dn = lax.conv_dimension_numbers(x_np.shape, w_np.shape, ("NCHW", "OIHW", "NCHW"))

# 2. 填充对齐：PyTorch 的 padding=1 等价于 JAX 的 [(1, 1), (1, 1)]
jax_padding = [(padding, padding), (padding, padding)]

# 3. 执行卷积 [cite: 35, 40]
# precision 使用 HIGHEST 以确保在所有硬件上对齐 [cite: 376, 377]
out_jax = lax.conv_general_dilated(
    lhs=jnp.array(x_np),
    rhs=jnp.array(w_np),
    window_strides=(stride, stride),
    padding=jax_padding,
    dimension_numbers=dn,
    precision=lax.Precision.HIGHEST,
)

In [33]:
# 转换类型以便比较
out_jax_torch = torch.from_numpy(np.array(out_jax))

try:
    # 核心：不传入 rtol/atol，使用官方默认值
    torch.testing.assert_close(out_jax_torch, out_torch)
    print("✅ 算子对齐成功！")
    print(f"最大绝对误差: {(out_jax_torch - out_torch).abs().max().item()}")
except Exception as e:
    print(f"❌ 对齐失败: \n{e}")

✅ 算子对齐成功！
最大绝对误差: 5.329070518200751e-15


In [34]:
import torch
import jax.numpy as jnp
import numpy as np
import torch.nn.functional as F

# 生成测试数据 (float64 黄金基准)
x_np = np.random.randn(1024).astype(np.float64) * 10  # 扩大范围测试稳定性

# PyTorch 实现
res_pt = F.silu(torch.from_numpy(x_np)).numpy()

# JAX 实现
res_jax = jnp.array(x_np)
# 注：jax.numpy.nn 包含这些激活函数，也可以用数学公式表达
res_jax = res_jax * jax.nn.sigmoid(res_jax)
res_jax = np.array(res_jax)

# 校验
torch.testing.assert_close(torch.from_numpy(res_pt), torch.from_numpy(res_jax))

In [None]:
# 生成测试数据
x_np = np.random.randn(32, 100).astype(np.float64)

# PyTorch 实现
# dim=-1 表示在最后一个维度（100）上做归一化
res_pt = F.softmax(torch.from_numpy(x_np), dim=-1).numpy()

# JAX 实现
# axis=-1 对应 PyTorch 的 dim=-1


res_jax = jax.nn.softmax(jnp.array(x_np), axis=-1)
res_jax = np.array(res_jax)

# 校验
torch.testing.assert_close(torch.from_numpy(res_pt), torch.from_numpy(res_jax))

In [2]:
import jax.numpy as jnp

jnp.float8_e4m3fn

jax.numpy.float8_e4m3fn

In [3]:
import torch

torch.float8_e4m3fn

torch.float8_e4m3fn

In [4]:
jnp.float8_e5m2

jax.numpy.float8_e5m2

In [5]:
torch.float8_e5m2

torch.float8_e5m2