# GraniteSpeech Inference with FMS

This notebook demonstrates speech-to-text inference using GraniteSpeech model in FMS.

## Environment Setup (External Environment)

In [None]:
!git clone https://github.com/foundation-model-stack/foundation-model-stack.git
%cd foundation-model-stack
!git checkout main

In [None]:
!pip install -e .
!pip install datasets soundfile torchaudio huggingface_hub peft torchcodec

## Imports

In [None]:
from typing import Dict, Any, Optional, Tuple

import torch
from datasets import load_dataset
from huggingface_hub import snapshot_download

from fms.models import get_model
from fms.models.granite_speech import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor
from fms.utils.generation import generate
from fms.utils.tokenizers import get_tokenizer

## Model Configuration

In [None]:
MODEL_CONFIGS = {
    "3.3-8b": {
        "model_id": "ibm-granite/granite-speech-3.3-8b",
        "variant": "3.3-8b",
        "ignore_patterns": None,
    },
    "3.3-2b": {
        "model_id": "ibm-granite/granite-speech-3.3-2b",
        "variant": "3.3-2b",

        # TODO: Remove once IBM deletes orphaned 3-shard files from HF repo
        "ignore_patterns": ["*-of-00003.safetensors"],
    },
}

## Prompt Configuration

In [None]:
# Default prompts for GraniteSpeech
SYSTEM_PROMPT = """Knowledge Cutoff Date: April 2024.
Today's Date: April 9, 2025.
You are Granite, developed by IBM. You are a helpful AI assistant"""

USER_PROMPT = "<|audio|>can you transcribe the speech into a written format?"


def build_chat_prompt(
    tokenizer: Any,
    system_prompt: str = SYSTEM_PROMPT,
    user_prompt: str = USER_PROMPT,
) -> str:
    """Build a chat-formatted prompt using the tokenizer's chat template."""
    chat = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

## Helper Functions

In [None]:
def get_model_and_tokenizer(
    model_id: str,
    variant: str,
    device: str = "cuda",
    dtype: torch.dtype = torch.bfloat16,
    ignore_patterns: Optional[list] = None,
) -> Tuple[torch.nn.Module, Any]:
    model_path = snapshot_download(model_id, ignore_patterns=ignore_patterns)
    model = get_model(
        "granite_speech",
        variant,
        model_path=model_path,
        source="hf",
        device_type=device,
        data_type=dtype,
    )
    model.eval()
    tokenizer_wrapper = get_tokenizer(model_id)
    # Extract underlying HF tokenizer if wrapped
    hf_tokenizer = getattr(tokenizer_wrapper, 'tokenizer', tokenizer_wrapper)
    return model, hf_tokenizer

In [None]:
def get_audio(
    dataset_name: str = "hf-internal-testing/librispeech_asr_dummy",
    split: str = "validation",
    sample_index: int = 0,
) -> Tuple[torch.Tensor, str, float]:
    """Load audio sample from LibriSpeech dataset."""
    dataset = load_dataset(dataset_name, "clean", split=split)
    sample = dataset[sample_index]
    audio = torch.tensor(sample["audio"]["array"], dtype=torch.float32)
    ground_truth = sample["text"]
    duration = len(audio) / sample["audio"]["sampling_rate"]
    return audio, ground_truth, duration

In [None]:
def generate_transcript(
    model: torch.nn.Module,
    tokenizer: Any,
    inputs: Dict[str, torch.Tensor],
    max_new_tokens: int = 200,
) -> str:
    """Generate transcription from audio inputs."""
    with torch.no_grad():
        output_ids = generate(
            model,
            inputs["input_ids"],
            max_new_tokens=max_new_tokens,
            do_sample=False,
            use_cache=True,
            extra_kwargs={
                "input_features": inputs["input_features"],
                "input_features_mask": inputs.get("input_features_mask"),
                "attention_mask": inputs.get("attention_mask"),
            },
        )
    # Strip input tokens to get only the generated response
    input_length = inputs["input_ids"].shape[1]
    new_tokens = output_ids[:, input_length:]
    return tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0]

In [None]:
def process_inputs(
    audio: torch.Tensor,
    tokenizer: Any,
    prompt: str,
    device: str,
) -> Dict[str, torch.Tensor]:
    processor = GraniteSpeechProcessor(GraniteSpeechFeatureExtractor(), tokenizer)
    inputs = processor(text=[prompt], audio=audio, return_tensors="pt")
    return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

In [None]:
def run_inference(
    model_config: str = "3.3-8b",
    device: str = "cuda",
    dtype: torch.dtype = torch.bfloat16,
    sample_index: int = 0,
    system_prompt: str = SYSTEM_PROMPT,
    user_prompt: str = USER_PROMPT,
    max_new_tokens: int = 200,
) -> Dict[str, Any]:
    """Run speech-to-text inference with proper chat template formatting."""
    config = MODEL_CONFIGS[model_config]
    
    print(f"Loading model: {config['model_id']}")
    model, tokenizer = get_model_and_tokenizer(
        model_id=config["model_id"],
        variant=config["variant"],
        device=device,
        dtype=dtype,
        ignore_patterns=config["ignore_patterns"],
    )
    
    print("Loading audio...")
    audio, ground_truth, duration = get_audio(sample_index=sample_index)
    
    # Build chat-formatted prompt
    prompt = build_chat_prompt(tokenizer, system_prompt, user_prompt)
    
    print("Processing inputs...")
    inputs = process_inputs(audio, tokenizer, prompt, device)
    
    print("Generating transcription...")
    transcription = generate_transcript(model, tokenizer, inputs, max_new_tokens)
    
    print(f"\n{'='*60}")
    print(f"Ground Truth:  {ground_truth}")
    print(f"{'='*60}")
    print(f"Transcription: {transcription.upper()}")
    print(f"{'='*60}")
    
    return {"ground_truth": ground_truth, "transcription": transcription}

## 8B Model Inference

In [None]:
# Run inference on multiple samples
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

# Test with sample indices 2, 5, 6 (matching reference script)
results_8b = []
for idx in [2, 5, 6]:
    print(f"\n{'#'*30} SAMPLE {idx} {'#'*30}")
    result = run_inference(
        model_config="3.3-8b",
        device=device,
        dtype=dtype,
        sample_index=idx,
    )
    results_8b.append(result)

## 2B Model Inference

In [None]:
# Run inference on multiple samples with 2B model
results_2b = []
for idx in [2, 5, 6]:
    print(f"\n{'#'*30} SAMPLE {idx} {'#'*30}")
    result = run_inference(
        model_config="3.3-2b",
        device=device,
        dtype=dtype,
        sample_index=idx,
    )
    results_2b.append(result)