In [None]:
import torch
import triton
import triton.language as tl

@triton.jit
def float32_to_bf16_kernel(
    input_ptr, 
    output_ptr, 
    n_elements, 
    BLOCK: tl.constexpr,
    BF16ASM: tl.constexpr
):
    # 计算当前线程块处理的偏移量
    pid = tl.program_id(axis=0)
    # 创建一个块内的偏移量
    block_offset = tl.arange(0, BLOCK)
    # 计算全局偏移量
    offsets = pid * BLOCK + block_offset
    
    # 检查是否超出边界
    mask = offsets < n_elements
    
    # 加载float32数据
    x = tl.load(input_ptr + offsets, mask=mask)
    
    # 使用inline assembly将float32转换为bf16
    if BF16ASM:
        x = tl.inline_asm_elementwise(
            asm="""
            {
                .reg .b32 tmp, round_bit, mantissa, sign_exp, rounding_bias;
                .reg .b16 result;
                .reg .pred p;
            
                // 获取原始32位浮点值
                mov.b32 tmp, $1;
                
                // 提取符号位和指数位
                and.b32 sign_exp, tmp, 0xFF800000;
                
                // 提取尾数位
                and.b32 mantissa, tmp, 0x007FFFFF;
                
                // 获取要舍入的位 (第16位)
                and.b32 round_bit, mantissa, 0x00008000;
                
                // 检查是否需要进行舍入 (round-to-nearest-even)
                // 如果第16位是1，我们需要考虑舍入
                setp.eq.u32 p, round_bit, 0x00008000;
                // 如果是奇数（第17位是1）或者低位有非零值，则进位
                and.b32 rounding_bias, mantissa, 0x0000FFFF;
                @p add.u32 rounding_bias, rounding_bias, 0x00008000;
                and.b32 rounding_bias, rounding_bias, 0x00010000;
                
                // 添加舍入偏移量到原始值
                add.u32 tmp, tmp, rounding_bias;
                
                // 提取高16位（符号、指数和高位尾数）
                shr.b32 tmp, tmp, 16;
                
                // 转换为16位并存储结果
                cvt.u16.u32 result, tmp;
                mov.b16 $0, result;
            }
            """,
            constraints=("=h, r"),  # h表示16位寄存器，r表示32位寄存器
            args=[x],
            dtype=(tl.bfloat16),  # 输出类型为bfloat16
            is_pure=True,
            pack=1,
        )
    else:
        x += 1
    # 存储结果
    tl.store(output_ptr + offsets, x, mask=mask)

def float32_to_bf16(x: torch.Tensor) -> torch.Tensor:
    """
    将float32 tensor转换为bfloat16 tensor
    
    Args:
        x: 输入的float32 tensor
    
    Returns:
        bfloat16 tensor
    """
    # 确保输入tensor在GPU上
    assert x.is_cuda and x.dtype == torch.float32, "输入必须是CUDA float32 tensor"
    
    # 创建输出tensor
    y = torch.empty(x.shape, dtype=torch.bfloat16, device=x.device)
    
    # 获取元素总数
    n_elements = x.numel()
    
    # 定义线程块大小
    BLOCK = 1024
    
    # 计算需要启动的线程块数量
    grid = ((n_elements + BLOCK - 1) // BLOCK,)
    
    # 启动kernel
    float32_to_bf16_kernel[grid](
        x,
        y,
        n_elements,
        BLOCK,
        True
    )
    
    return y

def compare_conversion_methods(x: torch.Tensor):
    """
    比较我们的triton实现和PyTorch原生实现的结果
    
    Args:
        x: 输入的float32 tensor
    """
    # 使用我们的triton kernel进行转换
    triton_result = float32_to_bf16(x)
    
    # 使用PyTorch内置方法进行转换
    pytorch_result = x.to(torch.bfloat16)
    
    # 比较Triton和PyTorch结果
    torch.testing.assert_close(triton_result, pytorch_result, atol=1e-6, rtol=1e-6)
    triton_vs_pytorch = torch.all(triton_result == pytorch_result)
    triton_vs_pytorch_max_diff = torch.max(torch.abs(triton_result.float() - pytorch_result.float())).item()
    
    print("转换结果比较:")
    print(f"Triton实现 vs PyTorch: 是否相同={triton_vs_pytorch}, 最大差异={triton_vs_pytorch_max_diff}")
    
    # 打印部分结果进行对比
    n_samples = min(10, x.numel())
    print("\n示例对比:")
    print("索引 | 原始float32 | Triton bf16 | PyTorch bf16")
    print("-" * 60)
    for i in range(n_samples):
        idx = i
        orig = x.flatten()[idx].item()
        triton_val = triton_result.flatten()[idx].float().item()
        pytorch_val = pytorch_result.flatten()[idx].float().item()
        print(f"{idx:4d} | {orig:12.6f} | {triton_val:12.6f} | {pytorch_val:12.6f}")

def performance_test(x: torch.Tensor, n_iterations=100):
    """
    比较不同实现的性能
    
    Args:
        x: 输入的float32 tensor
        n_iterations: 重复次数，用于测量平均性能
    """
    print("\n性能测试:")
    
    # 预热GPU
    for _ in range(10):
        _ = float32_to_bf16(x)
        _ = x.to(torch.bfloat16)
    
    # 同步GPU
    torch.cuda.synchronize()
    
    # 测试Triton实现
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(n_iterations):
        _ = float32_to_bf16(x)
    end.record()
    torch.cuda.synchronize()
    triton_time = start.elapsed_time(end) / n_iterations
    
    # 测试PyTorch内置实现
    start.record()
    for _ in range(n_iterations):
        _ = x.to(torch.bfloat16)
    end.record()
    torch.cuda.synchronize()
    pytorch_time = start.elapsed_time(end) / n_iterations
    
    print(f"Triton实现: {triton_time:.4f} ms")
    print(f"PyTorch实现: {pytorch_time:.4f} ms")
    
    # 计算相对性能
    if pytorch_time > 0 and triton_time > 0:
        ratio = pytorch_time / triton_time
        if ratio > 1:
            print(f"Triton比PyTorch快 {ratio:.2f}x")
        else:
            print(f"PyTorch比Triton快 {1/ratio:.2f}x")

# 主函数
def main():
    # 检查是否有可用的CUDA设备
    if not torch.cuda.is_available():
        print("错误：未检测到CUDA设备，此示例需要GPU才能运行")
        return
    
    # 设置随机种子以确保可重复性
    torch.manual_seed(42)
    
    # 创建测试数据
    # 包含各种类型的值
    special_values = torch.tensor([
        # 正常值
        1.0, -1.0, 3.14159, -2.71828, 
        # 大值和小值
        1e20, -1e20, 1e-20, -1e-20,
        # 特殊值
        float('inf'), float('-inf'), float('nan'),
        # 0及接近0的值
        0.0, -0.0, 1e-30, -1e-30
    ], dtype=torch.float32, device='cuda')
    
    # 添加一些随机值
    random_values = torch.randn(10000000, device='cuda')
    
    # 合并特殊值和随机值
    x = torch.cat([special_values, random_values])
    
    print(f"测试数据大小: {x.shape}")
    
    # 比较转换方法
    compare_conversion_methods(x)
    
    # 性能测试
    # 创建一个更大的tensor用于性能测试
    perf_x = torch.randn(10_000_000, device='cuda')
    performance_test(perf_x)

if __name__ == "__main__":
    main()