Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added SmeLU #263

Merged
merged 16 commits into from May 10, 2022
Merged
20 changes: 20 additions & 0 deletions xformers/components/activations.py
Expand Up @@ -16,6 +16,7 @@ class Activation(str, Enum):
GeLU = "gelu"
LeakyReLU = "leaky_relu"
ReLU = "relu"
SmeLU = "smelu"


# For unit testing / parity comparisons, probably not the fastest way
Expand All @@ -28,6 +29,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x_ * x_


class SmeLU(nn.Module):
def __init__(self, beta: float = 2.0) -> None:
super().__init__()
self.beta = beta

def forward(self, x: torch.Tensor) -> torch.Tensor:
relu = torch.where(
x >= self.beta,
x,
torch.tensor([0.0], device=x.device, dtype=x.dtype),
)
return torch.where(
torch.abs(x) <= self.beta,
((x + self.beta) ** 2) / (4.0 * self.beta),
relu,
)


class Passthrough(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand All @@ -45,4 +64,5 @@ def build_activation(activation: Optional[Activation]):
Activation.GeLU: nn.GELU,
Activation.LeakyReLU: nn.LeakyReLU,
Activation.SquaredReLU: SquaredReLU,
Activation.SmeLU: SmeLU,
}[activation]()
28 changes: 28 additions & 0 deletions xformers/triton/k_activations.py
Expand Up @@ -21,6 +21,7 @@ def get_triton_activation_kernel(activation: Optional[Activation]):
Activation.LeakyReLU: leaky_relu,
Activation.GeLU: gelu,
Activation.SquaredReLU: squared_relu,
Activation.SmeLU: smelu,
}[activation]
if activation
else None
Expand All @@ -34,6 +35,7 @@ def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
Activation.LeakyReLU: leaky_relu_grad,
Activation.GeLU: gelu_grad,
Activation.SquaredReLU: squared_relu_grad,
Activation.SmeLU: smelu_grad,
}[activation]
if activation
else None
Expand Down Expand Up @@ -135,3 +137,29 @@ def gelu_grad(x):
return 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)


@triton.jit
def smelu(x, beta=2.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that you can pass a default param with triton actually, it only works with a subset of the python syntax and my guess is that this is out of it (cc @ptillet). Something could be worth trying, having a getter for this kernel, like the following

def get_smelu_kernel(beta: float = 2.0): @triton.jit def smelu(x): pass # use beta here, but maybe that this will fail at the JIT phase

If that does not work,

  • for a start we could have a fixed beta, then iterate on the implementation to expose it (completely fine by me)
  • could be that the activation kernel take another parameter, which in that case would be the beta value, or that we figure out with Phil how to generate the kernel code on the fly with the proper beta

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @blefaudeux I'll give it a try... a bit late here so wanted to give it a shot in the morning 馃槾

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking this with Philippe, the default value should work actually, maybe that it needs to be : float ? or similar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok! cool let me check

Copy link

@ptillet ptillet Apr 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, what should work (for now) is default arguments for tl.constexpr annotated arguments, and with triton 2.0 :p I'm not too sure about Triton 1.x

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right... i'm on triton 1.x at the moment...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to update to triton2.. CI is blocking right now, I hope to get that sorted out this week end

"""
SmeLU_ activation - Smooth ReLU

.. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf
"""
zero = 0.0
four = 4.0
beta = beta.to(x.dtype)
output = (x + beta) * (x + beta) / (four.to(x.dtype) * beta)
relu = tl.where(x >= beta, x, zero.to(x.dtype))
return tl.where(tl.abs(x) <= beta, output, relu)


@triton.jit
def smelu_grad(x, beta=2.0):
zero = 0.0
one = 1.0
two = 2.0
beta = beta.to(x.dtype)
grad = (beta + x) / (two.to(x.dtype) * beta)
relu_grad = tl.where(x >= beta, one.to(x.dtype), zero.to(x.dtype))
return tl.where(tl.abs(x) <= beta, grad, relu_grad)