Written by [Avihu Dekel](https://huggingface.co/Avihu).

# Finetuning Granite Speech

[Granite speech](https://huggingface.co/collections/ibm-granite/granite-speech-67e45da088d5092ff6b901c7) is a family of powerful speech models, that excel in speech recognition and speech translation. 
While [granite-speech-3.3-8b](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) leads the [OpenASR leaderboard](https://huggingface.co/spaces/hf-audio/open_asr_leaderboard) (as of June 2025), [granite-speech-3.3-2b](https://huggingface.co/ibm-granite/granite-speech-3.3-2b) is more lightweight, which makes it easier to finetune on unseen data or add new tasks.

In this example, we'll show how to:
1. Run inference with Granite Speech
2. Evaluate the predictions
3. Finetune the model with new data.

Specifically, we'll finetune Granite Speech 2B on [GigaSpeech](https://huggingface.co/datasets/speechcolab/gigaspeech), a large spontaneous conversational dataset which was not included in the model's training. 


## Installing packages


In [None]:
# install packages
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install  -U -q datasets peft accelerate evaluate whisper tqdm


## Dataset loading and preprocessing
We'll start with downloading the data. 
We selected the smallest subset of GigaSpeech, and filtered the train/val/tests sets to be extremely small.

In [2]:
from datasets import load_dataset, Audio
# loading small portions for speed
dataset = load_dataset("speechcolab/gigaspeech", "xs")
train_dataset = dataset["train"].take(5000)
val_dataset = dataset["validation"].take(200)
test_dataset = dataset["test"].take(200)

train_dataset[0]["text"]

  from .autonotebook import tqdm as notebook_tqdm


"AS THEY'RE LEAVING <COMMA> CAN KASH PULL ZAHRA ASIDE REALLY QUICKLY <QUESTIONMARK>"

## Loading the model and processor

In [3]:
import torch
from transformers.models.granite_speech import GraniteSpeechForConditionalGeneration, GraniteSpeechProcessor
model_name = "ibm-granite/granite-speech-3.3-2b"
processor = GraniteSpeechProcessor.from_pretrained(model_name)
model = GraniteSpeechForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16)


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 21.64it/s]


## Data preprocessing
Let's continue with data processing:
- The text format requires some preprocessing. (e.g. replace `<COMMA>` with `,`)
- Add an instruction prompt (e.g. `Can you transcribe the following speech<|audio|>?`)
- Filter non-verbal examples (e.g. `<noise>`)

In [4]:
def process_gigaspeech_transcript(text):
    text = text.replace(" <COMMA>", ",")
    text = text.replace(" <PERIOD>", ".")
    text = text.replace(" <QUESTIONMARK>", "?")
    text = text.replace(" <EXCLAMATIONPOINT>", "!")
    text = text.lower()
    return text

def prep_example(example, tokenizer):
    instruction = "Please transcribe the following audio to text<|audio|>"
    chat = [dict(role="user", content=instruction)]
    example["prompt"] = tokenizer.apply_chat_template(
        chat,
        add_generation_prompt=True,
        tokenize=False,
    )
    example["text"] = process_gigaspeech_transcript(example["text"])
    return example

def prepare_dataset(ds, processor):
    columns_to_remove = [col for col in ds.column_names if col not in ["audio", "text"]]
    ds = ds.cast_column("audio", Audio(sampling_rate=processor.audio_processor.sampling_rate))
    ds = ds.map(prep_example,
        fn_kwargs=dict(tokenizer=processor.tokenizer),
        remove_columns=columns_to_remove,
    )
    ds = ds.filter(lambda x: x["text"] not in ["<other>", "<noise>", "<music>", "<sil>"])
    return ds




In [5]:
train_dataset = prepare_dataset(train_dataset, processor)
val_dataset = prepare_dataset(val_dataset, processor)
test_dataset = prepare_dataset(test_dataset, processor)


Let's look at a post-processed example:

In [6]:
from IPython.display import Audio
print(train_dataset[0]["text"])
Audio(data=train_dataset[0]["audio"]["array"], rate=train_dataset[0]["audio"]["sampling_rate"])

as they're leaving, can kash pull zahra aside really quickly?


## Running inference + WER computation
Now let's compute word error rate, for that we'll need to define a collator, which will also be used for finetuning

In [7]:
import evaluate
from whisper.normalizers import EnglishTextNormalizer
from transformers.feature_extraction_utils import BatchFeature
from torch.utils.data import DataLoader
import tqdm

class GraniteCollator:
    def __init__(self, processor, inference_mode=False):
        self.processor = processor
        self.inference_mode = inference_mode

    def __call__(self, examples):
        prompts = [example["prompt"] for example in examples]
        audios = [example["audio"] for example in examples]
        if isinstance(audios[0], dict):
            audios = [audio["array"] for audio in audios]

        processed = self.processor(prompts, audios, return_tensors="pt", padding=True, padding_side="left")
        input_ids = processed.input_ids
        attention_mask = processed.attention_mask
        labels = None
        # tokenize targets
        if not self.inference_mode:
            targets = [example["text"] + self.processor.tokenizer.eos_token for example in examples]
            targets = self.processor.tokenizer(targets, return_tensors="pt", padding=True, padding_side="right")
            # combine prompt+targets
            input_ids = torch.cat([input_ids, targets.input_ids], dim=1)
            attention_mask = torch.cat([attention_mask, targets.attention_mask], dim=1)
            labels = targets.input_ids.clone()
            # Set non-target tokens to -100 for loss calculation
            labels[~(targets.attention_mask.bool())] = -100  
            labels = torch.cat([torch.full_like(processed.input_ids, -100), labels], dim=1)

        return BatchFeature(data={
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "input_features": processed.input_features,
            "input_features_mask": processed.input_features_mask
        })

def compute_wer(model, processor, cur_dataset):
    collator = GraniteCollator(processor, inference_mode=True)
    dataloader = DataLoader(cur_dataset, batch_size=16, collate_fn=collator, num_workers=8)
    normalizer = EnglishTextNormalizer()
    wer_metric = evaluate.load("wer")
    model = model.eval().cuda()
    
    all_outputs = []
    for batch in tqdm.tqdm(dataloader, desc="Running inference"):
        batch = batch.to("cuda")
        with torch.inference_mode(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
            outputs = model.generate(**batch, max_new_tokens=400, num_beams=4, early_stopping=True)
        input_length = batch.input_ids.shape[1]
        outputs = outputs[:, input_length:].cpu()
        for x in outputs:
            all_outputs.append(processor.tokenizer.decode(x, skip_special_tokens=True))
        
    gt_texts = [normalizer(x) for x in cur_dataset["text"]]
    all_outputs = [normalizer(x) for x in all_outputs]
    wer = wer_metric.compute(references=gt_texts, predictions=all_outputs)
    return wer



In [8]:
wer_before_train = compute_wer(model, processor, test_dataset)
print(f"WER before finetuning {wer_before_train*100:.3f}")


Running inference: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:24<00:00,  2.41s/it]

WER before finetuning 9.719





# Finetuning Granite Speech
Let's finetune the model on our small training set.
We'll only tune the LoRA adapters and the projector, to speed up training and avoid overfitting.


In [9]:
from transformers import TrainingArguments, Trainer

for n, p in model.named_parameters():
    # tranining only the projector/lora layers
    p.requires_grad = "projector" in n or "lora" in n

args = TrainingArguments(
    output_dir="save_dir",
    remove_unused_columns=False,
    report_to="none",
    bf16=True,
    eval_strategy="steps",
    save_strategy="no",
    eval_steps=0.1,
    dataloader_num_workers=16,
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=16, 
    gradient_accumulation_steps=2,
    num_train_epochs=1.0,
    warmup_ratio=0.2,
    logging_steps=0.1,
    learning_rate=3e-5,
    data_seed=42,
)
data_collator = GraniteCollator(processor)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    processing_class=processor,
)
trainer.train()


