In [None]:
# exp_lut.py
import torch
import exp_lut_cuda  # The compiled extension

def create_exp_lut():
    # Create all possible FP16 bit patterns
    bit_patterns = torch.arange(0, 65536, dtype=torch.uint16)
    fp16_values = bit_patterns.view(torch.float16)

    # Compute exp in higher precision to reduce errors
    fp32_values = fp16_values.to(torch.float32)
    exp_values = torch.exp(fp32_values)

    # Cast back to FP16 for the LUT
    exp_values_fp16 = exp_values.to(torch.float16)

    # Transfer LUT to GPU
    exp_lut = exp_values_fp16.cuda()

    return exp_lut

def exp_lut(input_tensor, lut):
    """
    Compute exp using LUT.

    Args:
        input_tensor (torch.Tensor): Input tensor of dtype float16 on CUDA.
        lut (torch.Tensor): Precomputed LUT of dtype float16 on CUDA.

    Returns:
        torch.Tensor: Output tensor with exp applied, dtype float16 on CUDA.
    """
    assert input_tensor.dtype == torch.float16, "Input must be float16"
    assert input_tensor.is_cuda, "Input must be on CUDA"

    output = torch.empty_like(input_tensor)
    exp_lut_cuda.exp_lut_cuda(input_tensor, output, lut)
    return output


In [None]:
# example_usage.py
import torch
from exp_lut import create_exp_lut, exp_lut

def main():
    # Initialize LUT
    print("Creating exp LUT...")
    exp_lookup_table = create_exp_lut()

    # Create a sample input tensor
    print("Creating input tensor...")
    input_size = 1024 * 1024  # 1 million elements
    input_fp16 = torch.randn(input_size, device='cuda', dtype=torch.float16)

    # Compute exp using LUT
    print("Computing exp using LUT...")
    output_fp16 = exp_lut(input_fp16, exp_lookup_table)

    # Compute exp using PyTorch for verification
    print("Computing exp using PyTorch...")
    output_ref = torch.exp(input_fp16.to(torch.float32)).to(torch.float16)

    # Compare results
    print("Comparing results...")
    max_diff = (output_fp16.to(torch.float32) - output_ref.to(torch.float32)).abs().max()
    print(f"Maximum difference between LUT and PyTorch: {max_diff}")

    # Check if within acceptable tolerance
    if torch.allclose(output_fp16, output_ref, atol=1e-2):
        print("LUT-based exp matches PyTorch's exp within tolerance.")
    else:
        print("LUT-based exp does NOT match PyTorch's exp within tolerance.")

if __name__ == "__main__":
    main()


In [136]:
import torch

def next_fp16(x, dtype=torch.float16, increment=1):
    """
    gets immediate next number representable in FP16
    """
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x, dtype=dtype)
    else:
        x = x.to(dtype)
    
    # View the bits as int16
    x_bits = x.view(torch.int16)
    next_bits = x_bits + increment
    next_x = next_bits.view(dtype)
    
    return next_x


def get_allowed_exp_range(dtype=torch.bfloat16, low_threshold=1e-5, high_threshold=65536 * 16):
    # Create all FP16 values
    bit_patterns = torch.arange(-65536//2,65536//2, dtype=torch.int16)
    fp16_values = bit_patterns.view(dtype)

    # exp
    exp_fp16 = torch.exp(fp16_values) 

    # where it goes to inf
    if high_threshold is not None:
        is_inf = exp_fp16 > high_threshold
    is_inf = is_inf | torch.isinf(exp_fp16) #| torch.isnan(exp_fp16)

    # where it goes to zero
    zero_thresh = low_threshold or torch.finfo(dtype).tiny
    is_zero = exp_fp16 < zero_thresh

    # where it goes to one,
    # find number representable just above 1
    next_fp16_one = next_fp16(1, dtype=dtype, increment=1)
    previous_fp16_one = next_fp16(1, dtype=dtype, increment=-1)
    is_one = (exp_fp16 < next_fp16_one) & (exp_fp16 > previous_fp16_one)

    # Determine the thresholds
    x_inf = fp16_values[is_inf].min()
    x_zero = fp16_values[is_zero].max()
    x_one = fp16_values[is_one].max()

    is_nan = torch.isnan(fp16_values)

    # remaining available values
    mask = ~(is_inf | is_zero | is_one | is_nan)
    allowed_fp16_values = fp16_values[mask]
    allowed_exp_fp16 = exp_fp16[mask]
    allowed_bits = bit_patterns[mask]

    print(f"Threshold for exp(x) = +inf: x >= {x_inf.item()}")
    print(f"Threshold for exp(x) = 0: x <= {x_zero.item()}")
    print(f"Threshold for exp(x) = 1: x <= {x_one.item()}")

    lut_indices = torch.arange(0, allowed_fp16_values.numel(), dtype=torch.int16)
    mapping_table = torch.zeros(65536, dtype=torch.int16)
    mapping_table[allowed_bits] = lut_indices

    # Define special case codes:
    # -1: +inf, -2: 1, -3: 0, -4 nan
    SPECIAL_INF = -1
    SPECIAL_ONE = -2
    SPECIAL_ZERO = -3
    SPECIAL_NAN = -4

    mapping_table[is_inf] = -1
    mapping_table[is_one] = -2
    mapping_table[is_zero] = -3
    mapping_table[is_nan] = -4

    return allowed_fp16_values, allowed_exp_fp16, allowed_bits, is_inf, is_zero, is_one, is_nan, mapping_table



Threshold for exp(x) = +inf: x >= 13.875
Threshold for exp(x) = 0: x <= -11.5625
Threshold for exp(x) = 1: x <= 0.0038909912109375


TypeError: 'torch.dtype' object is not callable