In [8]:
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline
import torch

In [None]:
MODEL_NAME = "Mistral-7B-Instruct-v0.1"

In [9]:
def load_peft_model(base_model: str, peft_model: str):
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        return_dict=True,
        quantization_config=bnb_config,
        #device_map="auto",
        device_map="cuda:0",
        # trust_remote_code=True,
        local_files_only=True,
        # use_safetensors=True
    )
    if peft_model is not None:
        if os.path.exists(f"{peft_model}/adapter_config.json"):
            print(f"Loading PEFT {peft_model}")
            model.load_adapter(peft_model)
        else:
            print("WARNING: PEFT_MODEL NOT EXISTS!!!")
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer


In [7]:
def create_llama2_generation_prompt(system_message, question: str):
    if system_message is not None:
        return ("<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n{user_input} [/INST]"
                .format(system_message=system_message, user_input=question))
    prompt_template = """<s>[INST] {user_input} [/INST]"""
    return prompt_template.format(user_input=question)


def ask_llama2_instruction_prompt(model, generation_config, tokenizer, device, question: str):
    system_msg = (
        "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
        "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
        "Please ensure that your responses are socially unbiased and positive in nature.\n"
        "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
        "If you don't know the answer to a question, please don't share false information.")
    prompt = create_llama2_generation_prompt(system_msg, question)
    encoding = tokenizer(prompt, return_tensors="pt").to(device)

    outputs = model.generate(
        input_ids=encoding.input_ids,
        attention_mask=encoding.attention_mask,
        generation_config=generation_config
    )

    resp = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return resp.replace(prompt, "")

In [10]:
class LLM:
    def __init__(self):
        model, tokenizer, generation_config = self.load_model()
        self.model = model
        self.tokenizer = tokenizer
        self.generation_config = generation_config
        self.device = 'cuda'

    def load_model(self):
        base_model = f"../models/{MODEL_NAME}"
        model, tokenizer = load_peft_model(base_model, None)
        generation_config = model.generation_config
        generation_config.max_new_tokens = 1024
        generation_config.temperature = 0.2
        generation_config.do_sample = True
        generation_config.top_p = 0.9
        generation_config.num_return_sequences = 1
        generation_config.pad_token_id = tokenizer.eos_token_id
        generation_config.eos_token_id = tokenizer.eos_token_id
        return model, tokenizer, generation_config

    def ask(self, user_input: str):
        answer = ask_llama2_instruction_prompt(model=self.model,
                                               generation_config=self.generation_config,
                                               tokenizer=self.tokenizer,
                                               device=self.device,
                                               question=user_input)
        return answer


In [None]:
llm = LLM()

