You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import math
import torch
import torch
import intel_extension_for_pytorch as ipex
import triton
import triton.language as tl
class Backend:
device = 'xpu'
def sync(self):
torch.xpu.synchronize()
def check_device(*args):
return True
backend = Backend()
'''
Sources -
kernel_fma is based on Triton matmul tutorial in https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py
kenel_colsum is written from scratch
kernel_swiglu_fwd is written from scratch
kernel_swiglu_bwd is written from scratch
'''
# Implements Z = X @ Y + b
@triton.jit
def kernel_fma(
x_ptrs, y_ptrs, b_ptrs, z_ptrs,
M, N, K,
stride_xm, stride_xk,
stride_yk, stride_yn,
stride_zm, stride_zn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
BIAS_REQD: tl.constexpr,
):
# pid -> pid_m, pid_n
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# block pointers for x and y
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptrs + (offs_xm[:, None] * stride_xm + offs_k [None, :] * stride_xk)
y_ptrs = y_ptrs + (offs_k [:, None] * stride_yk + offs_yn[None, :] * stride_yn)
# initialize accumulator to zeros in SLM
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# loop over k dim
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# load input blocks from HBM
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
y = tl.load(y_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# accumulate product of input blocks
acc += tl.dot(x, y)
# increment block pointers
x_ptrs += BLOCK_SIZE_K * stride_xk
y_ptrs += BLOCK_SIZE_K * stride_yk
z = acc.to(tl.float16)
# add bias to accumulator
if BIAS_REQD:
bias = tl.load(b_ptrs + offs_yn, mask=offs_yn < N, other=0.0)
z += bias[None, :]
# store output block to HBM
offs_zm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_zn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
z_ptrs = z_ptrs + offs_zm[:, None] * stride_zm + offs_zn[None, :] * stride_zn
z_mask = (offs_zm[:, None] < M) & (offs_zn[None, :] < N)
tl.store(z_ptrs, z, mask=z_mask)
@triton.jit
def kernel_colsum(
x_ptrs, s_ptrs,
M, N,
stride_m, stride_n,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr
):
pid = tl.program_id(axis=0)
offs_m = tl.arange(0, BLOCK_SIZE_M)
offs_n = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x_ptrs = x_ptrs + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
# initialize accumulator to zeros in SLM
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for m in range(0, tl.cdiv(M, BLOCK_SIZE_M)):
# load input block from HBM
# x = tl.load(x_ptrs, mask=offs_m[:, None] + m * BLOCK_SIZE_M < M & offs_n < N, other=0.0)
# x = tl.load(x_ptrs, mask=offs_m[:, None] < M - m * BLOCK_SIZE_M & offs_n < N, other=0.0)
x = tl.load(x_ptrs, mask=offs_m[:, None] < M - m * BLOCK_SIZE_M, other=0.0)
acc += x
# increment block pointers
x_ptrs += BLOCK_SIZE_M * stride_m
r = tl.sum(acc, axis=0)
s_ptrs += offs_n
tl.store(s_ptrs, r)
@triton.jit
def kernel_swiglu_fwd(
x_ptrs, y_ptrs, z_ptrs,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptrs + offsets, mask).to(tl.float32)
y = tl.load(y_ptrs + offsets, mask).to(tl.float32)
u = tl.sigmoid(x)
v = x * u # silu
z = v * y # swiglu
z = z.to(tl.float16)
tl.store(z_ptrs + offsets, z, mask)
@triton.jit
def kernel_swiglu_bwd(
x_ptrs, y_ptrs, dz_ptrs, dx_ptrs, dy_ptrs,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptrs + offsets, mask).to(tl.float32)
y = tl.load(y_ptrs + offsets, mask).to(tl.float32)
dz = tl.load(dz_ptrs + offsets, mask).to(tl.float32)
u = tl.sigmoid(x)
v = x * u # silu
dy = dz * v
dt = dz * y # temp
dx = dt * u * (1.0 + x * (1.0 - u))
dx = dx.to(tl.float16)
dy = dy.to(tl.float16)
tl.store(dx_ptrs + offsets, dx, mask)
tl.store(dy_ptrs + offsets, dy, mask)
def fused_mul_add(X, Y, b, transpose_x, transpose_y):
if transpose_x:
K, M = X.shape
Xstride0, Xstride1 = X.stride(1), X.stride(0)
else:
M, K = X.shape
Xstride0, Xstride1 = X.stride(0), X.stride(1)
if transpose_y:
N, _ = Y.shape
Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
else:
_, N = Y.shape
Wstride0, Wstride1 = Y.stride(0), Y.stride(1)
# Allocates output.
Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
kernel_fma[grid](
X, Y, b, Z,
M, N, K,
Xstride0, Xstride1,
Wstride0, Wstride1,
Z.stride(0), Z.stride(1),
BLOCK_SIZE_M=128,
BLOCK_SIZE_N=128,
BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
BIAS_REQD=b is not None,
)
return Z
# Implements Z = XY
class Matmul(torch.autograd.Function):
@staticmethod
def forward(X, Y):
# Check constraints.
assert X.shape[1] == Y.shape[0], "Incompatible dimensions for X and Y"
assert X.is_contiguous(), "Matrix X must be contiguous"
assert Y.is_contiguous(), "Matrix Y must be contiguous"
return fused_mul_add(X, Y, None, transpose_x=False, transpose_y=False)
@staticmethod
def setup_context(ctx, inputs, output):
ctx.save_for_backward(*inputs)
@staticmethod
def backward(ctx, dZ):
X, Y = ctx.saved_tensors
# dX = dZ @ Y.T # (M x N) x (N x K)
dX = fused_mul_add(dZ, Y, None, transpose_x=False, transpose_y=True)
# dY = X.T @ dZ # (K x M) x (M x N)
dY = fused_mul_add(X, dZ, None, transpose_x=True, transpose_y=False)
return dX, dY
@triton.jit
def _softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
@triton.jit
def _softmax_backward_kernel(grad_input, grad_output, output, grad_input_stride, grad_out_stride, output_row_stride,
n_cols, BLOCK_SIZE: tl.constexpr):
# Parallelization across rows
row_idx = tl.program_id(0)
# Memory pointer calculations
row_start_ptr = grad_input + row_idx * grad_input_stride
grad_output_row_start_ptr = grad_output + row_idx * grad_out_stride
output_row_start_ptr = output + row_idx * output_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
# Memmory addresses of all the elements we want to load
grad_output_ptrs = grad_output_row_start_ptr + col_offsets
output_ptrs = output_row_start_ptr + col_offsets
# Load relevant data
o = tl.load(output_ptrs, mask=col_offsets < n_cols)
g = tl.load(grad_output_ptrs, mask=col_offsets < n_cols)
# Using cross-entropy loss
# Step1: Compute intermediate sum used for gradient
s = tl.sum(g * o, 0)
# Step1: Compute the gradients
grad_input = o * (g - s)
grad_input_ptrs = row_start_ptr + col_offsets
tl.store(grad_input_ptrs, grad_input, mask=col_offsets < n_cols)
# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
class Softmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
_softmax_kernel[(n_rows, )](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, grad_out):
(out,) = ctx.saved_tensors
n_rows, n_cols = out.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# torch.zeros is measurably slower, we'll zero out in the kernel
grad_in = torch.empty_like(out)
# Make sure that the tensor are contiguous
grad_in, grad_out, out = map(lambda x: x.contiguous(), [
grad_in, grad_out, out])
_softmax_backward_kernel[(n_rows, )](
grad_in, grad_out, out,
grad_in.stride(0),
grad_out.stride(0),
out.stride(0),
n_cols,
BLOCK_SIZE,
)
return grad_in.reshape_as(grad_out)
class TritonAttentionHead(torch.nn.Module):
def __init__(self, config, name=None) -> None:
super().__init__()
self.config = config
self.name = name
self.dK = int(self.config.emb_size/self.config.no_of_heads)
self.sqrt_dk = math.sqrt(self.dK)
self.wQ = torch.nn.Parameter(torch.rand([self.config.emb_size, self.dK],
device=backend.device, dtype=self.config.dtype, requires_grad=True))
self.wK = torch.nn.Parameter(torch.rand_like(self.wQ))
self.wV = torch.nn.Parameter(torch.rand_like(self.wQ))
self.matmul_triton = Matmul.apply
self.softmax_triton = Softmax.apply
if self.config.mask:
self.mask = torch.triu(
torch.ones(self.config.no_of_embs, self.config.no_of_embs, device=backend.device,
dtype=self.config.dtype, requires_grad=True),
diagonal=1)
self.mask[self.mask.bool()] = -float('inf')
#print('Mask: ', self.mask)
def forward(self, embs_for_q, embs_for_k, embs_for_v) -> torch.tensor:
assert backend.check_device(embs_for_q)
q = self.matmul_triton(embs_for_q, self.wQ)
k = self.matmul_triton(embs_for_k, self.wK)
v = self.matmul_triton(embs_for_v, self.wV)
# # TODO: multiply by 1/sqrt(dK) and use Triton for qkT computation
qkT = torch.mm(q, torch.transpose(k, 0, 1))
if self.config.mask:
qkT = qkT + self.mask
softmaxed = self.softmax_triton(qkT)
output = self.matmul_triton(softmaxed, v)
return output
class TritonMultiHeadAttention(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.attention_heads: TritonAttentionHead = []
for x in range(config.no_of_heads):
self.attention_heads.append(TritonAttentionHead(config, "attention_head_{x}"))
self.dK = int(self.config.emb_size/self.config.no_of_heads)
self.wO = torch.nn.Parameter(torch.rand([self.config.emb_size, self.config.emb_size],
device=backend.device, dtype=self.config.dtype, requires_grad=True))
self.matmul_triton = Matmul.apply
def forward(self, embs_for_q, embs_for_k, embs_for_v):
list_of_z = []
for x in range(self.config.no_of_heads):
list_of_z.append(self.attention_heads[x](embs_for_q, embs_for_k, embs_for_v))
print(x)
backend.sync()
print("sync finished")
concatZ = torch.cat(list_of_z, dim=1)
output = self.matmul_triton(concatZ, self.wO)
return output
def fwd_bwd_triton_full_attention(
embs: torch.tensor, triton_multi_head_attention: TritonMultiHeadAttention
):
pred_y = triton_multi_head_attention(embs, embs, embs)
grad_output = torch.rand_like(pred_y)
pred_y.backward(grad_output, retain_graph=True)
return pred_y, grad_output
class Config:
no_of_embs = 1024
emb_size = 4096
no_of_heads = 32
mask = True
separate_kernels = True
dtype = torch.float16
def main():
config = Config()
embds = torch.rand(
[config.no_of_embs, config.emb_size],
device=backend.device,
dtype=config.dtype,
requires_grad=True,
)
head = TritonMultiHeadAttention(config)
res = fwd_bwd_triton_full_attention(
embds, head
)
print(res)
if __name__ == '__main__':
main()
The text was updated successfully, but these errors were encountered:
Execution hangs with
float16
type onxpu
.It works with cuda or with
float32
type.Code:
The text was updated successfully, but these errors were encountered: