Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for mosaicml/mpt-30b-instruct model #491

Closed
maziyarpanahi opened this issue Jun 23, 2023 · 19 comments · Fixed by #514
Closed

Support for mosaicml/mpt-30b-instruct model #491

maziyarpanahi opened this issue Jun 23, 2023 · 19 comments · Fixed by #514

Comments

@maziyarpanahi
Copy link
Contributor

Feature request

I was wondering if there will be a support for the newly released mpt-30b-instruct

Motivation

It's not possible to use mosaicml/mpt-30b-instruct model:

ValueError: sharded is not supported for AutoModel

Your contribution

I am not sure how you can add support for new LLM models. (if there is a step by step guide as where to start would be great and I can contribute)

@mantrakp04
Copy link

did you try --trust-remote-code while running the docker

@tim-a-davis
Copy link

it's very slow. This model is not supported for sharding at the moment in text-generation-inference.

did you try --trust-remote-code while running the docker

@mantrakp04
Copy link

Then try implementing a rudimentary implementation of it, you can use rust or js as router and Python for inference, copy the custom kernels from the repo, modify them as suitable, mpt already has an implementation for flash attention in its "remote code file" use that and batch_encode_plus while tokenizer and batch_decode, implement batching on router server and volla u have your own server ready for inference

@tim-a-davis
Copy link

Then try implementing a rudimentary implementation of it, you can use rust or js as router and Python for inference, copy the custom kernels from the repo, modify them as suitable, mpt already has an implementation for flash attention in its "remote code file" use that and batch_encode_plus while tokenizer and batch_decode, implement batching on router server and volla u have your own server ready for inference

Maybe you could write me one as an example?

@mantrakp04
Copy link

mantrakp04 commented Jun 23, 2023

Am working on one right now, if you would like to help out (discord: mantrakp)

@SinanAkkoyun
Copy link

I am also very interested in this
I know the router side but how do you actually "on the fly" batch compute multiple requests at once with transformers?

@SinanAkkoyun
Copy link

(And can we expect an optimized tgi implementation soon?)

@Narsil
Copy link
Collaborator

Narsil commented Jun 26, 2023

Take example to other models we have done in server/text-generation-server/models/custom_modeling/*.py maybe ?

There's also some files in server/text-generation-server/models/*.py. Those are declaring the model as being flash enabled (the batching happens differently when a model supports flash).

If you succeed PRs are welcome !

@louis030195
Copy link

is mpt even supported #290 ?

@Narsil
Copy link
Collaborator

Narsil commented Jun 28, 2023

It's supported on the "best effort basis".

I started some work to actually support it, but it means rewriting flash attention (the cuda version) with added bias, which may take some time.

@mantrakp04
Copy link

Sad news: I didn't succeed, the mpt model is a bit different, i tried loading it but it didn't work as expected and keeps mixing up tokens. I am looking forward for your implementation Narsil, sorry for the wait.

A pre thanks (an advance thank) to narsil :D

@ankit201
Copy link

It's supported on the "best effort basis".

I started some work to actually support it, but it means rewriting flash attention (the cuda version) with added bias, which may take some time.

Can you guide on how you started writing the flash attention part and what are your thoughts on implementing dynamic batching for this as it only supports 1 concurrent request for now on AutoModel.
A little guidance would be really great, maybe we can collaborate and try this out.

@Narsil
Copy link
Collaborator

Narsil commented Jun 29, 2023

on implementing dynamic batching for this as it only supports 1 concurrent request for now on AutoModel.

This won't require work once we have flash attention.

@ankit201
Copy link

on implementing dynamic batching for this as it only supports 1 concurrent request for now on AutoModel.

This won't require work once we have flash attention.

Please correct me if I'm wrong but do we need to implement this since mpt-30 models already has flashattention usage prebuilt in its config?
mpt-30b-chat

import torch
import transformers

name = 'mosaicml/mpt-30b-chat'

config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
config.attn_config['attn_impl'] = 'triton'  # change this to use triton-based FlashAttention
config.init_device = 'cuda:0' # For fast initialization directly on GPU!

model = transformers.AutoModelForCausalLM.from_pretrained(
  name,
  config=config,
  torch_dtype=torch.bfloat16, # Load model weights in bfloat16
  trust_remote_code=True
)

@Narsil
Copy link
Collaborator

Narsil commented Jun 30, 2023

Because it doesn't implement the flash attention we want.

This is Triton's flash attention, which doesn't support "unpadded" batching, which is the one necessary to work nicely on TGI (removing batching, removes a LOT of issues and unnecessary memory and speeds up inference much more than flash by itself).

Flash attention actually doesn't play that big of a role in speeding things up at inference, since most of the time is spent in decode where it doesn't really help. But the no padding thing is extremely important.

@Narsil
Copy link
Collaborator

Narsil commented Jul 1, 2023

Here is the non flash version (as a temporary measure since modifying the kernel is taking more time than I anticipated: #514

This should enable sharding at least.

@Narsil Narsil mentioned this issue Jul 1, 2023
5 tasks
@ankit201
Copy link

ankit201 commented Jul 1, 2023

Here is the non flash version (as a temporary measure since modifying the kernel is taking more time than I anticipated: #514

This should enable sharding at least.

Many thanks for this. Looking forward to the flash class too.
Cheers!

OlivierDehaene pushed a commit that referenced this issue Jul 3, 2023
# What does this PR do?


This adds a non flash version of MPT.
Flash is harder because we need to create a bias ready cuda kernel of
flash attention.

Fixes
#361
Fixes
#491
Fixes
#290
@ConProgramming
Copy link

Because it doesn't implement the flash attention we want.

This is Triton's flash attention, which doesn't support "unpadded" batching, which is the one necessary to work nicely on TGI (removing batching, removes a LOT of issues and unnecessary memory and speeds up inference much more than flash by itself).

Flash attention actually doesn't play that big of a role in speeding things up at inference, since most of the time is spent in decode where it doesn't really help. But the no padding thing is extremely important.

Triton is the only flash attention implementation that supports ALiBi, if I understand this correctly.

So for TGI, if we want to use MPT with ALiBi, does that leave us with just the native pytorch implementation?

@OlivierDehaene
Copy link
Member

OlivierDehaene commented Jul 4, 2023

We will fork and add it ourselves to the flash attention cuda kernels.

AIProphet added a commit to AIProphet/text-generation-inference that referenced this issue Jul 12, 2023
# What does this PR do?


This adds a non flash version of MPT.
Flash is harder because we need to create a bias ready cuda kernel of
flash attention.

Fixes
huggingface/text-generation-inference#361
Fixes
huggingface/text-generation-inference#491
Fixes
huggingface/text-generation-inference#290
verdant621 added a commit to verdant621/text-generation-inference that referenced this issue Oct 19, 2023
# What does this PR do?


This adds a non flash version of MPT.
Flash is harder because we need to create a bias ready cuda kernel of
flash attention.

Fixes
huggingface/text-generation-inference#361
Fixes
huggingface/text-generation-inference#491
Fixes
huggingface/text-generation-inference#290
cr313 added a commit to cr313/text-generation-inference-load-test that referenced this issue Apr 19, 2024
# What does this PR do?


This adds a non flash version of MPT.
Flash is harder because we need to create a bias ready cuda kernel of
flash attention.

Fixes
huggingface/text-generation-inference#361
Fixes
huggingface/text-generation-inference#491
Fixes
huggingface/text-generation-inference#290
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants