In [1]:
from mistral_common.protocol.instruct.tool_calls import Function, Tool

from mistral_inference.model import Transformer
from mistral_inference.generate import generate

from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest



In [None]:
tokenizer = MistralTokenizer.from_file("./mistral_models/7B_instruct/tokenizer.model.v3")  # change to extracted tok>
model = Transformer.from_folder("./mistral_models/7B_instruct")  # change to extracted model dir


In [3]:
completion_request = ChatCompletionRequest(
    tools=[
        Tool(
            function=Function(
                name="retrieve_payment_status",
                description="Get payment status of a transaction",
                parameters={
                    "type": "object",
                    "properties": {
                        "transaction_id": {
                        "type": "string",
                        "description": "The transaction id.",
                        }
                    },
                    "required": ["transaction_id"],
                },
            )
        ),
        Tool(
            function=Function(
                name="retrieve_payment_date",
                description="Get payment date of a transaction",
                parameters={
                    "type": "object",
                    "properties": {
                        "transaction_id": {
                        "type": "string",
                        "description": "The transaction id.",
                        }
                    },
                    "required": ["transaction_id"],
                },
            )
        )
    ],
    messages=[
        UserMessage(content="What's the status of my transaction T1003?")
        ],
)


In [4]:
tokens = tokenizer.encode_chat_completion(completion_request).tokens


In [5]:
# max_tokens control the output of model, set it based on required function call.
out_tokens, _ = generate([tokens], model, max_tokens=30, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)


In [6]:
result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
result = result.strip()

In [7]:
print(result)

[{"name": "retrieve_payment_status", "arguments": {"transaction_id": "T1005"}}]


In [8]:
import pandas as pd

# Assuming we have the following data
data = {
    'transaction_id': ['T1001', 'T1002', 'T1003', 'T1004', 'T1005'],
    'customer_id': ['C001', 'C002', 'C003', 'C002', 'C001'],
    'payment_amount': [125.50, 89.99, 120.00, 54.30, 210.20],
    'payment_date': ['2021-10-05', '2021-10-06', '2021-10-07', '2021-10-05', '2021-10-08'],
    'payment_status': ['Paid', 'Unpaid', 'Paid', 'Paid', 'Pending']
}

# Create DataFrame
df = pd.DataFrame(data)

In [9]:
def retrieve_payment_status(df: data, transaction_id: str) -> str:
    if transaction_id in df.transaction_id.values: 
        return json.dumps({'status': df[df.transaction_id == transaction_id].payment_status.item()})
    return json.dumps({'error': 'transaction id not found.'})

def retrieve_payment_date(df: data, transaction_id: str) -> str:
    if transaction_id in df.transaction_id.values: 
        return json.dumps({'date': df[df.transaction_id == transaction_id].payment_date.item()})
    return json.dumps({'error': 'transaction id not found.'})



In [10]:
import functools

names_to_functions = {
  'retrieve_payment_status': functools.partial(retrieve_payment_status, df=df),
  'retrieve_payment_date': functools.partial(retrieve_payment_date, df=df)
}


In [11]:
import json
tool_call = json.loads(result)
function_name = tool_call[0]["name"]
function_params = (tool_call[0]["arguments"]) 
print("\nfunction_name: ", function_name, "\nfunction_params: ", function_params)


function_name:  retrieve_payment_status 
function_params:  {'transaction_id': 'T1005'}


In [12]:
function_result = names_to_functions[function_name](**function_params)
function_result

'{"status": "Pending"}'