In [None]:
import json
import logging
import os
import warnings
from pathlib import Path
from typing import List, Optional, Type, Union

import fire  # type: ignore
import torch
import torch.distributed as dist
from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import Tokenizer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer, SpecialTokenPolicy
from mistral_common.tokens.tokenizers.sentencepiece import is_sentencepiece
from mistral_common.tokens.tokenizers.tekken import is_tekken

from mistral_inference.generate import generate, generate_mamba
from mistral_inference.mamba import Mamba
from mistral_inference.transformer import Transformer

In [None]:
def interactive(
    model_path: str,
    max_tokens: int = 35,
    temperature: float = 0.7,
    num_pipeline_ranks: int = 1,
    instruct: bool = False,
    lora_path: Optional[str] = None,
) -> None:
    if is_torchrun():
        torch.distributed.init_process_group()
        torch.cuda.set_device(torch.distributed.get_rank())
        should_print = torch.distributed.get_rank() == 0

        num_pipeline_ranks = torch.distributed.get_world_size()
    else:
        should_print = True
        num_pipeline_ranks = 1

    mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path))
    tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer

    model_cls = get_model_cls(model_path)
    model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks)

    # load LoRA
    if lora_path is not None:
        model.load_lora(Path(lora_path))

    prompt: str = ""
    messages: List[UserMessage | AssistantMessage] = []

    while True:
        if should_print:
            user_input = input("Prompt: ")

            if instruct:
                messages += [UserMessage(content=user_input)]
                chat_completion_request = ChatCompletionRequest(messages=messages)

                tokens = mistral_tokenizer.encode_chat_completion(chat_completion_request).tokens
            else:
                prompt += user_input

                tokens = tokenizer.encode(prompt, bos=True, eos=False)

            length_tensor = torch.tensor([len(tokens)], dtype=torch.int)
        else:
            length_tensor = torch.tensor([0], dtype=torch.int)

        if is_torchrun():
            dist.broadcast(length_tensor, src=0)

        if not should_print:
            tokens = int(length_tensor.item()) * [0]

        generate_fn = generate if isinstance(model, Transformer) else generate_mamba
        generated_tokens, _ = generate_fn(  # type: ignore[operator]
            [tokens],
            model,
            max_tokens=max_tokens,
            temperature=temperature,
            eos_id=tokenizer.eos_id,
        )

        answer = tokenizer.decode(generated_tokens[0])

        if should_print:
            print(answer)
            print("=====================")

        if instruct:
            messages += [AssistantMessage(content=answer)]
        else:
            prompt += answer
