Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No acceleration compared with timm vit block #410

Closed
woolpeeker opened this issue Oct 26, 2022 · 11 comments
Closed

No acceleration compared with timm vit block #410

woolpeeker opened this issue Oct 26, 2022 · 11 comments

Comments

@woolpeeker
Copy link

I use the code below to test the vit block speed. The output shows the speed is almost the same between pytorch and lightseq

Did I missed something?

Output for forward only:

timm finished 500 running, avg_time: 76.379987 ms
light_seq finished 500 running, avg_time: 75.543549 ms

The output for forward + backward:

timm finished 500 running, avg_time: 228.803998 ms
light_seq finished 500 running, avg_time: 227.007331 ms

from timm.models.vision_transformer import Block
from lightseq.training.ops.pytorch.transformer_encoder_layer import LSTransformerEncoderLayer
from easydict import EasyDict as edict
import torch.nn as nn
import torch
import time
import sys
sys.path.append('./')


torch.backends.cudnn.benchmark = True


def generate_dummy_data(args):
    inputs = torch.randn([args.bs, args.num_token, args.dim]).cuda()
    return (inputs, )


def get_timm_block(args):
    return Block(
        dim=args.dim,
        num_heads=args.num_heads,
        mlp_ratio=args.mlp_ratio,
        qkv_bias=False,
        drop=False,
        attn_drop=False,
        init_values=None,
        drop_path=0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm
    )

class LSBlockWrapper(LSTransformerEncoderLayer):
    def forward(self, x):
        B, N, C = x.shape
        mask = torch.zeros([B, N, N], device=x.device, dtype=x.dtype)
        return super().forward(x, mask)

def get_ls_block(args):
    config = LSBlockWrapper.get_config(
        max_batch_tokens=args.num_token * args.bs,
        max_seq_len=args.num_token,
        hidden_size=args.dim,
        intermediate_size=int(args.mlp_ratio * args.dim),
        nhead=args.num_heads,
        attn_prob_dropout_ratio=0,
        hidden_dropout_ratio=0,
        activation_dropout_ratio=0,
        pre_layer_norm=True,
        fp16=False,
        local_rank=0,
        activation_fn='gelu')
    return LSBlockWrapper(
            config=config,
            initial_weights=None,
            initial_biases=None
        )


def run(module, args, name='Unknown'):
    inputs = generate_dummy_data(args)

    # cudnn warmup
    for _ in range(50):
        if args.backward:
            module(*inputs).sum().backward()
        else:
            module(*inputs)

    torch.cuda.synchronize()
    t0 = time.time()

    for _ in range(args.num_iter):
        if args.backward:
            module(*inputs).sum().backward()
        else:
            module(*inputs)

    torch.cuda.synchronize()
    t1 = time.time()

    avg_time = (t1 - t0) * 1000 / args.num_iter
    print(
        f'>>> {name} finished {args.num_iter} running, avg_time: {avg_time:.6f} ms')
    return avg_time


def main():
    args = edict()
    args.num_iter = 500
    args.backward = False

    args.bs = 128
    args.dim = 1280
    args.num_heads = 16
    args.mlp_ratio = 4.0
    args.num_token = 256

    timm_block = get_timm_block(args).cuda()
    ls_block = get_ls_block(args).cuda()

    run(timm_block, args, name='timm')
    run(ls_block, args, name='light_seq')

    print('Finished.')

if __name__ == '__main__':
    main()
@Taka152
Copy link
Contributor

Taka152 commented Oct 28, 2022

It seems like you are using fp32, could you try fp16?

@woolpeeker
Copy link
Author

woolpeeker commented Oct 28, 2022

Sure, I tested it with fp16.

with backward=False

timm finished 500 running, avg_time: 10.408471 ms
light_seq finished 500 running, avg_time: 9.462291 ms

with backward=True

timm finished 500 running, avg_time: 31.718561 ms
light_seq finished 500 running, avg_time: 30.036270 ms
only 1.7 ms difference.

The test environment is:
PyTorch version: 1.12.1
CUDA used to build PyTorch: 11.3
Python version: 3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0] (64-bit runtime)
lightseq==2.2.1

@woolpeeker
Copy link
Author

Do you have official test result of ViT between native pytorch and lightseq?

@Taka152
Copy link
Contributor

Taka152 commented Nov 2, 2022

We have tested on 8xA100, and this is the result.
image
BTW, could you tell me the card you are using, I will check if I can reproduce your results.

@woolpeeker
Copy link
Author

I used a single A100-80G GPU

@godweiyang
Copy link
Collaborator

I used a single A100-80G GPU

Hi, I tested fp16 precision of your script. Below is the result:
IHVkWS8bgR

I ran it for multiple times, and the results are the same.

I also ran other dims and batch_sizes, and all the results show that lightseq is faster.

@woolpeeker
Copy link
Author

Thanks for testing. The lightseq result is the same with mine, around 9.5 ms
in our machine, the original timm layer is a little faster than yours. Ours is around 10.4 ms.
Lightseq is always faster than original layer, but the gap is smaller in our machine.

I will close this issue.

BTW, do you have plan to implement the flashAttention in lightseq?
I heard it is much faster than previous methods.

@syorami
Copy link

syorami commented Dec 1, 2022

hi @woolpeeker I'm also trying to integrate lightseq into my project with timm ViT model. After switching to the lightseq layer, can we still load the same pretrained model weights?

@woolpeeker
Copy link
Author

yes, you just need to organize the weight tensor following lightseq' docs. I have test it. the result can align to timm

@syorami
Copy link

syorami commented Dec 2, 2022

@woolpeeker Thanks! This saves much time for me.

@syorami
Copy link

syorami commented Dec 6, 2022

hi @woolpeeker I followed the doc and integrate lightseq transformer layer into timm ViT but the speed improvement is trivial. I'm wondering if you can observe any speedup as I'm using same GPU (A100) and the above snippet code gives me exactly same results as yours. I guess the relative improvement heavily depends on the batch size or other hyperparameters.

BTW, switching to FlashAttention could give me a speed boost of around 10% for the whole timm ViT model and for a single attention block, the speedup is 45%. I hope this could help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants