In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [None]:
import torch
import triton
import triton.language as tl

@triton.jit
def lora_matmul_kernel(
        input_ptr, w_ptr, w1_ptr, w2_ptr, output_ptr,
        M, N, K, R,
        stride_im, stride_ik,
        stride_wk, stride_wn,
        stride_w1k, stride_w1r,
        stride_w2r, stride_w2n,
        stride_om, stride_on,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, 
        BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_R: tl.constexpr,
        GROUP_SIZE_M: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    offs_r = tl.arange(0, BLOCK_SIZE_R)
    
    # Use separate accumulators for base path and LoRA path
    base_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    lora_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    
    # Compute base path: input × W
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        k_offs = k * BLOCK_SIZE_K + offs_k
        a_ptrs = input_ptr + offs_m[:, None] * stride_im + k_offs[None, :] * stride_ik
        w_ptrs = w_ptr + k_offs[:, None] * stride_wk + offs_n[None, :] * stride_wn
        
        mask_a = (offs_m[:, None] < M) & (k_offs[None, :] < K)
        mask_b = (k_offs[:, None] < K) & (offs_n[None, :] < N)
        
        a = tl.load(a_ptrs, mask=mask_a, other=0.0)
        b = tl.load(w_ptrs, mask=mask_b, other=0.0)
        
        base_acc += tl.dot(a, b)

    # Compute LoRA path: input × W1 × W2
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        k_offs = k * BLOCK_SIZE_K + offs_k
        a_ptrs = input_ptr + offs_m[:, None] * stride_im + k_offs[None, :] * stride_ik
        mask_a = (offs_m[:, None] < M) & (k_offs[None, :] < K)
        a = tl.load(a_ptrs, mask=mask_a, other=0.0)
            
        w1_ptrs = w1_ptr + k_offs[:, None] * stride_w1k + offs_r[None, :] * stride_w1r
        w2_ptrs = w2_ptr + offs_r[:, None] * stride_w2r + offs_n[None, :] * stride_w2n

        mask_w1 = (k_offs[:, None] < K) & (offs_r[None, :] < R)
        mask_w2 = (offs_r[:, None] < R) & (offs_n[None, :] < N)

        w1 = tl.load(w1_ptrs, mask=mask_w1, other=0.0)
        w2 = tl.load(w2_ptrs, mask=mask_w2, other=0.0)

        temp = tl.dot(a, w1)
        lora_acc += tl.dot(temp.to(w2.dtype), w2)

    # Combine results and write to output
    output = base_acc + lora_acc
    output_ptrs = output_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(output_ptrs, output, mask=mask_out)

def lora_matmul(input, weight, lora_weight1, lora_weight2):
    """
    Compute matrix multiplication with LoRA: output = input × weight + input × lora_weight1 × lora_weight2
    
    Arguments:
        input: torch.Tensor of shape (M, K)
        weight: torch.Tensor of shape (K, N)
        lora_weight1: torch.Tensor of shape (K, R)
        lora_weight2: torch.Tensor of shape (R, N)
    Returns:
        output: torch.Tensor of shape (M, N)
    """
    # Check input dimensions
    assert input.shape[1] == weight.shape[0], "Input and weight dimensions mismatch"
    assert input.shape[1] == lora_weight1.shape[0], "Input and LoRA W1 dimensions mismatch"
    assert lora_weight1.shape[1] == lora_weight2.shape[0], "LoRA W1 and W2 dimensions mismatch"
    assert weight.shape[1] == lora_weight2.shape[1], "Weight and LoRA W2 dimensions mismatch"
    
    # Extract dimensions
    M, K = input.shape
    _, N = weight.shape
    R = lora_weight1.shape[1]
    
    # Allocate output
    output = torch.empty((M, N), device=input.device, dtype=input.dtype)
    
    # Define block sizes and make sure they're appropriate for the GPU
    BLOCK_SIZE_M = 32
    BLOCK_SIZE_N = 32
    BLOCK_SIZE_K = 32
    BLOCK_SIZE_R = R
    GROUP_SIZE_M = 8
    
    # Calculate grid size
    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
    )
    
    # Launch kernel
    lora_matmul_kernel[grid](
        input, weight, lora_weight1, lora_weight2, output,
        M, N, K, R,
        input.stride(0), input.stride(1),
        weight.stride(0), weight.stride(1),
        lora_weight1.stride(0), lora_weight1.stride(1),
        lora_weight2.stride(0), lora_weight2.stride(1),
        output.stride(0), output.stride(1),
        BLOCK_SIZE_M=BLOCK_SIZE_M,
        BLOCK_SIZE_N=BLOCK_SIZE_N,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
        BLOCK_SIZE_R=BLOCK_SIZE_R,
        GROUP_SIZE_M=GROUP_SIZE_M,
    )
    
    return output

In [None]:
torch.manual_seed(0)
    
M = 128
K = 512
N = 256
R = 128

device = torch.device('cuda')
input = torch.randn((M, K), device=device, dtype=torch.float16)
weight = torch.randn((K, N), device=device, dtype=torch.float16)
lora_w1 = torch.randn((K, R), device=device, dtype=torch.float16)
lora_w2 = torch.randn((R, N), device=device, dtype=torch.float16)

In [None]:
%%time

output_triton = lora_matmul(input, weight, lora_w1, lora_w2)

In [None]:
%%time

output_torch = torch.matmul(input, weight) + torch.matmul(torch.matmul(input, lora_w1), lora_w2)

In [None]:
output_triton

In [None]:
output_torch

In [None]:
print(f"Max difference: {torch.max(torch.abs(output_triton - output_torch))}")

In [None]:
(output_torch.sign() == output_triton.sign()).float().mean()

In [None]:
output_torch.sign()

In [None]:
output_triton.sign()

In [None]:
torch.where(~((output_torch.sign() == output_triton.sign())))