好 — 我会把 **FP64（IEEE-754 double）** 与 **P3109（任意 K-bit、P-precision 紧凑浮点交换格式）** 的细节完整实现出来，不再省略子正规数、精确指数偏置、Inf/NaN 处理、以及常见舍入模式（向零 / 最近（偶数） / 向下 / 向上）等。代码以 **PyTorch（向量化）** 实现，既可用于权重预量化，也可在推理时对中间激活进行严格模拟（通过 forward hook）。

> 注意：此实现是**软件层面的精确模拟**（bit-accurate行为在数学上贴近 P3109 规范中常见的位分配规则），但并非硬件实现。它会比未量化慢很多 —— 仅用于研究/仿真实验。请在小模型或小 batch 上先尝试。

下面给出完整实现（可直接复制到脚本或 notebook 使用） — 包含：

* `float_to_p3109`：将 float 张量按 P3109（K, P）规则量化回浮点值，支持舍入模式和子正规数、Inf/NaN、溢出处理、尾数进位导致的指数调整等；
* `apply_p3109_to_params(model, ...)`：对模型权重与偏置执行一次性的 P3109 量化（可用于部署前权重量化实验）；
* `install_activation_p3109_hooks(model, ...)`：在推理时对每个模块的输出进行 P3109 量化模拟（逐层仿真）；
* `run_inference_with_formats`：整合 fp64 与 p3109 两种模式的推理流程示例。

---

