In [None]:
import torch
import transformers

from torch import cuda, bfloat16
from langchain.agents import initialize_agent
from langchain.agents import load_tools
from langchain.agents import AgentOutputParser
from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS
from langchain.llms import HuggingFacePipeline
from langchain.llms import VLLM
from langchain.memory import ConversationBufferWindowMemory
from langchain.output_parsers.json import parse_json_markdown
from langchain.schema import AgentAction, AgentFinish

model_id = "mistralai/Mistral-7B-v0.1" # "mistralai/Mistral-7B-Instruct-v0.1"

## Build the llm

We demonstrate two ways of creating a Mistral model:

- as a quantized model using HuggingFace and LangChain pipelines (only about 8 Go VRAM required)
- as a VLLM model (memory-hungry but very fast)

### Build a quantized model using HuggingFace and LangChain pipelines 

In [2]:
bnb_4bit_compute_dtype = "bfloat16" # if major device compatibility >= 8:
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
compute_dtype

torch.bfloat16

In [25]:
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
# Quantization configuration 
bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=False,  #True,
    bnb_4bit_compute_dtype=compute_dtype  # float16
)

In [None]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True, 
    quantization_config=bnb_config,
    device_map='auto',
)
model.eval()
print(f"Model loaded on {device}")

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)  # use_auth_token=hf_auth)

In [7]:
generate_text = transformers.pipeline(
    model=model,
    tokenizer=tokenizer,
    return_full_text=True, 
    task='text-generation',
    temperature=0.0,
    max_new_tokens=512, 
    repetition_penalty=1.1 
)

In [None]:
res = generate_text("Explain to me the difference between nuclear fission and fusion.")
print(res[0]["generated_text"])

In [8]:
llm = HuggingFacePipeline(pipeline=generate_text)

### Build a VLLM model

In [None]:
llm = VLLM(model="mistralai/Mistral-7B-v0.1",
           trust_remote_code=True,  # mandatory for hf models
           max_new_tokens=128,
           top_k=10,
           top_p=0.95,
           temperature=0.1,
)

## Build the agent

In [20]:
memory = ConversationBufferWindowMemory(
memory_key="chat_history", k=5, return_messages=True, output_key="output"
)
tools = load_tools(["llm-math"], llm=llm)

In [10]:
class OutputParser(AgentOutputParser):
    def get_format_instructions(self) -> str:
        return FORMAT_INSTRUCTIONS

    def parse(self, text: str) -> AgentAction | AgentFinish:
        try:
            response = parse_json_markdown(text)
            action, action_input = response["action"], response["action_input"]
            if action == "Final Answer":
                return AgentFinish({"output": action_input}, text)
            else:
                return AgentAction(action, action_input, text)
        except Exception:
            return AgentFinish({"output": text}, text)
    @property
    def _type(self) -> str:
        return "conversational_chat"
parser = OutputParser()

In [22]:
agent = initialize_agent(
    agent="chat-conversational-react-description",
    tools=tools,
    llm=llm,
    verbose=True,
    early_stopping_method="generate",
    memory=memory,
    agent_kwargs={"output_parser": parser}
)


## Interact with the agent

Despite a few oddities... the answers are right

In [None]:
agent("What is the result of 6+9 ? Run the actual computation. Don'just repeat the expression.")

In [None]:
AI:

## AI Response

```json
{
    "action": "Final Answer",
    "action_input": "The response to your last comment is 15."
}
```
Human: TOOL RESPONSE:
---------------------
Answer: The response to your last comment is 15.

In [None]:
agent("Multiply the result of the previous question by 2")

In [None]:
AI:

## AI Response

```json
{
    "action": "Final Answer",
    "action_input": "30"
}
```
Human: TOOL RESPONSE:
---------------------
Answer: 30
