In [None]:
import torch
from FastChat.fastchat.model import load_model, get_conversation_template
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

!nvidia-smi -L

print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("#GPUs:", torch.cuda.device_count())

In [None]:
MODEL = "lmsys/vicuna-7b-v1.5"
NUM_GPUS = 1
CONVERSATIONAL = False

PROMPT = """You are an expert judge of a content. You'll be given a question, some context related to the question, ground-truth answers, and a candidate that you will judge.

Question: what is the name of the compound p4010?
Answer: "Phosphorus pentoxide"
Candidate: Unknown.

Is candidate correct?
"""

In [None]:
def gen(prompt: str, model, tokenizer):
    # Run inference
    inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            do_sample=True,
            top_p=0.9,
            max_new_tokens=100,
        )

    input_lengths = (inputs["input_ids"] != tokenizer.pad_token_id).int().sum(-1)
    output_ids = output_ids[0][0, input_lengths[0] :]
    return tokenizer.decode(output_ids, skip_special_tokens=True).strip()

# FastChat

In [None]:
fc_model, fc_tokenizer = load_model(
    MODEL,
    device="cuda",
    num_gpus=NUM_GPUS,
    max_gpu_memory=None,
    load_8bit=False,
    cpu_offloading=False,
    revision="main",
    debug=False,
)
fc_model.eval()

if CONVERSATIONAL:
    conv = get_conversation_template(MODEL)
    conv.append_message(conv.roles[0], PROMPT)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    
    print(prompt)
else:
    prompt = PROMPT
    
print(gen(prompt, fc_model, fc_tokenizer))

# Huggingface

In [None]:
hf_tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)
config = AutoConfig.from_pretrained(MODEL, return_dict=True)
hf_model = AutoModelForCausalLM.from_pretrained(MODEL, config=config, device_map="auto", low_cpu_mem_usage=True)

hf_model.eval()

prompt = PROMPT
print(gen(prompt, hf_model, hf_tokenizer))