In [None]:
# exp_lut.py
import torch
import exp_lut_cuda  # The compiled extension

def create_exp_lut():
    # Create all possible FP16 bit patterns
    bit_patterns = torch.arange(0, 65536, dtype=torch.uint16)
    fp16_values = bit_patterns.view(torch.float16)

    # Compute exp in higher precision to reduce errors
    fp32_values = fp16_values.to(torch.float32)
    exp_values = torch.exp(fp32_values)

    # Cast back to FP16 for the LUT
    exp_values_fp16 = exp_values.to(torch.float16)

    # Transfer LUT to GPU
    exp_lut = exp_values_fp16.cuda()

    return exp_lut

def exp_lut(input_tensor, lut):
    """
    Compute exp using LUT.

    Args:
        input_tensor (torch.Tensor): Input tensor of dtype float16 on CUDA.
        lut (torch.Tensor): Precomputed LUT of dtype float16 on CUDA.

    Returns:
        torch.Tensor: Output tensor with exp applied, dtype float16 on CUDA.
    """
    assert input_tensor.dtype == torch.float16, "Input must be float16"
    assert input_tensor.is_cuda, "Input must be on CUDA"

    output = torch.empty_like(input_tensor)
    exp_lut_cuda.exp_lut_cuda(input_tensor, output, lut)
    return output


In [None]:
# example_usage.py
import torch
from exp_lut import create_exp_lut, exp_lut

def main():
    # Initialize LUT
    print("Creating exp LUT...")
    exp_lookup_table = create_exp_lut()

    # Create a sample input tensor
    print("Creating input tensor...")
    input_size = 1024 * 1024  # 1 million elements
    input_fp16 = torch.randn(input_size, device='cuda', dtype=torch.float16)

    # Compute exp using LUT
    print("Computing exp using LUT...")
    output_fp16 = exp_lut(input_fp16, exp_lookup_table)

    # Compute exp using PyTorch for verification
    print("Computing exp using PyTorch...")
    output_ref = torch.exp(input_fp16.to(torch.float32)).to(torch.float16)

    # Compare results
    print("Comparing results...")
    max_diff = (output_fp16.to(torch.float32) - output_ref.to(torch.float32)).abs().max()
    print(f"Maximum difference between LUT and PyTorch: {max_diff}")

    # Check if within acceptable tolerance
    if torch.allclose(output_fp16, output_ref, atol=1e-2):
        print("LUT-based exp matches PyTorch's exp within tolerance.")
    else:
        print("LUT-based exp does NOT match PyTorch's exp within tolerance.")

if __name__ == "__main__":
    main()