```python
import torch
import torch.nn as nn
import math
import time
from typing import Optional

# -----------------------------
# Helper: rounding functions
# -----------------------------
def _round_tensor(x: torch.Tensor, mode: str):
    """Round tensor x according to mode.
    Supported modes: 'nearest_even', 'toward_zero', 'floor', 'ceil'.
    """
    if mode == "nearest_even":
        # torch.round uses "round to nearest, ties to even"
        return torch.round(x)
    elif mode == "toward_zero":
        return torch.trunc(x)
    elif mode == "floor":
        return torch.floor(x)
    elif mode == "ceil":
        return torch.ceil(x)
    else:
        raise ValueError(f"Unsupported rounding mode: {mode}")

# -----------------------------
# Core: P3109 simulation
# -----------------------------
def float_to_p3109(x: torch.Tensor,
                   K: int = 8,
                   P: int = 3,
                   rounding: str = "nearest_even",
                   preserve_dtype: bool = True,
                   eps: float = 1e-45) -> torch.Tensor:
    """
    Simulate IEEE P3109-like compact floating point quantization on tensor x.
    - K: total bits (sign + exponent + significand).
    - P: precision (number of significand bits INCLUDING the implicit leading 1 for normals).
         So fraction bits = P - 1.
    - rounding: "nearest_even", "toward_zero", "floor", "ceil"
    - preserve_dtype: if True, returns tensor in same dtype as input (float32/float64)
                      (values are quantized but stored as float32/64).
    Returns quantized tensor (floating values reconstructed from K/P representation).
    Notes:
      - This handles Inf/NaN, normals, subnormals, zero, sign, exponent bias, rounding.
      - Assumes IEEE-like layout: sign(1) | exponent(e_bits) | fraction(f_bits),
        where e_bits = K - 1 - f_bits, f_bits = P - 1  -> e_bits = K - P.
    """
    if P < 1:
        raise ValueError("P must be >= 1")
    e_bits = K - P
    if e_bits < 1:
        raise ValueError("K and P invalid: need at least 1 exponent bit")

    f_bits = P - 1
    # Exponent bias (standard IEEE pattern)
    bias = (2 ** (e_bits - 1)) - 1
    # stored exponent values:
    max_stored_exp = (2 ** e_bits) - 2  # all ones reserved for Inf/NaN
    min_stored_exp = 1
    # Unbiased exponent for max/min_normal
    E_max_unbiased = max_stored_exp - bias
    E_min_unbiased = 1 - bias  # smallest normal exponent

    # For constructing subnormal quantization, subnormal exponent uses stored=0, value = fraction * 2^{E_min_unbiased}
    # fraction has f_bits bits (no implicit leading 1)
    # fraction integer ranges [0, 2^{f_bits} -1]

    # Work in float32 for intermediate stability (but keep original dtype if requested)
    orig_dtype = x.dtype
    work_dtype = torch.float32 if x.dtype != torch.float64 else torch.float64
    x = x.to(work_dtype)

    # Handle NaN/Inf separately
    is_nan = torch.isnan(x)
    is_inf = torch.isinf(x)
    sign = torch.sign(x)  # sign will be 0 for zeros; we need signbit
    signbit = torch.signbit(x)  # bool

    # absolute value
    absx = x.abs()

    # Zero handling (positive/negative zeros)
    is_zero = (absx == 0)

    # Prepare output
    out = torch.zeros_like(x, dtype=work_dtype)

    # Put Inf/NaN back
    out[is_nan] = torch.nan
    out[is_inf] = torch.sign(x[is_inf]) * float("inf")

    # Mask of finite & nonzero
    finite_mask = torch.isfinite(x) & (~is_zero) & (~is_nan) & (~is_inf)

    if finite_mask.any():
        xf = absx[finite_mask]

        # Use frexp to get mantissa m in [0.5, 1.0) and exponent e (integer)
        # x = m * 2**e
        m, e = torch.frexp(xf)  # m in [0.5,1)
        # Convert to normalized mantissa in [1.0, 2.0): mant = m * 2, exp_unbiased = e - 1
        mant = m * 2.0
        exp_unbiased = e - 1  # integer exponent (unbiased)

        # CASE 1: Normalized numbers: exp_unbiased in [E_min_unbiased, E_max_unbiased]
        normal_mask = (exp_unbiased >= E_min_unbiased) & (exp_unbiased <= E_max_unbiased)
        # CASE 2: Subnormals: exp_unbiased < E_min_unbiased (very small) -> handled as subnormal
        subnormal_mask = exp_unbiased < E_min_unbiased
        # CASE 3: Overflowed beyond max normal -> treat as Inf (will be handled if rounding pushes above E_max)
        overflow_mask = exp_unbiased > E_max_unbiased

        # Process normals
        if normal_mask.any():
            mant_n = mant[normal_mask]  # in [1,2)
            exp_n = exp_unbiased[normal_mask].to(torch.long)  # integer

            # fraction f in [0,1): f = mant - 1
            f = mant_n - 1.0
            # scale fraction to integer steps of 2^{-f_bits}: scaled = f * 2^{f_bits}
            scaled = f * (2.0 ** f_bits)

            # rounding according to requested mode
            if rounding == "nearest_even":
                scaled_q = torch.round(scaled)  # ties to even
            elif rounding == "toward_zero":
                scaled_q = torch.trunc(scaled)
            elif rounding == "floor":
                scaled_q = torch.floor(scaled)
            elif rounding == "ceil":
                scaled_q = torch.ceil(scaled)
            else:
                raise ValueError("Unsupported rounding")

            # Handle possible carry when scaled_q == 2^{f_bits} (i.e., mantissa rounded from 1.111... to 2.0)
            carry_mask = (scaled_q >= (2 ** f_bits))
            if carry_mask.any():
                # increment exponent for those entries, set fraction to 0
                idxs = torch.nonzero(normal_mask, as_tuple=False).squeeze(1)
                carry_idxs = idxs[carry_mask]
                # For vectorization, we update scaled_q and exp_n
                scaled_q = scaled_q.clone()
                scaled_q[carry_mask] = 0.0
                exp_n = exp_n.clone()
                exp_n[carry_mask] = exp_n[carry_mask] + 1

            # Now check for exponent overflow after carry
            # When exp_n > E_max_unbiased -> becomes Inf
            overflow_after_carry = exp_n > E_max_unbiased
            # Reconstruct mantissa: mant_q = 1 + scaled_q / 2^{f_bits}
            mant_q = 1.0 + scaled_q / (2.0 ** f_bits)
            # Reconstruct value
            val_n = mant_q * (2.0 ** exp_n)
            # Where overflow_after_carry -> set to Inf
            val_n[overflow_after_carry] = float("inf")

            # Write back to out in positions corresponding to normal_mask
            # find indices in xf that correspond to normal_mask
            normal_indices = torch.nonzero(normal_mask, as_tuple=False).squeeze(1)
            # Map back positions in original tensor
            orig_positions = torch.nonzero(finite_mask, as_tuple=False).squeeze(1)[normal_indices]
            out[orig_positions] = val_n

        # Process subnormals
        if subnormal_mask.any():
            mant_s = mant[subnormal_mask]  # still mant computed from frexp; but for subnormals we need different formula
            exp_s = exp_unbiased[subnormal_mask].to(torch.long)
            # For subnormals: stored exponent = 0; effective exponent = E_min_unbiased
            # The representation is: value = fraction / 2^{f_bits} * 2^{E_min_unbiased}
            # Compute scaled_fraction = xf / 2^{E_min_unbiased} * 2^{f_bits}
            # scaled_fraction = xf / (2^{E_min_unbiased - f_bits})
            scaled_fraction = xf[subnormal_mask] / (2.0 ** E_min_unbiased) * (2.0 ** f_bits)
            # Now round scaled_fraction to integer in [0, 2^{f_bits}-1]
            if rounding == "nearest_even":
                frac_q = torch.round(scaled_fraction)
            elif rounding == "toward_zero":
                frac_q = torch.trunc(scaled_fraction)
            elif rounding == "floor":
                frac_q = torch.floor(scaled_fraction)
            elif rounding == "ceil":
                frac_q = torch.ceil(scaled_fraction)
            else:
                raise ValueError("Unsupported rounding")

            # clamp to valid range for fraction
            frac_q = torch.clamp(frac_q, 0.0, float(2 ** f_bits - 1))

            # If frac_q becomes zero -> true zero
            val_s = (frac_q / (2.0 ** f_bits)) * (2.0 ** E_min_unbiased)
            # Map back positions
            sub_indices = torch.nonzero(subnormal_mask, as_tuple=False).squeeze(1)
            orig_positions_s = torch.nonzero(finite_mask, as_tuple=False).squeeze(1)[sub_indices]
            out[orig_positions_s] = val_s

        # Process overflow (exp_unbiased > E_max_unbiased) -> Inf
        if overflow_mask.any():
            overflow_indices = torch.nonzero(overflow_mask, as_tuple=False).squeeze(1)
            orig_positions_o = torch.nonzero(finite_mask, as_tuple=False).squeeze(1)[overflow_indices]
            out[orig_positions_o] = float("inf")

    # Restore sign
    # negative numbers should be negative of abs-values, also preserve negative zero
    if out.numel() > 0:
        # For NaN/Inf/zero, sign(x) for NaN is weird; we handle signbit explicitly
        s_mask = torch.zeros_like(out, dtype=torch.bool)
        # construct signmask: True where original x negative (including -0.0)
        if x.numel() > 0:
            s_mask = signbit
        out = out * (~s_mask).to(out.dtype) + (-out) * s_mask.to(out.dtype)  # if signbit True, negate

    # Preserve original zeros' sign: if original was -0.0, keep -0.0
    if is_zero.any():
        zero_positions = torch.nonzero(is_zero, as_tuple=False).squeeze(1)
        # restore signed zero: use copy from original x
        out[zero_positions] = x[zero_positions] * 0.0  # this keeps sign bit of x in IEEE semantics

    # For NaN/Inf we already wrote; but their sign handled by multiplying sign
    # Ensure dtype
    if preserve_dtype:
        return out.to(orig_dtype)
    else:
        return out

# -----------------------------
# Utilities to quantize model
# -----------------------------
def apply_p3109_to_params(model: nn.Module,
                          K: int = 8,
                          P: int = 3,
                          rounding: str = "nearest_even",
                          inplace: bool = True):
    """
    Quantize all model parameters (weights & biases) to simulated P3109 representation.
    If inplace=True, modifies .data of parameters.
    """
    for name, p in model.named_parameters(recurse=True):
        if p is None:
            continue
        q = float_to_p3109(p.data, K=K, P=P, rounding=rounding, preserve_dtype=True)
        if inplace:
            p.data.copy_(q)
        else:
            p.data = q
    return model

# Activation hooks for per-layer simulation
def _make_hook(K, P, rounding):
    def hook(module, input, output):
        # Quantize activation outputs (output can be tensor or tuple)
        if isinstance(output, torch.Tensor):
            return float_to_p3109(output, K=K, P=P, rounding=rounding, preserve_dtype=True)
        elif isinstance(output, (list, tuple)):
            qitems = []
            for o in output:
                if isinstance(o, torch.Tensor):
                    qitems.append(float_to_p3109(o, K=K, P=P, rounding=rounding, preserve_dtype=True))
                else:
                    qitems.append(o)
            return type(output)(qitems)
        else:
            return output
    return hook

def install_activation_p3109_hooks(model: nn.Module,
                                   K: int = 8,
                                   P: int = 3,
                                   rounding: str = "nearest_even",
                                   modules_to_hook: Optional[tuple] = (nn.Linear, nn.Conv2d, nn.Conv1d, nn.Conv3d)):
    """
    Install forward hooks to quantize module outputs to P3109 representation during forward pass.
    modules_to_hook: tuple of module types to attach hooks to. Default: common compute layers.
    Returns list of hook handles (so user can remove them later).
    """
    handles = []
    for m in model.modules():
        if isinstance(m, modules_to_hook):
            handles.append(m.register_forward_hook(_make_hook(K, P, rounding)))
    return handles

# -----------------------------
# FP64 handling (straightforward)
# -----------------------------
def apply_fp64_to_model(model: nn.Module, inplace=True):
    if inplace:
        model.double()
    else:
        model = model.to(torch.float64)
    return model

# -----------------------------
# Integration: run inference with modes
# -----------------------------
def run_inference_with_formats(model: nn.Module,
                               input_tensor: torch.Tensor,
                               mode: str = "fp32",
                               p3109_K: int = 8,
                               p3109_P: int = 3,
                               p3109_rounding: str = "nearest_even",
                               quantize_activations: bool = True):
    """
    mode: "fp32", "fp64", "p3109"
    If mode == "p3109", we:
      - quantize weights (one-time)
      - optionally install activation hooks to quantize activations per layer
    Returns (output, elapsed_seconds)
    """
    model = model.eval()
    # copy model to avoid destructive changes if needed
    model_copy = model  # if user wants to avoid in-place, do deepcopy outside

    if mode == "fp32":
        model_copy = model_copy.float()
        input_t = input_tensor.float()

        start = time.time()
        with torch.no_grad():
            out = model_copy(input_t)
        elapsed = time.time() - start
        return out, elapsed

    elif mode == "fp64":
        model_copy = apply_fp64_to_model(model_copy, inplace=True)
        input_t = input_tensor.double()

        start = time.time()
        with torch.no_grad():
            out = model_copy(input_t)
        elapsed = time.time() - start
        return out, elapsed

    elif mode == "p3109":
        # 1) quantize params
        apply_p3109_to_params(model_copy, K=p3109_K, P=p3109_P, rounding=p3109_rounding, inplace=True)
        # 2) optionally quantize activations via hooks
        handles = []
        if quantize_activations:
            handles = install_activation_p3109_hooks(model_copy, K=p3109_K, P=p3109_P, rounding=p3109_rounding)

        # ensure model & input dtype for compute (use float32 to compute but with quantized params/acts)
        model_copy = model_copy.float()
        input_t = float_to_p3109(input_tensor.float(), K=p3109_K, P=p3109_P, rounding=p3109_rounding, preserve_dtype=True)

        start = time.time()
        with torch.no_grad():
            out = model_copy(input_t)
        elapsed = time.time() - start

        # remove hooks
        for h in handles:
            h.remove()
        return out, elapsed

    else:
        raise ValueError("Unsupported mode")

# -----------------------------
# Example usage
# -----------------------------
if __name__ == "__main__":
    # small test
    M = nn.Linear(10, 5)
    x = torch.randn(2, 10)

    # fp64 run
    out64, t64 = run_inference_with_formats(M, x, mode="fp64")
    print("fp64 time:", t64, "out dtype:", out64.dtype)

    # p3109 run (e.g., K=8, P=3 -> e_bits = 5, f_bits=2)
    out_p3109, tp = run_inference_with_formats(M, x, mode="p3109", p3109_K=8, p3109_P=3, p3109_rounding="nearest_even")
    print("p3109 time:", tp, "out dtype:", out_p3109.dtype)
```

