In [1]:
import torch
import triton
import triton.language as tl
from triton.language.extra import libdevice

DEVICE = 'cuda'

In [2]:
@triton.jit
def frexp(x_ptr, exp_ptr, mantissa_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_elements

    x = tl.load(x_ptr + offsets)

    y = libdevice.ilogb(x) + 1
    exponent = tl.where(x == 0, 0, y)
    mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))

    tl.store(exp_ptr + offsets, exponent, mask)
    tl.store(mantissa_ptr + offsets, mantissa, mask)

In [3]:
%env MLIR_ENABLE_DUMP=1
%env MLIR_DUMP_PATH=dump.txt
!rm -rf ~/.triton

torch.manual_seed(0)
size = 10
x = torch.rand(size, device=DEVICE)
exp_ptr = torch.zeros(size, device=DEVICE)
mantissa_ptr = torch.zeros(size, device=DEVICE)

grid = lambda meta: (triton.cdiv(size, meta['BLOCK_SIZE']), )
frexp[grid](x, exp_ptr, mantissa_ptr, size, 2)

print(exp_ptr)
print(mantissa_ptr)

torch_mantissa, torch_exp = torch.frexp(x)

print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.max(torch.abs(torch_mantissa - mantissa_ptr)), torch.max(torch.abs(torch_exp - exp_ptr)))}')

env: MLIR_ENABLE_DUMP=1
env: MLIR_DUMP_PATH=dump.txt


tensor([-1.,  0., -5.,  0.,  0.,  0., -1.,  0., -2.,  0.], device='cuda:0')
tensor([0.7981, 0.5167, 0.7978, 0.9401, 0.9459, 0.7967, 0.8300, 0.8203, 0.9162,
        0.9096], device='cuda:0')
The maximum difference between torch and triton is 0.0


In [6]:
@triton.jit
def frexp_real(x_ptr, exp_ptr, mantissa_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_elements

    x = tl.load(x_ptr + offsets, mask)

    mantissa = libdevice.frexp(x, exp_ptr + offsets)

    tl.store(mantissa_ptr + offsets, mantissa, mask)

In [7]:
%env MLIR_ENABLE_DUMP=1
%env MLIR_DUMP_PATH=dump_2.txt
!rm -rf ~/.triton

#torch.manual_seed(0)
size = 11
x = torch.rand(size, device=DEVICE, dtype=torch.float64)
exp_ptr = torch.zeros(size, device=DEVICE, dtype=torch.int32)
mantissa_ptr = torch.zeros(size, device=DEVICE)

grid = lambda meta: (triton.cdiv(size, meta['BLOCK_SIZE']), )
frexp_real[grid](x, exp_ptr, mantissa_ptr, size, 2)

print(exp_ptr)
print(mantissa_ptr)

torch_mantissa, torch_exp = torch.frexp(x)

print(torch_exp)
print(torch_mantissa)

print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.max(torch.abs(torch_mantissa - mantissa_ptr)), torch.max(torch.abs(torch_exp - exp_ptr)))}')

env: MLIR_ENABLE_DUMP=1
env: MLIR_DUMP_PATH=dump_2.txt


tensor([-1,  0,  0, -5, -1,  0, -1, -2,  0, -1, -4], device='cuda:0',
       dtype=torch.int32)
tensor([0.6389, 0.7433, 0.8156, 0.9996, 0.9641, 0.9541, 0.7813, 0.5511, 0.5693,
        0.9449, 0.9709], device='cuda:0')
tensor([-1,  0,  0, -5, -1,  0, -1, -2,  0, -1, -4], device='cuda:0',
       dtype=torch.int32)
tensor([0.6389, 0.7433, 0.8156, 0.9996, 0.9641, 0.9541, 0.7813, 0.5511, 0.5693,
        0.9449, 0.9709], device='cuda:0', dtype=torch.float64)
The maximum difference between torch and triton is 2.574586854819927e-08
