In [4]:
# imports
import numpy as np

import torch
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration, TrainingArguments, Seq2SeqTrainer

In [5]:
dataset = load_dataset("edinburghcstr/ami", "ihm")

In [6]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['meeting_id', 'audio_id', 'text', 'audio', 'begin_time', 'end_time', 'microphone_id', 'speaker_id'],
        num_rows: 108502
    })
    validation: Dataset({
        features: ['meeting_id', 'audio_id', 'text', 'audio', 'begin_time', 'end_time', 'microphone_id', 'speaker_id'],
        num_rows: 13098
    })
    test: Dataset({
        features: ['meeting_id', 'audio_id', 'text', 'audio', 'begin_time', 'end_time', 'microphone_id', 'speaker_id'],
        num_rows: 12643
    })
})


In [7]:
print(dataset["validation"][0])

{'meeting_id': 'ES2011a', 'audio_id': 'AMI_ES2011a_H03_FEE044_0092784_0093052', 'text': "BUT LIKE MOBILE PHONES HAVE SCREENS AND THEY'RE CHEAP", 'audio': {'path': '/home/fullldiesel/.cache/huggingface/datasets/downloads/extracted/cc01c8370b8fb423a86dbb2bdd5be3a0aca08a711c26c6e1dc79352850c6207f/ES2011a/dev_ami_es2011a_h03_fee044_0092784_0093052.wav', 'array': array([-1.22070312e-04, -9.15527344e-05, -9.15527344e-05, ...,
        1.52587891e-04,  3.05175781e-05,  1.83105469e-04]), 'sampling_rate': 16000}, 'begin_time': 927.8400268554688, 'end_time': 930.52001953125, 'microphone_id': 'H03', 'speaker_id': 'FEE044'}


In [8]:
from collections import defaultdict

meetings = defaultdict(list)
for example in dataset["validation"]:
    example["duration"] = abs(example["end_time"] - example["begin_time"])
    meetings[example["meeting_id"]].append(example)    

In [9]:
# Sort the meetings by begin_time
for meeting_id in meetings:
    meetings[meeting_id].sort(key=lambda x: x["begin_time"])

In [10]:
meetings["ES2011a"][:5]

