In [1]:
from transformers import VoxtralForConditionalGeneration, AutoProcessor, TextStreamer, TextIteratorStreamer
import torch
from threading import Thread

In [2]:
# device = "cpu"
device = "cuda"
repo_id = "mistralai/Voxtral-Mini-3B-2507"
# repo_id = "mistralai/Voxtral-Small-24B-2507"
max_new_tokens = 25000

In [3]:
processor = AutoProcessor.from_pretrained(repo_id)
model = VoxtralForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map=device)

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
from nbdev.config import get_config
from pathlib import Path

config = get_config()
project_dir = config.config_path
test_dir = project_dir/"./test_files/"
audio_path = test_dir/"short_test_audio.mp3"
# audio_path = test_dir/"02 - 1. Laying Plans.mp3"
assert audio_path.exists()

## Standard

In [5]:
inputs = processor.apply_transcription_request(language="en", 
                                               audio=str(audio_path), 
                                               model_id=repo_id, 
                                              )
inputs = inputs.to(device, dtype=torch.bfloat16)

with torch.no_grad():
    chunk_outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        streamer=None,
        do_sample=False
    )
torch.cuda.empty_cache()
result = processor.batch_decode(chunk_outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
result

"November the 10th, Wednesday, 9 p.m. I'm standing in a dark alley. After waiting several hours, the time has come. A woman with long dark hair approaches. I have to act and fast before she realises what has happened. I must find out."

## Standard with Streamer

In [6]:
streamer = TextStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)

In [7]:
inputs = processor.apply_transcription_request(language="en", 
                                               audio=str(audio_path), 
                                               model_id=repo_id, 
                                              )
inputs = inputs.to(device, dtype=torch.bfloat16)

with torch.no_grad():
    chunk_outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        streamer=streamer,
        do_sample=False
    )
torch.cuda.empty_cache()
result = processor.batch_decode(chunk_outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]

November the 10th, Wednesday, 9 p.m. I'm standing in a dark alley. After waiting several hours, the time has come. A woman with long dark hair approaches. I have to act and fast before she realises what has happened. I must find out.


## Non-blocking with Streamer

In [8]:
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)

In [9]:
def generate_text_stream(model, inputs, max_new_tokens, streamer, return_full=False):
    generation_kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        streamer=streamer,
        do_sample=False
    )
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    generated_text = ""
    for new_text in streamer:
        generated_text += new_text
        yield new_text
    
    thread.join()
    
    # Optionally return the complete text at the end
    if return_full:
        yield {"complete_text": generated_text}
    torch.cuda.empty_cache()

In [10]:
# Usage example:
for text_chunk in generate_text_stream(model, inputs, max_new_tokens, streamer):
    print(text_chunk, end="", flush=True)
    # Or in a web framework like FastAPI/Flask, you could stream this directly

November the 10th, Wednesday, 9 p.m. I'm standing in a dark alley. After waiting several hours, the time has come. A woman with long dark hair approaches. I have to act and fast before she realises what has happened. I must find out.

In [11]:
del model
torch.cuda.empty_cache()