## 1. LLM

In [6]:
import torch
from transformers import BitsAndBytesConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms.huggingface_pipeline import HuggingFacePipeline

model_name: str = "/data_hdd_16t/khanhtran/LLM/.hf_models/Phi-3-mini-4k-instruct"

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=nf4_config,
    low_cpu_mem_usage=True,
    trust_remote_code=True
)


tokenizer = AutoTokenizer.from_pretrained(model_name)
max_new_token = 512

model_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=max_new_token,
    pad_token_id=tokenizer.eos_token_id
)

gen_kwargs = {
    "temperature": 0
}

llm = HuggingFacePipeline(
    pipeline=model_pipeline,
    model_kwargs=gen_kwargs
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.35s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The model 'Phi3ForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausa

## 2. Prompt

In [7]:
from langchain_core.prompts import PromptTemplate

prompt = PromptTemplate.from_template(""""<s>[INST] {prompt} [/INST]"""
)

## 3. Chain

In [8]:
chain = prompt | llm

## 4. History chat

In [10]:
def extract_answer(response: str) -> str:
    return response.split("[/INST]")[1].strip()

In [11]:
from langchain.memory import ChatMessageHistory

ephemeral_chat_history = ChatMessageHistory()

In [12]:
question_1 = "Translate this sentence from English to Vietnamese: I love programming."

In [13]:
ephemeral_chat_history.add_user_message(question_1)
ephemeral_chat_history.messages

[HumanMessage(content='Translate this sentence from English to Vietnamese: I love programming.', additional_kwargs={}, response_metadata={})]

## 3.1 Query first times

In [14]:
response = chain.invoke(
    {
        "prompt": ephemeral_chat_history.messages
    }
)

You are not running the flash-attention implementation, expect numerical differences.


In [15]:
response

'"<s>[INST] [HumanMessage(content=\'Translate this sentence from English to Vietnamese: I love programming.\', additional_kwargs={}, response_metadata={})] [/INST] earth'

In [16]:
answer = extract_answer(response)
print(answer)

earth


In [17]:
ephemeral_chat_history.add_ai_message(answer)
ephemeral_chat_history.messages

[HumanMessage(content='Translate this sentence from English to Vietnamese: I love programming.', additional_kwargs={}, response_metadata={}),
 AIMessage(content='earth', additional_kwargs={}, response_metadata={})]

## 3.2 Query second times

In [18]:
question_2 = "What did you said?"

In [19]:
ephemeral_chat_history.add_user_message(question_2)
ephemeral_chat_history.messages

[HumanMessage(content='Translate this sentence from English to Vietnamese: I love programming.', additional_kwargs={}, response_metadata={}),
 AIMessage(content='earth', additional_kwargs={}, response_metadata={}),
 HumanMessage(content='What did you said?', additional_kwargs={}, response_metadata={})]

In [20]:
response = chain.invoke(
    {
        "prompt": ephemeral_chat_history.messages
    }
)

In [21]:
answer = extract_answer(response)
print(answer)

count
