In [None]:
!pip install -q -U transformers peft torch accelerate einops sentencepiece bitsandbytes

In [None]:
import torch
from peft import PeftModel, PeftConfig
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

In [None]:
peft_model_id = "dfurman/Mixtral-8x7B-peft-v0.1"
config = PeftConfig.from_pretrained(peft_model_id)

tokenizer = AutoTokenizer.from_pretrained(
    peft_model_id,
    use_fast=True,
    trust_remote_code=True,
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(model, peft_model_id)

In [None]:
messages = [
    {"role": "user", "content": "Tell me a recipe for a mai tai."},
]

print("\n\n*** Prompt:")
input_ids = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    return_tensors="pt",
)
print(tokenizer.decode(input_ids[0]))

print("\n\n*** Generate:")
with torch.autocast("cuda", dtype=torch.bfloat16):
    output = model.generate(
        input_ids=input_ids.to("cuda"),
        max_new_tokens=1024,
        return_dict_in_generate=True,
    )

response = tokenizer.decode(
    output["sequences"][0][len(input_ids[0]) :], skip_special_tokens=True
)
print(response)

In [None]:
messages = [
    {
        "role": "user",
        "content": "Recommend some games to play for 3 year old and 7 year olds.",
    },
]

print("\n\n*** Prompt:")
input_ids = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    return_tensors="pt",
)
print(tokenizer.decode(input_ids[0]))

print("\n\n*** Generate:")
with torch.autocast("cuda", dtype=torch.bfloat16):
    output = model.generate(
        input_ids=input_ids.to("cuda"),
        max_new_tokens=1024,
        return_dict_in_generate=True,
    )

response = tokenizer.decode(
    output["sequences"][0][len(input_ids[0]) :], skip_special_tokens=True
)
print(response)