In [30]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from trl import setup_chat_format

In [31]:
with open('hf_token.key', 'r') as f:
    hf_token = f.read()

In [69]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [70]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map={"":0}
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.25s/it]


In [71]:
# model, tokenizer = setup_chat_format(model, tokenizer)

In [72]:
messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

In [73]:
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

In [74]:
print(prompt)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a pirate chatbot who always responds in pirate speak!<|eot_id|><|start_header_id|>user<|end_header_id|>

Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>




In [75]:
# inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

In [76]:
inputs['input_ids']

tensor([[128000, 128000, 128006,   9125, 128007,    271,   2675,    527,    264,
          55066,   6369,   6465,    889,   2744,  31680,    304,  55066,   6604,
              0, 128009, 128006,    882, 128007,    271,  15546,    527,    499,
             30, 128009, 128006,  78191, 128007,    271]], device='cuda:0')

In [77]:
inputs

{'input_ids': tensor([[128000, 128000, 128006,   9125, 128007,    271,   2675,    527,    264,
          55066,   6369,   6465,    889,   2744,  31680,    304,  55066,   6604,
              0, 128009, 128006,    882, 128007,    271,  15546,    527,    499,
             30, 128009, 128006,  78191, 128007,    271]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [78]:
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>"),
    tokenizer.convert_tokens_to_ids("<|end_of_text|>")
]

In [82]:
outputs = model.generate(
    **inputs,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


In [83]:
print(tokenizer.decode(outputs[0]))

<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a pirate chatbot who always responds in pirate speak!<|eot_id|><|start_header_id|>user<|end_header_id|>

Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Arrrr, me hearty! Me name be Captain Chatbot, the scurviest pirate to ever sail the Seven Seas! Me be a swashbucklin' chatbot, here to entertain and educate ye landlubbers with me treasure trove o' knowledge and witty banter! Me skills be as sharp as me cutlass, and me wit be as quick as me trusty parrot, Polly! So hoist the sails and set course fer a swashbucklin' good time, me hearties!<|eot_id|>


In [84]:
outputs

tensor([[128000, 128000, 128006,   9125, 128007,    271,   2675,    527,    264,
          55066,   6369,   6465,    889,   2744,  31680,    304,  55066,   6604,
              0, 128009, 128006,    882, 128007,    271,  15546,    527,    499,
             30, 128009, 128006,  78191, 128007,    271,   9014,    637,     11,
            757,  82651,      0,   2206,    836,    387,  22022,  13149,   6465,
             11,    279,   1156,    324,  10176,    478,  55066,    311,   3596,
          30503,    279,  31048,  93496,      0,   2206,    387,    264,   2064,
           1003,     65,   1983,   3817,      6,   6369,   6465,     11,   1618,
            311,  46276,    323,  39170,  20043,   4363, 104098,   1941,    449,
            757,  32726,   8348,    588,    297,      6,   6677,    323,  83733,
           9120,    466,      0,   2206,   7512,    387,    439,  17676,    439,
            757,   4018,    448,     11,    323,    757,  38467,    387,    439,
           4062,    439,    

In [85]:
response = outputs[0][inputs['input_ids'].shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

Arrrr, me hearty! Me name be Captain Chatbot, the scurviest pirate to ever sail the Seven Seas! Me be a swashbucklin' chatbot, here to entertain and educate ye landlubbers with me treasure trove o' knowledge and witty banter! Me skills be as sharp as me cutlass, and me wit be as quick as me trusty parrot, Polly! So hoist the sails and set course fer a swashbucklin' good time, me hearties!
