Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BigCode models (StarCoder etc.) #5

Closed
casper-hansen opened this issue Aug 25, 2023 · 7 comments
Closed

Implement BigCode models (StarCoder etc.) #5

casper-hansen opened this issue Aug 25, 2023 · 7 comments

Comments

@casper-hansen
Copy link
Owner

https://huggingface.co/bigcode/starcoder

@curname
Copy link

curname commented Aug 29, 2023

Hi, @casper-hansen Thanks for your great work.
Based on the branch you submitted but not merged in awq, I tried some experiments with the starcoder model, but the accuracy dropped significantly after quantization. The accuracy on he python is about 18%, and the inference speed is 43ms/token. The above two results are averaged on the he python dataset. I’m not sure what went wrong. Could it be that starcoder used multi-query attention? Apart from that, I can’t think of any other reason.

Here is my code:

      from .base import BaseAWQForCausalLM
      
      
      class BigCodeAWQForCausalLM(BaseAWQForCausalLM):
          layer_type = "gpt_bigcode"
          max_new_tokens_key = "n_positions"
      
          @staticmethod
          def get_model_layers(model):
              return model.transformer.h
      
          @staticmethod
          def get_act_for_scaling(module):
              return dict(
                  is_scalable=True,
                  scale_name="mlp.act",
                  scale_layer=module.mlp.act,
                  scale_shape=module.mlp.c_fc.out_features
              )
      
          @staticmethod
          def move_embed(model, device):
              model.transformer.wte = model.transformer.wte.to(device)
              model.transformer.drop = model.transformer.drop.to(device)
      
          @staticmethod
          def get_layers_for_scaling(module, input_feat, module_kwargs):
              layers = []
      
              # attention input
              layers.append(dict(
                  prev_op=module.ln_1,
                  layers=[module.attn.c_attn],
                  inp=input_feat['attn.c_attn'],
                  module2inspect=module.attn,
                  kwargs=module_kwargs
              ))
      
              # attention output
              layers.append(dict(
                  prev_op=module.attn.c_attn,
                  layers=[module.attn.c_proj],
                  inp=input_feat['attn.c_proj']
              ))
      
              # linear 1
              layers.append(dict(
                  prev_op=module.ln_2,
                  layers=[module.mlp.c_fc],
                  inp=input_feat['mlp.c_fc'],
                  module2inspect=module.mlp
              ))
      
              # linear 2
              layers.append(dict(
                  prev_op=module.mlp.act,
                  layers=[module.mlp.c_proj],
                  inp=input_feat['mlp.c_proj']
              ))
      
              return layers

I am not sure do i need do something special?
Looking forward to your reply!

@casper-hansen
Copy link
Owner Author

accuracy on he python is about 18%

Seems something must have gone wrong here when converting. I will look into the specification of the layers.

@curname Can you paste the code you used to measure the accuracy?

43ms/token

This could be reasonable dependent on hardware/model size, but seems there is room for improvement here.

@curname
Copy link

curname commented Aug 30, 2023

Hi, @casper-hansen

accuracy on he python is about 18%

The code for measuring the HE accuracy comes from the OpenAI human-eval project, the address is here https://github.com/openai/human-eval/tree/master/human_eval.
I did some experiments more, like the bloom, I did not scale the attention_output of starcoder, the latest results not only did not decline, but also had a slight improvement, the accuracy of HE python reached 36%, which is really surprising.
And the original model url: https://huggingface.co/bigcode/starcoder/tree/main

43ms/token

I did the above implementation on A100 80G, the speed of awq and gptq is almost same, the experiments in the paper prove that awq is better than gptq, although the experimental model is mainly llama, not starcoder, if I want to further improve the inference speed to 30ms/token, or even 20ms/token, I would appreciate it if you could give me some suggestions.

@curname
Copy link

curname commented Aug 30, 2023

like the bloom, I did not scale the attention_output of starcoder

And the code like this:

    `from .base import BaseAWQForCausalLM
    
    
    class BigCodeAWQForCausalLM(BaseAWQForCausalLM):
        layer_type = "gpt_bigcode"
        max_new_tokens_key = "n_positions"
    
        @staticmethod
        def get_model_layers(model):
            return model.transformer.h
    
        @staticmethod
        def get_act_for_scaling(module):
            # return dict(
            #     is_scalable=False
            # )
            return dict(
                is_scalable=True,
                scale_name="mlp.act",
                scale_layer=module.mlp.act,
                scale_shape=module.mlp.c_fc.out_features
            )
    
        @staticmethod
        def move_embed(model, device):
            model.transformer.wte = model.transformer.wte.to(device)
            model.transformer.drop = model.transformer.drop.to(device)
    
        @staticmethod
        def get_layers_for_scaling(module, input_feat, module_kwargs):
            layers = []
    
            # attention input
            layers.append(dict(
                prev_op=module.ln_1,
                layers=[module.attn.c_attn],
                inp=input_feat['attn.c_attn'],
                module2inspect=module.attn,
                kwargs=module_kwargs
            ))
    
            # attention output
            # layers.append(dict(
            #     prev_op=module.attn.c_attn,
            #     layers=[module.attn.c_proj],
            #     inp=input_feat['attn.c_proj']
            # ))
    
            # linear 1
            layers.append(dict(
                prev_op=module.ln_2,
                layers=[module.mlp.c_fc],
                inp=input_feat['mlp.c_fc'],
                module2inspect=module.mlp
            ))
    
            # linear 2
            layers.append(dict(
                prev_op=module.mlp.act,
                layers=[module.mlp.c_proj],
                inp=input_feat['mlp.c_proj']
            ))
    
            return layers`

@casper-hansen
Copy link
Owner Author

@curname Did you get it working with better accuracy yet? Also, did you test perplexity before and after on wikitext? A normal in wikitext perplexity is between 2-5% (LLaMa 7B is around 2%). I wish I could test these models for you but unfortunately, I do not have many GPU resources available to me because of the cost associated with them.

The code for testing can be found here:

AutoAWQ/awq/entry.py

Lines 134 to 138 in 783afe5

- Run perplexity of quantized model:
python -m awq.entry --entry_type eval --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
- Run perplexity unquantized FP16 model:
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained

@abacaj
Copy link

abacaj commented Sep 11, 2023

Hi I tried this code, seems to work in terms of model generating correctly. But strangely very slow output when compared to fp16, the model I tried was 1B. Using a 3090, GPU util is very low < 10% when generating:

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'abacaj/starcoderbase-1b-sft'
quant_path = 'starcoder-1b-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4 }

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

@casper-hansen
Copy link
Owner Author

Hi I tried this code, seems to work in terms of model generating correctly. But strangely very slow output when compared to fp16, the model I tried was 1B. Using a 3090, GPU util is very low < 10% when generating:

The process to support models is first to support quantization and then we move on to optimize the inference after by fusing layers. Additionally, we have an upcoming PR that will enable speeding up inference much easier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants