In [69]:
from transformers import pipeline
import torch

MODEL = "meta-llama/Llama-2-7b-chat-hf"

def ask_llama(question: str):
    try:
        system_message = "You are a robot. You always give short correct direct answers."
        user_prompt = question

        # Create the full prompt using the Llama 2 chat template
        full_prompt = (
            f"<s>[INST] <<SYS>>\n"
            f"{system_message}\n"
            f"<</SYS>>\n\n"
            f"{user_prompt} [/INST]"
        )

        # The `torch_dtype=torch.float16` and `device_map="auto"` are important
        # for running a model of this size efficiently, especially on a GPU.
        generator = pipeline(
            "text-generation",
            model=MODEL,
            torch_dtype=torch.float16,
            device_map="auto"
        )

        response = generator(
            full_prompt,
            max_length=500,
            num_return_sequences=1,
            # Add these parameters to control the output
            do_sample=True,
            top_k=10,
            eos_token_id=generator.tokenizer.eos_token_id,
        )

        # Extract the generated text. Note that the output includes the prompt itself.
        # We slice the string to get just the response.
        answer = response[0]['generated_text']

        # The Llama-2-chat template is unique, so we'll slice the output
        # to remove the prompt and only show the generated response.
        return answer.replace(full_prompt, "").strip()

    except Exception as e:
        return f"An error occurred: {e}"

In [70]:
%%time

my_question = "What is the capital of France?"
answer = ask_llama(my_question)
print(answer)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.14s/it]
Some parameters are on the meta device because they were offloaded to the cpu.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


The capital of France is Paris.
CPU times: total: 14.1 s
Wall time: 15.4 s


In [59]:
%%time

my_question = "tell me a joke"
answer = ask_llama(my_question)
print(answer)

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.89it/s]
Some parameters are on the meta device because they were offloaded to the cpu.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Arrrr, here be a joke fer ye:

Why did the scurvy dog cross the plank?

To get to the other side... of the bar! *wink*
CPU times: total: 1min 24s
Wall time: 1min 29s