---

## 细节说明（实现如何对应规范要点）

1. **位划分（K, P）**：

   * `f_bits = P - 1`（尾数的显式 fraction bits）。
   * `e_bits = K - P`（指数位数）。
   * 指数偏置 `bias = 2^(e_bits-1) - 1`：与 IEEE 风格一致。
   * 存储值 `0`（所有指数位都为 0）用于**零与子正规数**；存储值 `2^{e_bits}-1` (all ones) 为 **Inf/NaN**。

2. **规范化数（normal）**：

   * 使用 `frexp()` 得到 `mantissa ∈ [0.5,1)` 与 `e`；换算为 `mant ∈ [1,2)` 与 `exp_unbiased = e - 1`。
   * 将 `fraction = mant - 1` 量化到 `2^{-f_bits}` 的步长，按所选取的舍入规则进行舍入。
   * 若舍入导致尾数进位（mant -> 2.0），则指数加 1；若指数超过最大可表示范围则置为 Inf。

3. **子正规数（subnormal）**：

   * 当真实 `exp_unbiased < E_min_unbiased` 时，将数视作子正规：其 stored exponent = 0，value = `(integer_fraction / 2^{f_bits}) * 2^{E_min_unbiased}`。
   * 用 `scaled_fraction = x / 2^{E_min_unbiased} * 2^{f_bits}`，对其进行舍入/截断并重建子正规值（并 clamp 到 [0, 2^{f_bits}-1]）。