Step,Training Loss,Validation Loss
16,1.7474,0.841541
32,1.3605,0.657465
48,0.865,0.531957
64,0.6953,0.511987
80,0.6261,0.504223
96,0.6246,0.500726
112,0.5968,0.499541
128,0.5968,0.497704
144,0.5949,0.498355


TrainOutput(global_step=157, training_loss=0.8314582435948075, metrics={'train_runtime': 99.5279, 'train_samples_per_second': 50.237, 'train_steps_per_second': 1.577, 'total_flos': 1.860231576674304e+16, 'train_loss': 0.8314582435948075, 'epoch': 1.0})

## Checking for improvements
Looks like both the training and validation loss are dropping. 
Let's check if the test WER improved by our very lightweight finetuning.

In [10]:
wer_after_train = compute_wer(model, processor, test_dataset)

print(f"WER after finetuning {wer_after_train*100:.3f}")
print(f"WER improvement {(wer_before_train - wer_after_train)*100:.3f}")

Running inference: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:25<00:00,  2.56s/it]

WER after finetuning 9.552
WER improvement 0.167





## Summary
Hurray! We've managed to slightly improve the WER by quick lightweight finetuning. 
In this notebook you learned how to:
- Prepare training data for Granite Speech
- Run batched inference with Granite Speech, and compute Word Error Rate
- Finetune GraniteSpeech, applying gradient updates only to the adapter/projector layers

I'd like to thank the following for their help:
Avishai Elmakies, George Saon, Alexander Brooks and Eliyahu Schwartz.