Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added SmeLU #263

Merged
merged 16 commits into from May 10, 2022
Merged

Added SmeLU #263

merged 16 commits into from May 10, 2022

Conversation

kashif
Copy link
Contributor

@kashif kashif commented Apr 7, 2022

What does this PR do?

Fixes #262 .

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 馃檭
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 7, 2022
@blefaudeux
Copy link
Contributor

excellent, thanks @kashif ! Couple of comments, I hope that helps, there's a trick with the definition of beta I think

@kashif
Copy link
Contributor Author

kashif commented Apr 7, 2022

right i dont think the beta default arg would work as it is... I was just about to ask you how to deal with that?



@triton.jit
def smelu(x, beta=2.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that you can pass a default param with triton actually, it only works with a subset of the python syntax and my guess is that this is out of it (cc @ptillet). Something could be worth trying, having a getter for this kernel, like the following

def get_smelu_kernel(beta: float = 2.0): @triton.jit def smelu(x): pass # use beta here, but maybe that this will fail at the JIT phase

If that does not work,

  • for a start we could have a fixed beta, then iterate on the implementation to expose it (completely fine by me)
  • could be that the activation kernel take another parameter, which in that case would be the beta value, or that we figure out with Phil how to generate the kernel code on the fly with the proper beta

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @blefaudeux I'll give it a try... a bit late here so wanted to give it a shot in the morning 馃槾

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking this with Philippe, the default value should work actually, maybe that it needs to be : float ? or similar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok! cool let me check

Copy link

@ptillet ptillet Apr 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, what should work (for now) is default arguments for tl.constexpr annotated arguments, and with triton 2.0 :p I'm not too sure about Triton 1.x

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right... i'm on triton 1.x at the moment...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to update to triton2.. CI is blocking right now, I hope to get that sorted out this week end

@blefaudeux
Copy link
Contributor

rebasing this on top of #272 should work @kashif

@kashif
Copy link
Contributor Author

kashif commented Apr 19, 2022

ah cool! having a look

@blefaudeux
Copy link
Contributor

hey @kashif let me know if I can help

@kashif
Copy link
Contributor Author

kashif commented Apr 28, 2022

so I tried

def get_smelu_kernel(x, beta: float = 2.0):
    @triton.jit
    def smelu(x, beta):
        """
        SmeLU_ activation -  Smooth ReLU

        .. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf
        """
        zero = 0.0
        four = 4.0
        beta = beta.to(x.dtype)

        output = (x + beta) * (x + beta) / (four.to(x.dtype) * beta)
        relu = tl.where(x >= beta, x, zero.to(x.dtype))
        return tl.where(tl.abs(x) <= beta, output, relu)

    smelu(x, beta)

but that didnt work either so I am setting the beta param to 2.0 for now.

@blefaudeux
Copy link
Contributor

if you can I would really recommend setting up pre-commit, it helps with all the linting. Some explanations here

@blefaudeux
Copy link
Contributor

Thanks for the updates @kashif, looks good to me, we can always iterate to expose beta down the line. It looks like the errors are unrelated to your changes, maybe dependent on main having changed, could you try to rebase ? I can do that also if you'd like

Copy link
Contributor

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we need to solve these CI issues (guess is that it's waiting for a rebase), thanks a lot @kashif !

@kashif
Copy link
Contributor Author

kashif commented May 5, 2022

thank you! I learned a lot and now i have a 3090Ti to test on!

@kashif
Copy link
Contributor Author

kashif commented May 5, 2022

i was just heading out to eat... can you kindly rebase?

@blefaudeux
Copy link
Contributor

hmm, thanks for the update Kashif, looks like the errors are still there, I don't understand how they can be related. I'll have a look

@blefaudeux
Copy link
Contributor

Hmm, I can repro the CI error but it should be unrelated to your changes @kashif, it means there's something wrong either in the triton stack or in the cuda kernels :( I'm trying to sort that out

@blefaudeux
Copy link
Contributor

blefaudeux commented May 9, 2022

ok, at least I got something wrong: @fmassa, if I run CUDA_LAUNCH_BLOCKING=1 pytest tests -x -v in an env which does not have triton (so no triton kernel is launched), I'm getting that, which points to an issue in the sparse kernels. Rings a bell ?

tests/test_block_factory.py::test_xformer_decoder_block[device0-post-MixtureOfExperts-global-False-relu-1-True-0.1-0.0] PASSED                                                                                                                                                                                                                                                                                                          [ 73%]
tests/test_block_factory.py::test_xformer_decoder_block[device0-post-MixtureOfExperts-global-False-relu-1-True-0.1-0.1] FAILED                                                                                                                                                                                                                                                                                                          [ 73%]

================================================================================================================================================================================================================== FAILURES ===================================================================================================================================================================================================================
_________________________________________________________________________________________________________________________________________________________________________ test_xformer_decoder_block[device0-post-MixtureOfExperts-global-False-relu-1-True-0.1-0.1] __________________________________________________________________________________________________________________________________________________________________________

attention_name = 'global', rotary_embeddings = False, feedforward_name = 'MixtureOfExperts', heads = 1, attn_dropout = 0.1, residual_dropout = 0.1, causal = True, activation = 'relu', layer_norm_style = 'post', device = device(type='cuda')

    @pytest.mark.parametrize("attn_dropout", [0.0, 0.1])
    @pytest.mark.parametrize("residual_dropout", [0.0, 0.1])
    @pytest.mark.parametrize("causal", [True, False])
    @pytest.mark.parametrize("heads", [1, 2])
    @pytest.mark.parametrize("activation", [a.value for a in Activation])
    @pytest.mark.parametrize("rotary_embeddings", [False, True])
    @pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
    @pytest.mark.parametrize("feedforward_name", FEEDFORWARD_REGISTRY.keys())
    @pytest.mark.parametrize("layer_norm_style", ["pre", "post"])
    @pytest.mark.parametrize("device", DEVICES)
    @pytest.mark.skipif(
        not torch.cuda.is_available(), reason="This test requires a CUDA device"
    )
    def test_xformer_decoder_block(
        attention_name: str,
        rotary_embeddings: bool,
        feedforward_name: str,
        heads: int,
        attn_dropout: float,
        residual_dropout: float,
        causal: bool,
        activation: Activation,
        layer_norm_style: str,
        device: torch.device,
    ):
    
        block_size = 16
    
        attention_config = {
            "name": attention_name,
            "dropout": attn_dropout,
            "causal": causal,
            "window_size": SEQ // 8 + 1,
            "seq_len": SEQ,
            "dim_head": MODEL // heads,
            "attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
            "layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long),
            "block_size": block_size,
            "num_rules": 2,  # Compositional Attention
        }
    
        multi_head_config = {
            "num_heads": heads,
            "dim_model": MODEL,
            "residual_dropout": residual_dropout,
            "attention": attention_config,
            "use_rotary_embeddings": rotary_embeddings,
        }
    
        feedforward_config = {
            "name": feedforward_name,
            "dim_model": MODEL,
            "dropout": DROPOUT,
            "activation": activation,
            "hidden_layer_multiplier": 4,
            "number_of_experts": 4,
            "gate": "top_2",
        }
    
        if feedforward_name == "MixtureOfExperts":
            init_torch_distributed_local()
    
        position_encoding_config = {
            "name": "sine",
            "dim_model": MODEL,
            "seq_len": SEQ,
            "vocab_size": VOCAB_SIZE,
        }
    
        encoder_block_config = xFormerEncoderConfig(
            dim_model=MODEL,
            multi_head_config=multi_head_config,
            feedforward_config=feedforward_config,
            position_encoding_config=position_encoding_config,
            layer_norm_style=layer_norm_style,
        )
    
        decoder_block_config = xFormerDecoderConfig(
            dim_model=MODEL,
            multi_head_config_masked=multi_head_config,
            multi_head_config_cross=multi_head_config,
            feedforward_config=feedforward_config,
            position_encoding_config=position_encoding_config,
            layer_norm_style=layer_norm_style,
        )
    
        # Test that the whole block can be instantiated
        encoder_block = xFormerEncoderBlock.from_config(encoder_block_config).to(device)
        decoder_block = xFormerDecoderBlock.from_config(decoder_block_config).to(device)
    
        # Check that the dimensions make sense, to a FW pass
        inputs = torch.rand(BATCH, SEQ, device=device)
>       encoded = encoder_block(inputs)

tests/test_block_factory.py:222: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py:1110: in _call_impl
    return forward_call(*input, **kwargs)
xformers/factory/block_factory.py:207: in forward
    x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask)
../../.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py:1110: in _call_impl
    return forward_call(*input, **kwargs)
xformers/components/residual.py:120: in forward
    x = self.sublayer(inputs=inputs, **kwargs)
../../.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py:1110: in _call_impl
    return forward_call(*input, **kwargs)
xformers/components/residual.py:68: in forward
    return residue + self.layer(*inputs, **kwargs)
../../.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py:1110: in _call_impl
    return forward_call(*input, **kwargs)
xformers/components/multi_head_dispatch.py:213: in forward
    y = self.attention(q, k, v, **kw_mask_args)
../../.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/modules/module.py:1110: in _call_impl
    return forward_call(*input, **kwargs)
xformers/components/attention/global_tokens.py:115: in forward
    att = scaled_dot_product_attention(
xformers/components/attention/core.py:230: in scaled_dot_product_attention
    att = scaled_query_key_softmax(q, k, att_mask=att_mask)
xformers/components/attention/core.py:209: in scaled_query_key_softmax
    att = _softmax(att, causal=is_causal)
xformers/components/attention/core.py:98: in _softmax
    return a.softmax()
xformers/components/attention/_sputnik_sparse.py:96: in softmax
    out = torch.nn.functional.softmax(self._mat, -1)
../../.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/nn/functional.py:1779: in softmax
    return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
../../.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/overrides.py:1390: in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
xformers/sparse/csr_tensor.py:322: in __torch_function__
    return cls._softmax(args[0], kwargs["dim"])
xformers/sparse/csr_tensor.py:135: in _softmax
    out = _csr_ops._SparseSoftmax.apply(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

ctx = <torch.autograd.function._SparseSoftmaxBackward object at 0x7f36522afe40>, m = 64, n = 64, row_indices = tensor([48, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 32, 49,
        50, 51, 52, 53, 54, 55, 56, 57...15,  0, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 26, 27, 28, 29, 30, 31], device='cuda:0',
       dtype=torch.int32), values = tensor([], device='cuda:0', size=(2, 0), grad_fn=<_sddmmBackward>)
row_offsets = tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0',
       dtype=torch.int32), column_indices = tensor([], device='cuda:0', dtype=torch.int32)

    @staticmethod
    def forward(ctx, m, n, row_indices, values, row_offsets, column_indices):
>       out = torch.ops.xformers.sparse_softmax_sputnik(
            m, n, row_indices, values, row_offsets, column_indices
        )
E       RuntimeError: CUDA error: invalid configuration argument

xformers/sparse/_csr_ops.py:56: RuntimeError
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ Captured stderr call -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
WARNING:root:Local experts no specified but world size of 1
WARNING:root:Assuming that all experts are local
WARNING:root:Local experts no specified but world size of 1
WARNING:root:Assuming that all experts are local
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Captured log call --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
WARNING  root:mixture_of_experts.py:108 Local experts no specified but world size of 1
WARNING  root:mixture_of_experts.py:109 Assuming that all experts are local
WARNING  root:mixture_of_experts.py:108 Local experts no specified but world size of 1
WARNING  root:mixture_of_experts.py:109 Assuming that all experts are local
============================================================================================================================================================================================================== warnings summary ===============================================================================================================================================================================================================
tests/test_attention_patterns.py::test_local_1d_pattern[50-3]
  /home/lefaudeux/.conda/envs/xformers_dev/lib/python3.8/site-packages/torch/functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2144.)
    return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]