4. **Inf / NaN / 零**：

   * 原值为 `NaN` 的位置保留 NaN；为 `Inf` 的位置保留±Inf（保留符号）。
   * 原为正/负零会保留符号（我们用 `x * 0.0` 的 trick 保留 sign bit）。

5. **舍入模式**：

   * `nearest_even` 使用 `torch.round()`（PyTorch 实现为 ties-to-even）。
   * `toward_zero` 使用 `torch.trunc()`。
   * `floor/ceil` 支持（罕用，但实现了）。

6. **精确性 / 数值稳定性**：

   * 在向量化实现中使用 `float32`（或 `float64`）临时计算以保持稳定性，但最终返回的值为原 dtype（若 `preserve_dtype=True`）。
   * 该实现对每个非零有限值都做了精确的拆解与重建（mantissa、指数），因此在数值上接近真实 K/P 表示。

## 性能与警告

* 该模拟会在每次量化处做较多的算术与分支判断（frexp、log2、clamp、round……），速度远慢于原生浮点计算。如果你想用于大型模型或生产，请仅在离线仿真或评估阶段使用，或对关键层（如权重）做一次性量化，而不是对每次激活都做逐元素仿真（activation hooks 可选择性打开）。
* 子正规数与极小值处的舍入会对模型行为有显著影响（尤其 BN、softmax、层归一化等），建议在小批次上验证数值稳定性。

