In [1]:
import torch
import triton
import triton.language as tl

In [None]:
@triton.jit
def quantize_kernel(
    x_ptr,        # Pointer to the input tensor (float32)
    output_ptr,   # Pointer to the output tensor (int8)
    scale_ptr,    # Pointer to the scaling factor
    n_elements,   # Number of elements in the tensor
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process
):
    # Program ID
    pid = tl.program_id(axis=0)

    # Calculate the start offset for this program
    offset = pid * BLOCK_SIZE

    # Create a mask to handle the case where n_elements is not a multiple of BLOCK_SIZE
    mask = offset + tl.arange(0, BLOCK_SIZE) < n_elements

    # Load input x
    x = tl.load(x_ptr + offset, mask=mask)

    # Load scale
    scale = tl.load(scale_ptr)

    # Quantize: scale and clamp
    quantized = tl.math.round(x / scale)
    quantized = tl.math.min(127, tl.math.max(-128, quantized))

    # Store output
    tl.store(output_ptr + offset, quantized, mask=mask)

@triton.jit
def dequantize_kernel(
    quant_ptr,     # Pointer to the quantized tensor (int8)
    output_ptr,    # Pointer to the output tensor (float32)
    scale_ptr,     # Pointer to the scaling factor
    n_elements,    # Number of elements in the tensor
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process
):
    # Program ID
    pid = tl.program_id(axis=0)

    # Calculate the start offset for this program
    offset = pid * BLOCK_SIZE

    # Create a mask to handle the case where n_elements is not a multiple of BLOCK_SIZE
    mask = offset + tl.arange(0, BLOCK_SIZE) < n_elements

    # Load quantized input
    quant = tl.load(quant_ptr + offset, mask=mask)

    # Load scale
    scale = tl.load(scale_ptr)

    # Dequantize: multiply by scale
    dequantized = quant * scale

    # Store output
    tl.store(output_ptr + offset, dequantized, mask=mask)

In [None]:
# Wrapper functions to call the triton kernels
def quantize(x, scale=None):
    """
    Quantize a tensor from float32 to int8 using a per-tensor scaling factor.

    Args:
        x: Input tensor (float32)
        scale: Optional scaling factor. If None, will be calculated as max(abs(x))/127

    Returns:
        Quantized tensor (int8) and the scale used
    """
    if scale is None:
        scale = torch.max(torch.abs(x)) / 127.0

    # Ensure scale is a tensor
    if not isinstance(scale, torch.Tensor):
        scale = torch.tensor([scale], device=x.device, dtype=torch.float32)

    # Output tensor
    output = torch.empty_like(x, dtype=torch.int8)

    # Calculate grid and block sizes
    n_elements = x.numel()
    BLOCK_SIZE = 1024
    grid = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE

    # Launch kernel
    quantize_kernel[grid, 1](
        x.data_ptr(),
        output.data_ptr(),
        scale.data_ptr(),
        n_elements,
        BLOCK_SIZE,
    )

    return output, scale

In [None]:
def dequantize(x_quant, scale):
    # Ensure scale is a tensor
    if not isinstance(scale, torch.Tensor):
        scale = torch.tensor([scale], device=x_quant.device, dtype=torch.float32)

    # Output tensor
    output = torch.empty_like(x_quant, dtype=torch.float32)

    # Calculate grid and block sizes
    n_elements = x_quant.numel()
    BLOCK_SIZE = 1024
    grid = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE

    # Launch kernel
    dequantize_kernel[grid, 1](
        x_quant.data_ptr(),
        output.data_ptr(),
        scale.data_ptr(),
        n_elements,
        BLOCK_SIZE,
    )

    return output

In [None]:
# Create a random tensor
x = torch.randn(1000000, device="cuda")

# Quantize
x_quant, scale = quantize(x)

# Dequantize
x_dequant = dequantize(x_quant, scale)

# Calculate error
error = torch.abs(x - x_dequant).mean()
print(f"Mean absolute error: {error.item()}")