# Tested on a g5.4xlarge

In [1]:
%pip install -U torch==2.0.1 \
  transformers==4.33.0 \
  sentencepiece==0.1.99 \
  accelerate==0.22.0 # needed for low_cpu_mem_usage parameter

Collecting torch==2.0.1
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl.metadata (24 kB)
Collecting transformers==4.33.0
  Downloading transformers-4.33.0-py3-none-any.whl.metadata (119 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.9/119.9 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentencepiece==0.1.99
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting accelerate==0.22.0
  Downloading accelerate-0.22.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-cupti-cu11==11.7.101 (from torch==2.0.1)
  Downloading nvidia_cuda_cupti_

In [2]:
import torch
from transformers import LlamaTokenizer

# Initialize tokenizer and model
model_checkpoint = "NousResearch/Llama-2-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_checkpoint)

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

In [3]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda')

In [5]:
# based on https://github.com/viniciusarruda/llama-cpp-chat-completion-wrapper/blob/1c9e29b70b1aaa7133d3c7d7b59a92d840e92e6d/llama_cpp_chat_completion_wrapper.py

from typing import List
from typing import Literal
from typing import TypedDict

from transformers import PreTrainedTokenizer

Role = Literal["system", "user", "assistant"]

class Message(TypedDict):
    role: Role
    content: str

MessageList = List[Message]

BEGIN_INST, END_INST = "[INST] ", " [/INST] "
BEGIN_SYS, END_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

def convert_list_of_message_lists_to_input_prompt(list_of_message_lists: List[MessageList], tokenizer: PreTrainedTokenizer) -> List[str]:
    input_prompts: List[str] = []
    print(type(list_of_message_lists))
    print(type(list_of_message_lists[0]))    
    for message_list in list_of_message_lists:
        if message_list[0]["role"] == "system":
            content = "".join([BEGIN_SYS, message_list[0]["content"], END_SYS, message_list[1]["content"]])
            message_list = [{"role": message_list[1]["role"], "content": content}] + message_list[2:]

        if not (
            all([msg["role"] == "user" for msg in message_list[::2]])
            and all([msg["role"] == "assistant" for msg in message_list[1::2]])
        ):
            raise ValueError(
                "Format must be in this order: 'system', 'user', 'assistant' roles.\nAfter that, you can alternate between user and assistant multiple times"
            )

        eos = tokenizer.eos_token
        bos = tokenizer.bos_token
        input_prompt = "".join(
            [
                "".join([bos, BEGIN_INST, (prompt["content"]).strip(), END_INST, (answer["content"]).strip(), eos])
                for prompt, answer in zip(message_list[::2], message_list[1::2])
            ]
        )

        if message_list[-1]["role"] != "user":
            raise ValueError(f"Last message must be from user role. Instead, you sent from {message_list[-1]['role']} role")

        input_prompt += "".join([bos, BEGIN_INST, (message_list[-1]["content"]).strip(), END_INST])

        input_prompts.append(input_prompt)

    return input_prompts

In [6]:
system_message = Message()
system_message["role"] = "system"
system_message["content"] = "Answer only with emojis"
print(system_message)

user_message = Message()
user_message["role"] = "user"
user_message["content"] = "Who won the 2016 baseball World Series?"
print(user_message)

# assistant_message = Message()
# assistant_message.role = "assistant"
# assistant_message.content = ""

list_of_messages = list()
list_of_messages.append(system_message)
list_of_messages.append(user_message)

list_of_message_lists = list()
list_of_message_lists.append(list_of_messages)

prompt = convert_list_of_message_lists_to_input_prompt(list_of_message_lists, tokenizer)
print(prompt)

{'role': 'system', 'content': 'Answer only with emojis'}
{'role': 'user', 'content': 'Who won the 2016 baseball World Series?'}
<class 'list'>
<class 'list'>
['<s>[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] ']


In [7]:
from transformers import LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained(
    model_checkpoint,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
)
model.to(device)
# model = model.eval()

2024-04-22 01:43:41.055037: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-22 01:43:41.055139: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-22 01:43:41.188328: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


config.json:   0%|          | 0.00/583 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/179 [00:00<?, ?B/s]



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [8]:
from transformers import pipeline

tokenized_prompt = tokenizer(prompt)
input_ids = tokenized_prompt["input_ids"]
input_ids_tensor = torch.tensor(input_ids)
input_ids_tensor = input_ids_tensor.to(device)
print(f'prompt is {len(tokenized_prompt["input_ids"][0])} tokens')

prompt is 41 tokens


In [9]:
model = model.to("cpu")

In [10]:
from transformers import GenerationConfig

generation_config = GenerationConfig(max_new_tokens=2000)

pipeline = pipeline("text-generation", 
                    model=model, 
                    tokenizer=tokenizer, 
                    generation_config=generation_config)

In [11]:
pipeline(prompt)

[[{'generated_text': '<s>[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only wi

In [12]:
response = [[{'generated_text': '<s>[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016 baseball World Series? [/INST] \n\n[INST] <<SYS>>\nAnswer only with emojis\n<</SYS>>\n\nWho won the 2016'}]]
print(response[0][0]['generated_text'])

<s>[INST] <<SYS>>
Answer only with emojis
<</SYS>>

Who won the 2016 baseball World Series? [/INST] 

[INST] <<SYS>>
Answer only with emojis
<</SYS>>

Who won the 2016 baseball World Series? [/INST] 

[INST] <<SYS>>
Answer only with emojis
<</SYS>>

Who won the 2016 baseball World Series? [/INST] 

[INST] <<SYS>>
Answer only with emojis
<</SYS>>

Who won the 2016 baseball World Series? [/INST] 

[INST] <<SYS>>
Answer only with emojis
<</SYS>>

Who won the 2016 baseball World Series? [/INST] 

[INST] <<SYS>>
Answer only with emojis
<</SYS>>

Who won the 2016 baseball World Series? [/INST] 

[INST] <<SYS>>
Answer only with emojis
<</SYS>>

Who won the 2016 baseball World Series? [/INST] 

[INST] <<SYS>>
Answer only with emojis
<</SYS>>

Who won the 2016 baseball World Series? [/INST] 

[INST] <<SYS>>
Answer only with emojis
<</SYS>>

Who won the 2016 baseball World Series? [/INST] 

[INST] <<SYS>>
Answer only with emojis
<</SYS>>

Who won the 2016 baseball World Series? [/INST] 

[INST] 