---

如果你愿意，我可以做进一步两件事（你选一项或都要）：

1. 把上面代码扩展为可选的「逐层选择量化」脚本 —— 例如只量化线性/卷积权重并只在前 N 层做激活量化，方便对不同策略做 ablation。
2. 用你提供的模型（或我用一个标准 Transformer / ResNet 示例）跑一次对比：`fp32` vs `fp64` vs `p3109`（给出输出差异统计：L2、max abs、top-k 保持率，及时间对比），帮助你评估 P3109 模拟对模型性能的影响。

你要我先做哪一项？或者把你要测试的模型/输入发给我，我直接替你跑一次（若你允许我修改权重与注入 hooks）？


In [None]:
# -----------------------------
# Core: P3109 simulation
# -----------------------------
def float_to_p3109(x: torch.Tensor,
                   K: int = 8,
                   P: int = 3,
                   rounding: str = "nearest_even",
                   preserve_dtype: bool = True,
                   eps: float = 1e-45) -> torch.Tensor:
    """
    Simulate IEEE P3109-like compact floating point quantization on tensor x.
    - K: total bits (sign + exponent + significand).
    - P: precision (number of significand bits INCLUDING the implicit leading 1 for normals).
         So fraction bits = P - 1.
    - rounding: "nearest_even", "toward_zero", "floor", "ceil"
    - preserve_dtype: if True, returns tensor in same dtype as input (float32/float64)
                      (values are quantized but stored as float32/64).
    Returns quantized tensor (floating values reconstructed from K/P representation).
    Notes:
      - This handles Inf/NaN, normals, subnormals, zero, sign, exponent bias, rounding.
      - Assumes IEEE-like layout: sign(1) | exponent(e_bits) | fraction(f_bits),
        where e_bits = K - 1 - f_bits, f_bits = P - 1  -> e_bits = K - P.
    """
    if P < 1:
        raise ValueError("P must be >= 1")
    e_bits = K - P
    if e_bits < 1:
        raise ValueError("K and P invalid: need at least 1 exponent bit")

    f_bits = P - 1
    # Exponent bias (standard IEEE pattern)
    bias = (2 ** (e_bits - 1)) - 1
    # stored exponent values:
    max_stored_exp = (2 ** e_bits) - 2  # all ones reserved for Inf/NaN
    min_stored_exp = 1
    # Unbiased exponent for max/min_normal
    E_max_unbiased = max_stored_exp - bias
    E_min_unbiased = 1 - bias  # smallest normal exponent

    # For constructing subnormal quantization, subnormal exponent uses stored=0, value = fraction * 2^{E_min_unbiased}
    # fraction has f_bits bits (no implicit leading 1)
    # fraction integer ranges [0, 2^{f_bits} -1]

    # Work in float32 for intermediate stability (but keep original dtype if requested)
    orig_dtype = x.dtype
    work_dtype = torch.float32 if x.dtype != torch.float64 else torch.float64
    x = x.to(work_dtype)

    # Handle NaN/Inf separately
    is_nan = torch.isnan(x)
    is_inf = torch.isinf(x)
    sign = torch.sign(x)  # sign will be 0 for zeros; we need signbit
    signbit = torch.signbit(x)  # bool

    # absolute value
    absx = x.abs()

    # Zero handling (positive/negative zeros)
    is_zero = (absx == 0)

    # Prepare output
    out = torch.zeros_like(x, dtype=work_dtype)

    # Put Inf/NaN back
    out[is_nan] = torch.nan
    out[is_inf] = torch.sign(x[is_inf]) * float("inf")

    # Mask of finite & nonzero
    finite_mask = torch.isfinite(x) & (~is_zero) & (~is_nan) & (~is_inf)

    if finite_mask.any():
        xf = absx[finite_mask]

        # Use frexp to get mantissa m in [0.5, 1.0) and exponent e (integer)
        # x = m * 2**e
        m, e = torch.frexp(xf)  # m in [0.5,1)
        # Convert to normalized mantissa in [1.0, 2.0): mant = m * 2, exp_unbiased = e - 1
        mant = m * 2.0
        exp_unbiased = e - 1  # integer exponent (unbiased)

        # CASE 1: Normalized numbers: exp_unbiased in [E_min_unbiased, E_max_unbiased]
        normal_mask = (exp_unbiased >= E_min_unbiased) & (exp_unbiased <= E_max_unbiased)
        # CASE 2: Subnormals: exp_unbiased < E_min_unbiased (very small) -> handled as subnormal
        subnormal_mask = exp_unbiased < E_min_unbiased
        # CASE 3: Overflowed beyond max normal -> treat as Inf (will be handled if rounding pushes above E_max)
        overflow_mask = exp_unbiased > E_max_unbiased

        # Process normals
        if normal_mask.any():
            mant_n = mant[normal_mask]  # in [1,2)
            exp_n = exp_unbiased[normal_mask].to(torch.long)  # integer

            # fraction f in [0,1): f = mant - 1
            f = mant_n - 1.0
            # scale fraction to integer steps of 2^{-f_bits}: scaled = f * 2^{f_bits}
            scaled = f * (2.0 ** f_bits)

            # rounding according to requested mode
            if rounding == "nearest_even":
                scaled_q = torch.round(scaled)  # ties to even
            elif rounding == "toward_zero":
                scaled_q = torch.trunc(scaled)
            elif rounding == "floor":
                scaled_q = torch.floor(scaled)
            elif rounding == "ceil":
                scaled_q = torch.ceil(scaled)
            else:
                raise ValueError("Unsupported rounding")

            # Handle possible carry when scaled_q == 2^{f_bits} (i.e., mantissa rounded from 1.111... to 2.0)
            carry_mask = (scaled_q >= (2 ** f_bits))
            if carry_mask.any():
                # increment exponent for those entries, set fraction to 0
                idxs = torch.nonzero(normal_mask, as_tuple=False).squeeze(1)
                carry_idxs = idxs[carry_mask]
                # For vectorization, we update scaled_q and exp_n
                scaled_q = scaled_q.clone()
                scaled_q[carry_mask] = 0.0
                exp_n = exp_n.clone()
                exp_n[carry_mask] = exp_n[carry_mask] + 1

            # Now check for exponent overflow after carry
            # When exp_n > E_max_unbiased -> becomes Inf
            overflow_after_carry = exp_n > E_max_unbiased
            # Reconstruct mantissa: mant_q = 1 + scaled_q / 2^{f_bits}
            mant_q = 1.0 + scaled_q / (2.0 ** f_bits)
            # Reconstruct value
            val_n = mant_q * (2.0 ** exp_n)
            # Where overflow_after_carry -> set to Inf
            val_n[overflow_after_carry] = float("inf")

            # Write back to out in positions corresponding to normal_mask
            # find indices in xf that correspond to normal_mask
            normal_indices = torch.nonzero(normal_mask, as_tuple=False).squeeze(1)
            # Map back positions in original tensor
            orig_positions = torch.nonzero(finite_mask, as_tuple=False).squeeze(1)[normal_indices]
            out[orig_positions] = val_n

        # Process subnormals
        if subnormal_mask.any():
            mant_s = mant[subnormal_mask]  # still mant computed from frexp; but for subnormals we need different formula
            exp_s = exp_unbiased[subnormal_mask].to(torch.long)
            # For subnormals: stored exponent = 0; effective exponent = E_min_unbiased
            # The representation is: value = fraction / 2^{f_bits} * 2^{E_min_unbiased}
            # Compute scaled_fraction = xf / 2^{E_min_unbiased} * 2^{f_bits}
            # scaled_fraction = xf / (2^{E_min_unbiased - f_bits})
            scaled_fraction = xf[subnormal_mask] / (2.0 ** E_min_unbiased) * (2.0 ** f_bits)
            # Now round scaled_fraction to integer in [0, 2^{f_bits}-1]
            if rounding == "nearest_even":
                frac_q = torch.round(scaled_fraction)
            elif rounding == "toward_zero":
                frac_q = torch.trunc(scaled_fraction)
            elif rounding == "floor":
                frac_q = torch.floor(scaled_fraction)
            elif rounding == "ceil":
                frac_q = torch.ceil(scaled_fraction)
            else:
                raise ValueError("Unsupported rounding")

            # clamp to valid range for fraction
            frac_q = torch.clamp(frac_q, 0.0, float(2 ** f_bits - 1))

            # If frac_q becomes zero -> true zero
            val_s = (frac_q / (2.0 ** f_bits)) * (2.0 ** E_min_unbiased)
            # Map back positions
            sub_indices = torch.nonzero(subnormal_mask, as_tuple=False).squeeze(1)
            orig_positions_s = torch.nonzero(finite_mask, as_tuple=False).squeeze(1)[sub_indices]
            out[orig_positions_s] = val_s

        # Process overflow (exp_unbiased > E_max_unbiased) -> Inf
        if overflow_mask.any():
            overflow_indices = torch.nonzero(overflow_mask, as_tuple=False).squeeze(1)
            orig_positions_o = torch.nonzero(finite_mask, as_tuple=False).squeeze(1)[overflow_indices]
            out[orig_positions_o] = float("inf")

    # Restore sign
    # negative numbers should be negative of abs-values, also preserve negative zero
    if out.numel() > 0:
        # For NaN/Inf/zero, sign(x) for NaN is weird; we handle signbit explicitly
        s_mask = torch.zeros_like(out, dtype=torch.bool)
        # construct signmask: True where original x negative (including -0.0)
        if x.numel() > 0:
            s_mask = signbit
        out = out * (~s_mask).to(out.dtype) + (-out) * s_mask.to(out.dtype)  # if signbit True, negate

    # Preserve original zeros' sign: if original was -0.0, keep -0.0
    if is_zero.any():
        zero_positions = torch.nonzero(is_zero, as_tuple=False).squeeze(1)
        # restore signed zero: use copy from original x
        out[zero_positions] = x[zero_positions] * 0.0  # this keeps sign bit of x in IEEE semantics

    # For NaN/Inf we already wrote; but their sign handled by multiplying sign
    # Ensure dtype
    if preserve_dtype:
        return out.to(orig_dtype)
    else:
        return out