Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama inference test #515

Open
HandH1998 opened this issue Jun 29, 2023 · 3 comments
Open

llama inference test #515

HandH1998 opened this issue Jun 29, 2023 · 3 comments

Comments

@HandH1998
Copy link
Contributor

HandH1998 commented Jun 29, 2023

I build lightseq on cuda11.4 successfully. Then I do llama-13B inference test on A100-80G. I set max_step=1024. When max_batch_size <11, it works fine. The problem is that when I set max_batch_size >= 11, lightseq/csrc/ops_new/sampling.cc.cu(73): an illegal memory access was encountered. And I also use CUDA_LAUNCH_BLOCKING=1 to locate the problem, lightseq/csrc/ops_new/sampling.cc.cu(57): an illegal memory access was encountered.The memory uses about 40G, so it is not OOM problem. The following is my inference test script. Please help me with the problem.

import time
import argparse
import numpy as np
import torch
import lightseq.inference as lsi
from transformers import LlamaTokenizer, LlamaForCausalLM

def ls_llama(model, inputs):
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    results = model.infer(inputs)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    return results, end_time - start_time

def ls_generate(model, tokenizer, inputs):
    print("=========lightseq=========")
    print("lightseq generating...")
    ls_res_ids, ls_time = ls_llama(model, inputs)

    ls_res_ids = np.squeeze(ls_res_ids, axis=1)
    # ls_res = tokenizer.batch_decode(ls_res_ids, skip_special_tokens=True)
    ls_res = tokenizer.batch_decode(ls_res_ids)
    print("lightseq results:")
    for sent in ls_res:
        print(sent)

    input_seq_len = inputs.shape[1]
    input_bsz = inputs.shape[0]
    input_total_tokens = input_seq_len * input_bsz

    print("input_seq_len: {}".format(input_seq_len))
    print("input_bsz: {}".format(input_bsz))
    print("input_total_tokens: {}".format(input_total_tokens))

    output_total_tokens = ls_res_ids.size
    gen_total_tokens = output_total_tokens - input_total_tokens
    output_seq_len = [seq.size for seq in ls_res_ids]

    print("output_total_tokens: {}".format(output_total_tokens))
    print("output_seq_len: {}".format(output_seq_len))
    print("gen_total_tokens: {}".format(gen_total_tokens))
    print(f"lightseq time: {ls_time}s")
    print("gen_speed: {} tokens/s".format(gen_total_tokens / ls_time))


def warmup(ls_tokenizer, ls_model, sentences):
    ls_inputs = ls_tokenizer(sentences, return_tensors="pt", padding=True)["input_ids"]
    ls_generate(ls_model, ls_tokenizer, ls_inputs)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--user_input", action="store_true")
    args = parser.parse_args()
    print("initializing gpt tokenizer...")
    ls_tokenizer = LlamaTokenizer.from_pretrained(
        "/home/zy/lightseq/llama/13b"
    )
    ls_tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    print("creating lightseq model...")
    # llama_weight_path = "/home/zy/lightseq/llama_13b.hdf5"
    ls_model = lsi.Llama(llama_weight_path, max_batch_size=11)

    # lightseq gpt perplexity supports batch infer with different lengths,
    # but sampling doesn't support
    sentences = [
        "Are you a pig?",
        "I love you, but you say that",
        "I love you, but you say that",
        "I love you, but you say that",
        "I love you, but you say that",
        "I love you, but you say that",
        "I love you, but you say that",
        "I love you, but you say that",
        "Are you a pig?",
        "I love you, but you say that",
        "I love you, but you say that",
    ]
    print("====================START warmup====================")
    warmup(
        ls_tokenizer,
        ls_model,
        sentences,
    )
    print("====================END warmup====================")

    while True:
        if args.user_input:
            sentences = [input("input the masked sentence:\n")]

        print("tokenizing the sentences...")

        ls_inputs = ls_tokenizer(sentences, return_tensors="pt", padding=True)[
            "input_ids"
        ]
        ls_generate(ls_model, ls_tokenizer, ls_inputs)

        if not args.user_input:
            break


if __name__ == "__main__":
    main()
@ChristineSeven
Copy link

ChristineSeven commented Aug 2, 2023

use your code, i got this error, module 'lightseq.inference' has no attribute 'Llama' . could you tell how you bypass this? @HandH1998

@HandH1998
Copy link
Contributor Author

use your code, i got this error, module 'lightseq.inference' has no attribute 'Llama' . could you tell how you bypass this? @HandH1998

It seems that you didn't compile it correctly.
image
Change use_new_arch to ON.

@ChristineSeven
Copy link

@HandH1998 Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants