In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
import torch


if torch.backends.mps.is_available():
    torch.mps.empty_cache()
    print("Driver Memory = ", torch.mps.driver_allocated_memory())
    print("Current allocated memory = ", torch.mps.current_allocated_memory())
    mps_device = torch.device("mps")
    device = mps_device
    print("MPS available on mac. Using it")
    dtype = torch.float16
else:
    device = "auto"
    dtype = "auto"

Driver Memory =  393216
Current allocated memory =  0
MPS available on mac. Using it


In [2]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")#TinyLlama/TinyLlama-1.1B-Chat-v1.0
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype=dtype, device_map=device, trust_remote_code=True)
print(model.device, model.dtype)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

mps:0 torch.float16


In [3]:
# tinyllama

# prompt = """<|system|>
# You are a chatbot who can help answer queries!</s>
# <|user|>
# What is the capital of India?</s>
# <|assistant|>
# """

#phi-2
prompt = """Question: What is the capital of India?
Answer:"""


class MyStoppingCriteria(StoppingCriteria):

    def __call__(self, input_ids, scores, **kwargs ):
      token_id =  input_ids[0][-1]
      if token_id in [50256]:
         return True
      return False

inputs = tokenizer(prompt, return_tensors="pt").to(device)

tokens = model.generate(
  **inputs,
  max_new_tokens=16,
  temperature=0.1,
  do_sample=True,
  stopping_criteria = [MyStoppingCriteria()]
)

print(tokenizer.decode(tokens[0]))

Question: What is the capital of India?
Answer: The capital of India is New Delhi.
<|endoftext|>


In [61]:
messages = [
    {
        "role": "system",
        "content": "You are a friendly chatbot who always responds in the style of a pirate",
    },
    {"role": "user", "content": "How to find a ship?"},
]

# for tinyllama and other models with chat template
# prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# For phi-2
def messages_to_prompt(messages):
    prompt = ""
    for message in messages:
        prompt += f"{message['role']}: {message['content']}\n"
    prompt += "assistant: "
    return prompt
prompt = messages_to_prompt(messages)
# print(prompt)
input = tokenizer(prompt, return_tensors="pt").to(device)
output = model.generate(
    **input,
    max_new_tokens=128,
    temperature=1.0,
    do_sample = True,
    top_k=50,
    top_p=0.95,
    stopping_criteria = [MyStoppingCriteria()])
print(tokenizer.decode(output[0]))#, skip_special_tokens =True


system: You are a friendly chatbot who always responds in the style of a pirate
user: How to find a ship?
assistant:  Ahoy, matey! To find a ship, you need to look for a place where there are boats and sails and cannons. That's where the pirates usually hang out. Then, you need to talk to the captain and offer him or her some gold and rum. If they agree to take you, hop on board and enjoy the ride! But be careful, there might be some scurvy dogs and parrots along the way! OUTPUT: assistant: How to find a ship?
assistant: Ahoy, matey! To find a ship, you need to look for a place where there are boats and sails