[{'meeting_id': 'ES2011a',
  'audio_id': 'AMI_ES2011a_H00_FEE041_0003427_0003714',
  'text': 'HERE WE GO',
  'audio': {'path': '/home/fullldiesel/.cache/huggingface/datasets/downloads/extracted/cc01c8370b8fb423a86dbb2bdd5be3a0aca08a711c26c6e1dc79352850c6207f/ES2011a/dev_ami_es2011a_h00_fee041_0003427_0003714.wav',
   'array': array([-2.74658203e-04, -3.05175781e-04, -2.13623047e-04, ...,
           9.15527344e-05,  6.10351562e-05,  0.00000000e+00]),
   'sampling_rate': 16000},
  'begin_time': 34.27000045776367,
  'end_time': 37.13999938964844,
  'microphone_id': 'H00',
  'speaker_id': 'FEE041',
  'duration': 2.8699989318847656},
 {'meeting_id': 'ES2011a',
  'audio_id': 'AMI_ES2011a_H00_FEE041_0003714_0003915',
  'text': 'WELCOME EVERYBODY',
  'audio': {'path': '/home/fullldiesel/.cache/huggingface/datasets/downloads/extracted/cc01c8370b8fb423a86dbb2bdd5be3a0aca08a711c26c6e1dc79352850c6207f/ES2011a/dev_ami_es2011a_h00_fee041_0003714_0003915.wav',
   'array': array([-3.05175781e-05, -3.0

In [11]:
meetings["ES2011a"][7]["audio"]["array"].shape

(67200,)

In [12]:
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

In [13]:
processor.tokenizer.special_tokens_map

{'bos_token': '<|endoftext|>',
 'eos_token': '<|endoftext|>',
 'unk_token': '<|endoftext|>',
 'pad_token': '<|endoftext|>',
 'additional_special_tokens': ['<|startoftranscript|>',
  '<|en|>',
  '<|zh|>',
  '<|de|>',
  '<|es|>',
  '<|ru|>',
  '<|ko|>',
  '<|fr|>',
  '<|ja|>',
  '<|pt|>',
  '<|tr|>',
  '<|pl|>',
  '<|ca|>',
  '<|nl|>',
  '<|ar|>',
  '<|sv|>',
  '<|it|>',
  '<|id|>',
  '<|hi|>',
  '<|fi|>',
  '<|vi|>',
  '<|iw|>',
  '<|uk|>',
  '<|el|>',
  '<|ms|>',
  '<|cs|>',
  '<|ro|>',
  '<|da|>',
  '<|hu|>',
  '<|ta|>',
  '<|no|>',
  '<|th|>',
  '<|ur|>',
  '<|hr|>',
  '<|bg|>',
  '<|lt|>',
  '<|la|>',
  '<|mi|>',
  '<|ml|>',
  '<|cy|>',
  '<|sk|>',
  '<|te|>',
  '<|fa|>',
  '<|lv|>',
  '<|bn|>',
  '<|sr|>',
  '<|az|>',
  '<|sl|>',
  '<|kn|>',
  '<|et|>',
  '<|mk|>',
  '<|br|>',
  '<|eu|>',
  '<|is|>',
  '<|hy|>',
  '<|ne|>',
  '<|mn|>',
  '<|bs|>',
  '<|kk|>',
  '<|sq|>',
  '<|sw|>',
  '<|gl|>',
  '<|mr|>',
  '<|pa|>',
  '<|si|>',
  '<|km|>',
  '<|sn|>',
  '<|yo|>',
  '<|so|>',
  '<

In [14]:
s = "I went to the shop. <|speakerturn|> Oh that's nice! <|speakerturn|>"

In [15]:
processor(text=s)

{'input_ids': [50257, 50362, 40, 1816, 284, 262, 6128, 13, 1279, 91, 47350, 861, 700, 91, 29, 3966, 326, 338, 3621, 0, 1279, 91, 47350, 861, 700, 91, 29, 50256], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [16]:
processor.tokenizer.tokenize(s)

['I',
 'Ġwent',
 'Ġto',
 'Ġthe',
 'Ġshop',
 '.',
 'Ġ<',
 '|',
 'speak',
 'ert',
 'urn',
 '|',
 '>',
 'ĠOh',
 'Ġthat',
 "'s",
 'Ġnice',
 '!',
 'Ġ<',
 '|',
 'speak',
 'ert',
 'urn',
 '|',
 '>']

In [17]:
processor.tokenizer.add_special_tokens({"additional_special_tokens": ["<|speakerturn|>"]}, replace_additional_special_tokens=False)

1

In [18]:
processor.tokenizer.special_tokens_map["additional_special_tokens"]

['<|startoftranscript|>',
 '<|en|>',
 '<|zh|>',
 '<|de|>',
 '<|es|>',
 '<|ru|>',
 '<|ko|>',
 '<|fr|>',
 '<|ja|>',
 '<|pt|>',
 '<|tr|>',
 '<|pl|>',
 '<|ca|>',
 '<|nl|>',
 '<|ar|>',
 '<|sv|>',
 '<|it|>',
 '<|id|>',
 '<|hi|>',
 '<|fi|>',
 '<|vi|>',
 '<|iw|>',
 '<|uk|>',
 '<|el|>',
 '<|ms|>',
 '<|cs|>',
 '<|ro|>',
 '<|da|>',
 '<|hu|>',
 '<|ta|>',
 '<|no|>',
 '<|th|>',
 '<|ur|>',
 '<|hr|>',
 '<|bg|>',
 '<|lt|>',
 '<|la|>',
 '<|mi|>',
 '<|ml|>',
 '<|cy|>',
 '<|sk|>',
 '<|te|>',
 '<|fa|>',
 '<|lv|>',
 '<|bn|>',
 '<|sr|>',
 '<|az|>',
 '<|sl|>',
 '<|kn|>',
 '<|et|>',
 '<|mk|>',
 '<|br|>',
 '<|eu|>',
 '<|is|>',
 '<|hy|>',
 '<|ne|>',
 '<|mn|>',
 '<|bs|>',
 '<|kk|>',
 '<|sq|>',
 '<|sw|>',
 '<|gl|>',
 '<|mr|>',
 '<|pa|>',
 '<|si|>',
 '<|km|>',
 '<|sn|>',
 '<|yo|>',
 '<|so|>',
 '<|af|>',
 '<|oc|>',
 '<|ka|>',
 '<|be|>',
 '<|tg|>',
 '<|sd|>',
 '<|gu|>',
 '<|am|>',
 '<|yi|>',
 '<|lo|>',
 '<|uz|>',
 '<|fo|>',
 '<|ht|>',
 '<|ps|>',
 '<|tk|>',
 '<|nn|>',
 '<|mt|>',
 '<|sa|>',
 '<|lb|>',
 '<|my|>',
 '<|bo

In [19]:
def merge_transcripts(utterances):
    merged_text = ""
    prev_speaker = None
    for utt in utterances:
        current_speaker = utt["speaker_id"]
        # Insert a speaker change token in the speaker changes
        if prev_speaker is not None and current_speaker != prev_speaker:
            merged_text += " <|speakerturn|> "
        merged_text += utt["text"].strip() + " "
        prev_speaker = current_speaker
    return merged_text.strip()

In [20]:
def merge_audio(utterances, sampling_rate=16000, gap_duration=0.1):
    gap_samples = int(gap_duration * sampling_rate)
    silence = np.zeros(gap_samples, dtype=np.float32)
    merged_audio_segments = []

    for utt in utterances:
        audio_array = utt["audio"]["array"]
        merged_audio_segments.append(audio_array)
        merged_audio_segments.append(silence)

    if merged_audio_segments:
        merged_audio_segments = merged_audio_segments[:-1]

    return np.concatenate(merged_audio_segments)

In [21]:
def segment_meeting(utterances, target_duration=25.0, max_duration=27.0):
    """
    Group utterances into segments of roughly target_duration to max_duration seconds.
    Each segment will have both merged text and merged audio.
    """
    segments = []
    current_segment = []
    current_duration = 0.0

    for utt in utterances:
        utt_duration = utt["duration"]
        if current_duration + utt_duration <= max_duration:
            current_segment.append(utt)
            current_duration += utt_duration
            # If we have reached at least the target duration, finalize the segment.
            if current_duration >= target_duration:
                segments.append(current_segment)
                current_segment = []
                current_duration = 0.0
        else:
            # If current segment is below target but adding the utterance exceeds max_duration,
            # you can choose to add it anyway (slight overage) or finalize the segment.
            if current_duration >= target_duration:
                segments.append(current_segment)
                current_segment = [utt]
                current_duration = utt_duration
            else:
                current_segment.append(utt)
                current_duration += utt_duration
                segments.append(current_segment)
                current_segment = []
                current_duration = 0.0

    if current_segment:
        segments.append(current_segment)
    return segments

In [22]:
# Example: process one meeting
meeting_id = "ES2011a"  # For example
meeting_utterances = meetings[meeting_id]
segments = segment_meeting(meeting_utterances, target_duration=25.0, max_duration=27.0)

# Check one segment
for i, segment in enumerate(segments):
    total_duration = sum(utt["duration"] for utt in segment)
    print(f"Segment {i+1}: Duration = {total_duration:.2f}s")
    print("Merged Text:", merge_transcripts(segment))

Segment 1: Duration = 26.26s
Merged Text: HERE WE GO WELCOME EVERYBODY UM I'M ABIGAIL CLAFLIN YOU CAN CALL ME ABBIE 'S SEE POWERPOINT THAT'S NOT IT THERE WE GO SO THIS IS OUR KICK OFF MEETING UM AND I GUESS WE SHOULD ALL GET ACQUAINTED LET'S SHALL WE ALL INTRODUCE OURSELVES
Segment 2: Duration = 26.75s
Merged Text: HI I'M CHIARA I'M THE UM MARKETING EXPERT UM WOULD YOU LIKE ME TO TALK ABOUT MY AIMS AT THE MOMENT OR WOULD YOU LIKE ME TO JUST SAY MY NAME AND THEN WE CAN TALK ABOUT BUSINESS LATER  <|speakerturn|> I THINK WE'LL GET AROUND TO THAT YEAH  <|speakerturn|> WE'LL GET ROUND TO THAT LATER  <|speakerturn|> SO THIS IS JUST INTRODUCTIONS YEAH  <|speakerturn|> MY NAME IS CHIARA AND I'M THE MARKETING EXPERT  <|speakerturn|> OKAY I FORGOT TO S SAY I'M THE PROJECT MANAGER BUT I FIGURED YOU ALL KNEW THAT ALREADY UM SO
Segment 3: Duration = 31.37s
Merged Text: I'M STEPHANIE AND I AM THE USER INTERFACE DESIGNER  <|speakerturn|> I'M KRISTA AND I'M THE INDUSTRIAL DESIGNER  <|speakerturn|> OKA

In [23]:
from datasets import Dataset

In [24]:
training_samples = []
for meeting_id, utterances in meetings.items():
    segments = segment_meeting(utterances, target_duration=24.0, max_duration=26.0)
    for i, segment in enumerate(segments):
        merged_text = merge_transcripts(segment)
        merged_audio = merge_audio(segment, sampling_rate=16000)
        segment_duration = sum(utt["duration"] for utt in segment)
        training_samples.append({
            "meeting_id": meeting_id,
            "position": i,
            "audio": {"array": merged_audio, "sampling_rate": 16000},
            "text": merged_text,
            "duration": segment_duration,
        })

In [25]:
len(training_samples)

1230

In [26]:
training_samples[:2]

[{'meeting_id': 'ES2011a',
  'position': 0,
  'audio': {'array': array([-2.74658203e-04, -3.05175781e-04, -2.13623047e-04, ...,
           9.15527344e-05,  9.15527344e-05,  9.15527344e-05]),
   'sampling_rate': 16000},
  'text': "HERE WE GO WELCOME EVERYBODY UM I'M ABIGAIL CLAFLIN YOU CAN CALL ME ABBIE 'S SEE POWERPOINT THAT'S NOT IT THERE WE GO SO THIS IS OUR KICK OFF MEETING UM AND I GUESS WE SHOULD ALL GET ACQUAINTED",
  'duration': 24.149993896484375},
 {'meeting_id': 'ES2011a',
  'position': 1,
  'audio': {'array': array([ 9.15527344e-05,  6.10351562e-05,  3.05175781e-05, ...,
          -4.57763672e-04, -5.18798828e-04, -2.74658203e-04]),
   'sampling_rate': 16000},
  'text': "LET'S SHALL WE ALL INTRODUCE OURSELVES  <|speakerturn|> HI I'M CHIARA I'M THE UM MARKETING EXPERT UM WOULD YOU LIKE ME TO TALK ABOUT MY AIMS AT THE MOMENT OR WOULD YOU LIKE ME TO JUST SAY MY NAME AND THEN WE CAN TALK ABOUT BUSINESS LATER  <|speakerturn|> I THINK WE'LL GET AROUND TO THAT YEAH  <|speakerturn|>

In [27]:
from datasets import Value, Features, Audio

In [28]:
def sample_generator(samples):
    for sample in samples:
        yield sample
features = Features({
    "meeting_id": Value("string"),
    "position": Value("int32"),
    "audio": Audio(sampling_rate=16000),
    "text": Value("string"),
    "duration": Value("float32")
})
ds = Dataset.from_generator(lambda: sample_generator(training_samples), features=features)

In [29]:
model_name = "openai/whisper-tiny.en"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)

In [30]:
specials = processor.tokenizer.special_tokens_map["additional_special_tokens"]
specials

['<|startoftranscript|>',
 '<|en|>',
 '<|zh|>',
 '<|de|>',
 '<|es|>',
 '<|ru|>',
 '<|ko|>',
 '<|fr|>',
 '<|ja|>',
 '<|pt|>',
 '<|tr|>',
 '<|pl|>',
 '<|ca|>',
 '<|nl|>',
 '<|ar|>',
 '<|sv|>',
 '<|it|>',
 '<|id|>',
 '<|hi|>',
 '<|fi|>',
 '<|vi|>',
 '<|iw|>',
 '<|uk|>',
 '<|el|>',
 '<|ms|>',
 '<|cs|>',
 '<|ro|>',
 '<|da|>',
 '<|hu|>',
 '<|ta|>',
 '<|no|>',
 '<|th|>',
 '<|ur|>',
 '<|hr|>',
 '<|bg|>',
 '<|lt|>',
 '<|la|>',
 '<|mi|>',
 '<|ml|>',
 '<|cy|>',
 '<|sk|>',
 '<|te|>',
 '<|fa|>',
 '<|lv|>',
 '<|bn|>',
 '<|sr|>',
 '<|az|>',
 '<|sl|>',
 '<|kn|>',
 '<|et|>',
 '<|mk|>',
 '<|br|>',
 '<|eu|>',
 '<|is|>',
 '<|hy|>',
 '<|ne|>',
 '<|mn|>',
 '<|bs|>',
 '<|kk|>',
 '<|sq|>',
 '<|sw|>',
 '<|gl|>',
 '<|mr|>',
 '<|pa|>',
 '<|si|>',
 '<|km|>',
 '<|sn|>',
 '<|yo|>',
 '<|so|>',
 '<|af|>',
 '<|oc|>',
 '<|ka|>',
 '<|be|>',
 '<|tg|>',
 '<|sd|>',
 '<|gu|>',
 '<|am|>',
 '<|yi|>',
 '<|lo|>',
 '<|uz|>',
 '<|fo|>',
 '<|ht|>',
 '<|ps|>',
 '<|tk|>',
 '<|nn|>',
 '<|mt|>',
 '<|sa|>',
 '<|lb|>',
 '<|my|>',
 '<|bo

In [31]:
special_tokens = {"additional_special_tokens": ["<|speakerturn|>"]}
print(len(processor.tokenizer))
processor.tokenizer.add_special_tokens(special_tokens, replace_additional_special_tokens=False)
model.resize_token_embeddings(len(processor.tokenizer))
print(len(processor.tokenizer))

51864


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


51865


In [32]:
from peft import LoraConfig, get_peft_model

In [33]:
lora_config = LoraConfig(
    r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none"
    )

In [34]:
peft_model = get_peft_model(model, lora_config)

In [35]:
for name, param in peft_model.named_parameters():
    if "lora" in name.lower():
        print(name, param.shape)

base_model.model.model.encoder.layers.0.self_attn.v_proj.lora_A.default.weight torch.Size([8, 384])
base_model.model.model.encoder.layers.0.self_attn.v_proj.lora_B.default.weight torch.Size([384, 8])
base_model.model.model.encoder.layers.0.self_attn.q_proj.lora_A.default.weight torch.Size([8, 384])
base_model.model.model.encoder.layers.0.self_attn.q_proj.lora_B.default.weight torch.Size([384, 8])
base_model.model.model.encoder.layers.1.self_attn.v_proj.lora_A.default.weight torch.Size([8, 384])
base_model.model.model.encoder.layers.1.self_attn.v_proj.lora_B.default.weight torch.Size([384, 8])
base_model.model.model.encoder.layers.1.self_attn.q_proj.lora_A.default.weight torch.Size([8, 384])
base_model.model.model.encoder.layers.1.self_attn.q_proj.lora_B.default.weight torch.Size([384, 8])
base_model.model.model.encoder.layers.2.self_attn.v_proj.lora_A.default.weight torch.Size([8, 384])
base_model.model.model.encoder.layers.2.self_attn.v_proj.lora_B.default.weight torch.Size([384, 8])


In [36]:
def data_collator(features):
    """
    Processes a list of training samples and returns a batch dictionary.
    Each sample should have:
      - "audio": a dict with keys "array" and "sampling_rate"
      - "text": the target transcript (with <|speakerturn|> tokens)
    """
    input_features_list = []
    labels_list = []
    
    for f in features:
        # Process audio into log-mel spectrogram features.
        audio_array = f["audio"]["array"]
        sample_rate = f["audio"]["sampling_rate"]
        inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt", truncation=True, padding="max_length")
        # Squeeze extra dimensions so that input_features is 2D: (feature_length, hidden_size)
        input_features_list.append(inputs.input_features.squeeze(0))
        
        # Tokenize the target text; pad/truncate to a fixed max_length (e.g., 512 tokens).
        tokenized = processor.tokenizer(f["text"], truncation=True, max_length=448, padding="max_length", return_tensors="pt")
        labels_list.append(tokenized.input_ids.squeeze(0))
    
    batch = {
        "input_features": torch.stack(input_features_list),
        "labels": torch.stack(labels_list)
    }
    return batch

# Test the data collator with one sample from your dataset (ds is your dataset from the generator)
# For example, if ds is already created as shown in the previous step:
sample = ds[0]
collated = data_collator([sample])
print("Data collator test:")
print("Input features shape:", collated["input_features"].shape)
print("Labels shape:", collated["labels"].shape)


Data collator test:
Input features shape: torch.Size([1, 80, 3000])
Labels shape: torch.Size([1, 448])


In [69]:
training_args = TrainingArguments(
    output_dir="./whisper-diarization-lora",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    eval_strategy="no",
    save_strategy="epoch",
    num_train_epochs=1,
    learning_rate=1e-4,
    fp16=True,
    logging_steps=5,
    report_to="none",
    remove_unused_columns=False
)

In [70]:
training_args.generation_config = peft_model.generation_config
training_args.data_collator = data_collator

In [71]:
trainer = Seq2SeqTrainer(
    model=peft_model,
    args=training_args,
    train_dataset=ds,  # your dataset created from the generator
    data_collator=data_collator,
    processing_class=processor,
)

In [72]:
trainer.train()

Step,Training Loss
5,0.6405
10,0.6728
15,0.6611
20,0.6114


KeyboardInterrupt: 

In [None]:
pef

In [55]:
ds[0]

{'meeting_id': 'ES2011a',
 'position': 0,
 'audio': {'path': None,
  'array': array([-2.74658203e-04, -3.05175781e-04, -2.13623047e-04, ...,
          9.15527344e-05,  9.15527344e-05,  9.15527344e-05]),
  'sampling_rate': 16000},
 'text': "HERE WE GO WELCOME EVERYBODY UM I'M ABIGAIL CLAFLIN YOU CAN CALL ME ABBIE 'S SEE POWERPOINT THAT'S NOT IT THERE WE GO SO THIS IS OUR KICK OFF MEETING UM AND I GUESS WE SHOULD ALL GET ACQUAINTED",
 'duration': 24.149993896484375}