tests/test_block_factory.py: 9012 tests with warnings
  /home/lefaudeux/Git/xformers/xformers/components/positional_embedding/sine.py:39: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
    div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model))

-- Docs: https://docs.pytest.org/en/latest/warnings.html
=========================================================================================================================================================================================================== short test summary info ===========================================================================================================================================================================================================
FAILED tests/test_block_factory.py::test_xformer_decoder_block[device0-post-MixtureOfExperts-global-False-relu-1-True-0.1-0.1] - RuntimeError: CUDA error: invalid configuration argument

@blefaudeux
Copy link
Contributor

Note that it happens without fairscale (so without the MixtureOfExperts), same error, on "global" again

@fmassa
Copy link
Contributor

fmassa commented May 10, 2022

I'm looking at the issue.

@fmassa fmassa mentioned this pull request May 10, 2022
@fmassa
Copy link
Contributor

fmassa commented May 10, 2022

This should be fixed with #300

Looks like some of the configurations in the test are generating a fully-empty (all zeros) matrix. Might be good to have a look to see if this is intended.

@blefaudeux blefaudeux merged commit 837d7c0 into facebookresearch:main May 10, 2022
@kashif kashif deleted the smooth-relu branch May 10, 2022 18:24
@blefaudeux
Copy link
Contributor

