Skip to content

Commit

Permalink
[feat] layernorm: add non-affine (#53)
Browse files Browse the repository at this point in the history
author: @ClashLuke
  • Loading branch information
ClashLuke committed Oct 29, 2021
1 parent f650ab7 commit 23389bd
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 55 deletions.
187 changes: 135 additions & 52 deletions xformers/triton/k_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,32 @@
_triton_registered_warnings = False


# fmt: off
@triton.jit
def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, **META):
def _affine(W, B, N, x, META):
cols = tl.arange(0, META["BLOCK_SIZE_N"])

w = tl.load(W + cols, mask=cols < N, other=1.0)
zero = 0.0
zero = zero.to(w.dtype) # Triton bug workarounds
w = tl.where(cols < N, w, zero)

b = tl.load(B + cols, mask=cols < N, other=0.0)
b = tl.where(cols < N, b, zero)
y = x * w + b
return y


@triton.jit
def _store(y, Y, stride, N, META):
row = tl.program_id(0)
cols = tl.arange(0, META["BLOCK_SIZE_N"])

y_ptrs = Y + row * stride + cols
tl.store(y_ptrs, y, mask=cols < N)


@triton.jit
def _layer_norm_non_affine(X, M, V, stride, N, eps, META):
# fmt: on
"""
Fused layernorm kernel over a 3d tensor.
Expand All @@ -36,7 +59,7 @@ def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, **META):
# Move to this row
x_ptrs = X + row * stride + cols
x = tl.load(x_ptrs, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x, 0.) # Triton bug workarounds
x = tl.where(cols < N, x, 0.0) # Triton bug workarounds

# Compute variance
x_mean = tl.sum(x, axis=0) / N
Expand All @@ -49,29 +72,40 @@ def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, **META):
tl.store(M + row, x_mean)
tl.store(V + row, x_inv_sigma)

# Normalize the inputs
w = tl.load(W + cols, mask=cols < N, other=1.0)
zero = 0.
zero = zero.to(w.dtype) # Triton bug workarounds
w = tl.where(cols < N, w, zero)
return x_zm * x_inv_sigma

b = tl.load(B + cols, mask=cols < N, other=0.0)
b = tl.where(cols < N, b, zero)
y = x_zm * x_inv_sigma * w + b

# write back to Y.
y_ptrs = Y + row * stride + cols
tl.store(y_ptrs, y, mask=cols < N)
# fmt: off
@triton.jit
def _layer_norm_non_affine_fw(X, Y, M, V, stride, N, eps, **META):
_store(_layer_norm_non_affine(X, M, V, stride, N, eps, META), Y, stride, N, META)


# fmt: off
@triton.jit
def _layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, **META):
# fmt: on
"""
Fused layernorm kernel over a 3d tensor.
The layer norm is applied over the last dimension.
Compute
y = (x - E(x))/(sqrt(var(x) + epsilon)) * gamma + beta
"""
y = _layer_norm_non_affine(X, M, V, stride, N, eps, META)
y = _affine(W, B, N, y, META)

_store(y, Y, stride, N, META)


# Backward pass (DX + partial DW + partial DB)
# fmt: off
@triton.jit
def _layer_norm_bwd_dx_fused(
DX, DY, DW, DB,
X, W, M, V,
Lock, stride, N,
**META
DX, DY, DW, DB,
Y, W, B, V,
Lock, stride, N,
**META
):
# fmt: on

Expand All @@ -83,9 +117,10 @@ def _layer_norm_bwd_dx_fused(
cols = tl.arange(0, BLOCK_SIZE_N)

# offset data pointers to start at the row of interest
x_ptrs = X + row * stride + cols
y_ptrs = Y + row * stride + cols
dy_ptrs = DY + row * stride + cols
w_ptrs = W + cols
b_ptrs = B + cols

# offset locks and weight/bias gradient pointer
# each kernel instance accumulates partial sums for
Expand All @@ -97,15 +132,15 @@ def _layer_norm_bwd_dx_fused(
Count = Lock + GROUP_SIZE_M

# load data to SRAM
x = tl.load(x_ptrs, mask=cols < N, other=0).to(tl.float32)
y = tl.load(y_ptrs, mask=cols < N, other=0).to(tl.float32)
dy = tl.load(dy_ptrs, mask=cols < N, other=0).to(tl.float32)
w = tl.load(w_ptrs, mask=cols < N, other=0).to(tl.float32)
b = tl.load(b_ptrs, mask=cols < N, other=0).to(tl.float32)

mean = tl.load(M + row)
rstd = tl.load(V + row)

# compute dx
xhat = (x - mean) * rstd
xhat = (y - b) / w
wdy = w * dy
xhat = tl.where(cols < N, xhat, 0.0)
wdy = tl.where(cols < N, wdy, 0.0)
Expand All @@ -114,13 +149,11 @@ def _layer_norm_bwd_dx_fused(
dx = (wdy - (xhat * mean1 + mean2)) * rstd

# write-back dx
cols = tl.arange(0, BLOCK_SIZE_N)
dx_ptrs = DX + row * stride + cols
tl.store(dx_ptrs, dx, mask=cols < N)
_store(dx, DX, stride, N, META)

# accumulate partial sums for dw/db
partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)
partial_db = dy.to(w.dtype)

# - wait for a lock on the accumulated dw/db
while tl.atomic_cas(Lock, 0, 1) == 1:
Expand All @@ -146,6 +179,40 @@ def _layer_norm_bwd_dx_fused(
tl.atomic_xchg(Lock, 0)


@triton.jit
def _layer_norm_no_affine_bwd(
DX, DY,
Y, V,
stride, N,
**META
):
# fmt: on

# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, META["BLOCK_SIZE_N"])

# offset data pointers to start at the row of interest
y_ptrs = Y + row * stride + cols
dy_ptrs = DY + row * stride + cols

# load data to SRAM
y = tl.load(y_ptrs, mask=cols < N, other=0).to(tl.float32)
dy = tl.load(dy_ptrs, mask=cols < N, other=0).to(tl.float32)

rstd = tl.load(V + row)

# compute dx
xhat = tl.where(cols < N, y, 0.0)
wdy = tl.where(cols < N, dy, 0.0)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * mean1 + mean2)) * rstd

# write-back dx
_store(dx, DX, stride, N, META)


# Backward pass (total DW + total DB)
# fmt: off
@triton.jit
Expand Down Expand Up @@ -178,7 +245,6 @@ class _LayerNorm(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16 if _triton_layernorm_fp16_enabled else None)
def forward(ctx, x, weight, bias, eps):

# allocate output
y = torch.empty_like(x)

Expand Down Expand Up @@ -212,30 +278,41 @@ def forward(ctx, x, weight, bias, eps):

# enqueue kernel
# fmt: off
layer_norm_fw[(M,)](
x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0),
N,
eps,
num_warps=num_warps,
BLOCK_SIZE_N=BLOCK_SIZE_N
)
if weight is None:
_layer_norm_non_affine_fw[(M,)](
x_arg, y, mean, rstd,
x_arg.stride(0),
N,
eps,
num_warps=num_warps,
BLOCK_SIZE_N=BLOCK_SIZE_N
)
else:
_layer_norm_fw[(M,)](
x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0),
N,
eps,
num_warps=num_warps,
BLOCK_SIZE_N=BLOCK_SIZE_N
)
# fmt: on

ctx.save_for_backward(x, weight, mean, rstd)
ctx.save_for_backward(y, rstd, weight, bias)
ctx.BLOCK_SIZE_N = BLOCK_SIZE_N
ctx.num_warps = num_warps
ctx.eps = eps
ctx.N = N

return y.reshape_as(x)

@staticmethod
@custom_bwd
def backward(ctx, dy):
x, weight, mean, var = ctx.saved_tensors
y, var, weight, bias = ctx.saved_tensors

# heuristics for amount of parallel reduction stream for DG/DB
N = weight.shape[0]
N = y.size(-1)
GROUP_SIZE_M = 64
if N <= 8192:
GROUP_SIZE_M = 96
Expand All @@ -246,38 +323,44 @@ def backward(ctx, dy):

# flatten the batch dimension, if any.
# We're interested in 'samples' x norm_dimension
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
y = y.reshape(-1, y.size(-1))
M, N = y.size()

# allocate output
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device="cuda")
t_args = {"dtype" : x.dtype, "device" : x.device}
_dw = torch.empty((GROUP_SIZE_M, weight.shape[0]), **t_args)
_db = torch.empty((GROUP_SIZE_M, weight.shape[0]), **t_args)
dw = torch.empty((weight.shape[0],), **t_args)
db = torch.empty((weight.shape[0],), **t_args)
t_args = {"dtype": y.dtype, "device": y.device}
_dw = torch.empty((GROUP_SIZE_M, y.size(-1)), **t_args)
_db = torch.empty((GROUP_SIZE_M, y.size(-1)), **t_args)
dw = torch.empty((y.size(-1),), **t_args)
db = torch.empty((y.size(-1),), **t_args)
dy = dy.contiguous()
dx = torch.empty_like(dy)

# Check the tensor shapes and layouts
# we suppose in the kernel that they have the same size and are contiguous
assert dx.numel() == x.numel(), \
assert dx.numel() == y.numel(), \
"Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm"

# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB

# fmt: off
meta = {"BLOCK_SIZE_N": ctx.BLOCK_SIZE_N,
"GROUP_SIZE_M": GROUP_SIZE_M,
"num_warps": ctx.num_warps}
if weight is None:
_layer_norm_no_affine_bwd[(M,)](dx, dy, y, var, y.stride(0), N, **meta)
return dx, None, None, None

_layer_norm_bwd_dx_fused[(M,)](
dx, dy, _dw, _db,
x_arg, weight, mean, var,
y, weight, bias, var,
locks,
x_arg.stride(0),
y.stride(0),
N,
BLOCK_SIZE_N=ctx.BLOCK_SIZE_N,
GROUP_SIZE_M=GROUP_SIZE_M,
num_warps=ctx.num_warps
**meta
)

# fmt: on

def grid(meta):
Expand All @@ -294,5 +377,5 @@ def grid(meta):
)
# fmt: on

dx = dx.reshape_as(x)
dx = dx.reshape_as(dy)
return dx, dw, db, None
9 changes: 6 additions & 3 deletions xformers/triton/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ class FusedLayerNorm(nn.Module):
"""

def __init__(self, normalized_shape, eps=1e-05):
def __init__(self, normalized_shape, affine=True, eps=1e-05):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
if affine:
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
else:
self.weight = self.bias = None
self.epsilon = eps

def forward(self, x):
Expand Down

0 comments on commit 23389bd

Please sign in to comment.