In [1]:
import torch
import triton
import triton.language as tl

DEVICE = 'cuda'

In [8]:
@triton.jit
def scalar_mult_duplicate(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask).reshape([1, BLOCK_SIZE]).broadcast_to(2, BLOCK_SIZE)
    output = 4 * x

    first_row_idx = block_start + tl.arange(0, BLOCK_SIZE)
    second_row_idx = block_start + tl.arange(0, BLOCK_SIZE) + n_elements
    first_row_mask = first_row_idx < n_elements
    second_row_mask = second_row_idx < (n_elements * 2)

    broadcasted_offsets = tl.trans(tl.join(first_row_idx, second_row_idx))
    broadcasted_mask = tl.trans(tl.join(first_row_mask, second_row_mask))

    tl.store(output_ptr + broadcasted_offsets, output, mask=broadcasted_mask)

In [9]:
def mult(x: torch.Tensor):
    output = torch.empty([2, x.shape[0]], device=DEVICE)
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    scalar_mult_duplicate[grid](x, output, n_elements, BLOCK_SIZE=1024)
    return output

In [7]:
!rm -rf dump.txt

In [10]:
%env MLIR_ENABLE_DUMP=1
%env MLIR_DUMP_PATH=dump.txt
!rm -rf ~/.triton

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device=DEVICE)
output_torch = 4 * x.broadcast_to([2, size])
output_triton = mult(x)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

env: MLIR_ENABLE_DUMP=1
env: MLIR_DUMP_PATH=dump.txt


LETSFUCKINGGO
ttir <function CUDABackend.add_stages.<locals>.<lambda> at 0x7f3e4400bd90>
 
"builtin.module"() ({
  "tt.func"() <{arg_attrs = [{tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}], function_type = (!tt.ptr<f32>, !tt.ptr<f32>, i32) -> (), sym_name = "scalar_mult_duplicate", sym_visibility = "public"}> ({
  ^bb0(%arg0: !tt.ptr<f32> , %arg1: !tt.ptr<f32> , %arg2: i32 ):
    %0 = "arith.constant"() <{value = dense<4.000000e+00> : tensor<2x1024xf32>}> : () -> tensor<2x1024xf32> 
    %1 = "arith.constant"() <{value = 2 : i32}> : () -> i32 
    %2 = "arith.constant"() <{value = 1024 : i32}> : () -> i32 
    %3 = "tt.get_program_id"() <{axis = 0 : i32}> : () -> i32 
    %4 = "arith.muli"(%3, %2) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32 
    %5 = "tt.make_range"() <{end = 1024 : i32, start = 0 : i32}> : () -> tensor<1024xi32> 
    %6 = "tt.splat"(%4) : (i32) -> tensor<1024xi32> 
    %7 = "arith.addi"(%6, %5) <{overflowFlag