blefaudeux commented May 10, 2022

For the record, here are my numbers on a desktop 3080 with an incoming fused linear PR (/main numbers should be somewhat close)

--- Type: torch.float16 ---

Units: TFlops/s B=8, M=512, K=256 B=8, M=512, K=512 B=4, M=512, K=1024 B=2, M=512, K=2048 B=2, M=512, K=4096 B=2, M=512, K=8192
pytorch - smelu - no bias - fw+bw 2.1 3.8 5.8 8.2 10.1 10.4
triton - smelu - no bias - fw+bw 2.9 9.2 10.4 12.0 12.6 13.6
pytorch - smelu - bias - fw+bw 1.9 3.6 5.6 8.0 9.4 11.2
triton - smelu - bias - fw+bw 2.6 8.2 9.9 11.6 10.8 13.2

--- Type: torch.float16 ---

Units: TFlops/s B=8, M=512, K=256 B=8, M=512, K=512 B=4, M=512, K=1024 B=2, M=512, K=2048 B=2, M=512, K=4096 B=2, M=512, K=8192
pytorch - smelu - no bias - fw 1.8 3.5 5.5 7.9 10.2 10.0
triton - smelu - no bias - fw 11.4 13.6 13.9 13.9 14.1 13.7
pytorch - smelu - bias - fw 1.7 3.2 5.1 7.5 9.9 10.7
triton - smelu - bias - fw 11.4 13.6 14.0 14.0 14.0 12.3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] Add smooth relu to the fused linear layer (triton) activations
5 participants