In [2]:
import logging
import os
from pathlib import Path

import torch
import torch.distributed as dist
from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage, SystemMessage, ToolMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import Tokenizer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer, SpecialTokenPolicy
from mistral_common.tokens.tokenizers.sentencepiece import is_sentencepiece
from mistral_common.tokens.tokenizers.tekken import is_tekken

from mistral_inference.generate import generate, generate_mamba
from mistral_inference.mamba import Mamba
from mistral_inference.transformer import Transformer

In [None]:
def load_tokenizer(model_path: Path) -> MistralTokenizer:
    tokenizer = [f for f in os.listdir(model_path) if is_tekken(model_path / f) or is_sentencepiece(model_path / f)]
    assert (
        len(tokenizer) > 0
    ), f"No tokenizer in {model_path}, place a `tokenizer.model.[v1,v2,v3]` or `tekken.json` file in {model_path}."
    assert (
        len(tokenizer) == 1
    ), f"Multiple tokenizers {', '.join(tokenizer)} found in `model_path`, make sure to only have one tokenizer"

    mistral_tokenizer = MistralTokenizer.from_file(str(model_path / tokenizer[0]))

    if isinstance(mistral_tokenizer.instruct_tokenizer.tokenizer, Tekkenizer):
        mistral_tokenizer.instruct_tokenizer.tokenizer.special_token_policy = SpecialTokenPolicy.KEEP

    logging.info(f"Loaded tokenizer of type {mistral_tokenizer.instruct_tokenizer.__class__}")

    return mistral_tokenizer

In [None]:
model_path = ""
max_tokens = 35,
temperature = 0.7,
lora_path = ""

mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path))
tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
model = Transformer.from_folder(Path(model_path), max_batch_size=3)

# load LoRA
if lora_path is not None:
    model.load_lora(Path(lora_path))

prompt: str = ""
messages: list[UserMessage | AssistantMessage | SystemMessage | ToolMessage] = [
    SystemMessage(content="You are a dog and you must only bark at any input and nothing else"),
]

while True:
    user_input = input("Prompt: ")

    messages += [UserMessage(content=user_input)]
    chat_completion_request = ChatCompletionRequest(messages=messages)

    tokens = mistral_tokenizer.encode_chat_completion(chat_completion_request).tokens
    length_tensor = torch.tensor([len(tokens)], dtype=torch.int)

    generate_fn = generate if isinstance(model, Transformer) else generate_mamba
    generated_tokens, _ = generate_fn(  # type: ignore[operator]
        [tokens],
        model,
        max_tokens=max_tokens,
        temperature=temperature,
        eos_id=tokenizer.eos_id,
    )

    answer = tokenizer.decode(generated_tokens[0])

    print(answer)
    print("=====================")

    messages += [AssistantMessage(content=answer)]