In [12]:
# imports
import numpy as np

import torch
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration, TrainingArguments, Seq2SeqTrainer
import soundfile as sf

In [2]:
dataset = load_dataset("edinburghcstr/ami", "ihm")

In [3]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['meeting_id', 'audio_id', 'text', 'audio', 'begin_time', 'end_time', 'microphone_id', 'speaker_id'],
        num_rows: 108502
    })
    validation: Dataset({
        features: ['meeting_id', 'audio_id', 'text', 'audio', 'begin_time', 'end_time', 'microphone_id', 'speaker_id'],
        num_rows: 13098
    })
    test: Dataset({
        features: ['meeting_id', 'audio_id', 'text', 'audio', 'begin_time', 'end_time', 'microphone_id', 'speaker_id'],
        num_rows: 12643
    })
})


In [4]:
print(dataset["validation"][0])

{'meeting_id': 'ES2011a', 'audio_id': 'AMI_ES2011a_H03_FEE044_0092784_0093052', 'text': "BUT LIKE MOBILE PHONES HAVE SCREENS AND THEY'RE CHEAP", 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/39f9cd4b79885620930ee08a0d5d8e320b058663a767622d29d00bb4a90eff94/ES2011a/dev_ami_es2011a_h03_fee044_0092784_0093052.wav', 'array': array([-1.22070312e-04, -9.15527344e-05, -9.15527344e-05, ...,
        1.52587891e-04,  3.05175781e-05,  1.83105469e-04]), 'sampling_rate': 16000}, 'begin_time': 927.8400268554688, 'end_time': 930.52001953125, 'microphone_id': 'H03', 'speaker_id': 'FEE044'}


In [5]:
from collections import defaultdict

meetings = defaultdict(list)
for example in dataset["train"]:
    example["duration"] = abs(example["end_time"] - example["begin_time"])
    meetings[example["meeting_id"]].append(example)    

In [6]:
# Sort the meetings by begin_time
for meeting_id in meetings:
    meetings[meeting_id].sort(key=lambda x: x["begin_time"])

In [21]:
meetings["EN2001a"][:2]

[{'meeting_id': 'EN2001a',
  'audio_id': 'AMI_EN2001a_H04_MEO069_0000334_0000388',
  'text': "'KAY",
  'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/e9041db40e8bea1afc48f66ff890393f21d11a13f8eecd0e1d8fcdc96b1fde48/EN2001a/train_ami_en2001a_h04_meo069_0000334_0000388.wav',
   'array': array([ 0.00012207,  0.00015259,  0.00015259, ..., -0.00128174,
          -0.00125122, -0.00128174]),
   'sampling_rate': 16000},
  'begin_time': 3.3399999141693115,
  'end_time': 3.880000114440918,
  'microphone_id': 'H04',
  'speaker_id': 'MEO069',
  'duration': 0.5400002002716064},
 {'meeting_id': 'EN2001a',
  'audio_id': 'AMI_EN2001a_H00_MEE068_0000557_0000594',
  'text': 'OKAY',
  'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/e9041db40e8bea1afc48f66ff890393f21d11a13f8eecd0e1d8fcdc96b1fde48/EN2001a/train_ami_en2001a_h00_mee068_0000557_0000594.wav',
   'array': array([0.        , 0.        , 0.        , ..., 0.00033569, 0.00030518,
          0.000305

In [22]:
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

In [23]:
def merge_transcripts(utterances):
    merged_text = ""
    prev_speaker = None
    for utt in utterances:
        current_speaker = utt["speaker_id"]
        # Insert a speaker change token in the speaker changes
        if prev_speaker is not None and current_speaker != prev_speaker:
            merged_text += " <|speakerturn|> "
        merged_text += utt["text"].strip() + " "
        prev_speaker = current_speaker
    return merged_text.strip()

In [24]:
def merge_audio(utterances, sampling_rate=16000, gap_duration=0.1):
    gap_samples = int(gap_duration * sampling_rate)
    silence = np.zeros(gap_samples, dtype=np.float32)
    merged_audio_segments = []

    for utt in utterances:
        audio_array = utt["audio"]["array"]
        merged_audio_segments.append(audio_array)
        merged_audio_segments.append(silence)

    if merged_audio_segments:
        merged_audio_segments = merged_audio_segments[:-1]

    return np.concatenate(merged_audio_segments)

In [25]:
import os
import soundfile as sf

def save_audio(merged_audio, sampling_rate, output_dir, meeting_id, segment_idx):
    """
    Save a merged audio clip as a WAV file.

    Args:
        merged_audio (numpy.ndarray): The merged audio clip.
        sampling_rate (int): The sampling rate (e.g., 16000).
        output_dir (str): Directory where the audio file will be saved.
        meeting_id (str): Identifier for the meeting.
        segment_idx (int): The index of the segment within the meeting.

    Returns:
        str: The file path of the saved audio.
    """
    os.makedirs(output_dir, exist_ok=True)
    file_path = os.path.join(output_dir, f"{meeting_id}_segment{segment_idx}.wav")
    sf.write(file_path, merged_audio, sampling_rate)
    return file_path


In [26]:
def segment_meeting(utterances, target_duration=25.0, max_duration=27.0):
    """
    Group utterances into segments of roughly target_duration to max_duration seconds.
    Each segment will have both merged text and merged audio.
    """
    segments = []
    current_segment = []
    current_duration = 0.0

    for utt in utterances:
        utt_duration = utt["duration"]
        if current_duration + utt_duration <= max_duration:
            current_segment.append(utt)
            current_duration += utt_duration
            # If we have reached at least the target duration, finalize the segment.
            if current_duration >= target_duration:
                segments.append(current_segment)
                current_segment = []
                current_duration = 0.0
        else:
            # If current segment is below target but adding the utterance exceeds max_duration,
            # you can choose to add it anyway (slight overage) or finalize the segment.
            if current_duration >= target_duration:
                segments.append(current_segment)
                current_segment = [utt]
                current_duration = utt_duration
            else:
                current_segment.append(utt)
                current_duration += utt_duration
                segments.append(current_segment)
                current_segment = []
                current_duration = 0.0

    if current_segment:
        segments.append(current_segment)
    return segments

In [27]:
# Example: process one meeting
meeting_id = "EN2001a"  # For example
meeting_utterances = meetings[meeting_id]
segments = segment_meeting(meeting_utterances, target_duration=25.0, max_duration=27.0)

# Check one segment
for i, segment in enumerate(segments):
    total_duration = sum(utt["duration"] for utt in segment)
    print(f"Segment {i+1}: Duration = {total_duration:.2f}s")
    print("Merged Text:", merge_transcripts(segment))

Segment 1: Duration = 30.58s
Merged Text: 'KAY  <|speakerturn|> OKAY  <|speakerturn|> GOSH 'KAY  <|speakerturn|> DOES ANYONE WANT TO SEE UH STEVE'S FEEDBACK FROM THE SPECIFICATION  <|speakerturn|> IS THERE MUCH MORE IN IT THAN HE D  <|speakerturn|> I I DRY READ IT THE LAST TIME  <|speakerturn|> RIGHT  <|speakerturn|> IS THERE MUCH MORE IN IT THAN HE SAID YESTERDAY  <|speakerturn|> NOT REALLY UM JUST WHAT HE'S TALKING ABOUT LIKE DUPLICATION OF EFFORT AND  <|speakerturn|> MM HMM HMM  <|speakerturn|> LIKE DUPLICATION OF EFFORT AND STUFF AND UM YEAH HE WAS SAYING THAT WE SHOULD MAYBE UH THINK ABOUT HAVING A PROTOTYPE FOR WEEK SIX WHICH IS NEXT WEEK
Segment 2: Duration = 30.67s
Merged Text: NEXT WEEK  <|speakerturn|> YEAH SO WE SHOULD PROBABLY PRIORITIZE OUR PACKAGES  <|speakerturn|> YEAH NOW I'D SAY IF FOR THE PROTOTYPE IF WE JUST LIKE WHEREVER POSSIBLE P CHUNK IN THE STUFF THAT WE HAVE UM PRE ANNOTATED AND STUFF AND FOR THE STUFF THAT WE DON'T HAVE PRE ANNOTATED WRITE LIKE A STUPID BASELI

In [28]:
from datasets import Dataset

In [30]:
training_samples = []
output_dir = "./merged_audio_clips"
for meeting_id, utterances in meetings.items():
    segments = segment_meeting(utterances, target_duration=24.0, max_duration=26.0)
    for i, segment in enumerate(segments):
        merged_text = merge_transcripts(segment)
        merged_audio = merge_audio(segment, sampling_rate=16000)
        segment_duration = sum(utt["duration"] for utt in segment)

        # Save the merged audio clip to disk
        audio_file_path = save_audio(merged_audio, 16_000, output_dir, meeting_id, i)
        training_samples.append({
            "meeting_id": meeting_id,
            "position": i,
            "audio": audio_file_path,
            "text": merged_text,
            "duration": segment_duration,
        })

In [31]:
from datasets import Value, Features, Audio

In [34]:
from datasets import Dataset, Audio, Features, Value

features = Features({
    "meeting_id": Value("string"),
    "position": Value("int32"),
    "audio": Audio(sampling_rate=16000),  # Lazy loads audio from the file path
    "text": Value("string"),
    "duration": Value("float32")
})

ds = Dataset.from_list(training_samples, features=features)
print(ds)

Dataset({
    features: ['meeting_id', 'position', 'audio', 'text', 'duration'],
    num_rows: 10589
})


In [41]:
ds[500:505]

{'meeting_id': ['EN2001e', 'EN2001e', 'EN2001e', 'EN2001e', 'EN2001e'],
 'position': [78, 79, 80, 81, 82],
 'audio': [{'path': './merged_audio_clips/EN2001e_segment78.wav',
   'array': array([ 6.10351562e-05,  9.15527344e-05,  1.22070312e-04, ...,
          -3.05175781e-05, -3.05175781e-05, -6.10351562e-05]),
   'sampling_rate': 16000},
  {'path': './merged_audio_clips/EN2001e_segment79.wav',
   'array': array([-3.05175781e-05,  0.00000000e+00,  0.00000000e+00, ...,
           0.00000000e+00, -3.05175781e-05,  0.00000000e+00]),
   'sampling_rate': 16000},
  {'path': './merged_audio_clips/EN2001e_segment80.wav',
   'array': array([ 0.00000000e+00, -3.05175781e-05,  3.05175781e-05, ...,
          -5.03540039e-03, -4.54711914e-03, -2.96020508e-03]),
   'sampling_rate': 16000},
  {'path': './merged_audio_clips/EN2001e_segment81.wav',
   'array': array([ 0.00015259,  0.00021362,  0.00018311, ..., -0.00152588,
          -0.00094604,  0.00106812]),
   'sampling_rate': 16000},
  {'path': './me

In [42]:
model_name = "openai/whisper-tiny.en"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)

In [43]:
specials = processor.tokenizer.special_tokens_map["additional_special_tokens"]
specials

['<|startoftranscript|>',
 '<|en|>',
 '<|zh|>',
 '<|de|>',
 '<|es|>',
 '<|ru|>',
 '<|ko|>',
 '<|fr|>',
 '<|ja|>',
 '<|pt|>',
 '<|tr|>',
 '<|pl|>',
 '<|ca|>',
 '<|nl|>',
 '<|ar|>',
 '<|sv|>',
 '<|it|>',
 '<|id|>',
 '<|hi|>',
 '<|fi|>',
 '<|vi|>',
 '<|iw|>',
 '<|uk|>',
 '<|el|>',
 '<|ms|>',
 '<|cs|>',
 '<|ro|>',
 '<|da|>',
 '<|hu|>',
 '<|ta|>',
 '<|no|>',
 '<|th|>',
 '<|ur|>',
 '<|hr|>',
 '<|bg|>',
 '<|lt|>',
 '<|la|>',
 '<|mi|>',
 '<|ml|>',
 '<|cy|>',
 '<|sk|>',
 '<|te|>',
 '<|fa|>',
 '<|lv|>',
 '<|bn|>',
 '<|sr|>',
 '<|az|>',
 '<|sl|>',
 '<|kn|>',
 '<|et|>',
 '<|mk|>',
 '<|br|>',
 '<|eu|>',
 '<|is|>',
 '<|hy|>',
 '<|ne|>',
 '<|mn|>',
 '<|bs|>',
 '<|kk|>',
 '<|sq|>',
 '<|sw|>',
 '<|gl|>',
 '<|mr|>',
 '<|pa|>',
 '<|si|>',
 '<|km|>',
 '<|sn|>',
 '<|yo|>',
 '<|so|>',
 '<|af|>',
 '<|oc|>',
 '<|ka|>',
 '<|be|>',
 '<|tg|>',
 '<|sd|>',
 '<|gu|>',
 '<|am|>',
 '<|yi|>',
 '<|lo|>',
 '<|uz|>',
 '<|fo|>',
 '<|ht|>',
 '<|ps|>',
 '<|tk|>',
 '<|nn|>',
 '<|mt|>',
 '<|sa|>',
 '<|lb|>',
 '<|my|>',
 '<|bo

In [44]:
special_tokens = {"additional_special_tokens": ["<|speakerturn|>"]}
print(len(processor.tokenizer))
processor.tokenizer.add_special_tokens(special_tokens, replace_additional_special_tokens=False)
model.resize_token_embeddings(len(processor.tokenizer))
print(len(processor.tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


51864
51865


In [45]:
from peft import LoraConfig, get_peft_model

In [46]:
lora_config = LoraConfig(
    r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none"
    )

In [47]:
peft_model = get_peft_model(model, lora_config)

In [48]:
for name, param in peft_model.named_parameters():
    if "lora" in name.lower():
        print(name, param.shape)

base_model.model.model.encoder.layers.0.self_attn.v_proj.lora_A.default.weight torch.Size([8, 384])
base_model.model.model.encoder.layers.0.self_attn.v_proj.lora_B.default.weight torch.Size([384, 8])
base_model.model.model.encoder.layers.0.self_attn.q_proj.lora_A.default.weight torch.Size([8, 384])
base_model.model.model.encoder.layers.0.self_attn.q_proj.lora_B.default.weight torch.Size([384, 8])
base_model.model.model.encoder.layers.1.self_attn.v_proj.lora_A.default.weight torch.Size([8, 384])
base_model.model.model.encoder.layers.1.self_attn.v_proj.lora_B.default.weight torch.Size([384, 8])
base_model.model.model.encoder.layers.1.self_attn.q_proj.lora_A.default.weight torch.Size([8, 384])
base_model.model.model.encoder.layers.1.self_attn.q_proj.lora_B.default.weight torch.Size([384, 8])
base_model.model.model.encoder.layers.2.self_attn.v_proj.lora_A.default.weight torch.Size([8, 384])
base_model.model.model.encoder.layers.2.self_attn.v_proj.lora_B.default.weight torch.Size([384, 8])


In [49]:
def data_collator(features):
    """
    Processes a list of training samples and returns a batch dictionary.
    Each sample should have:
      - "audio": a dict with keys "array" and "sampling_rate"
      - "text": the target transcript (with <|speakerturn|> tokens)
    """
    input_features_list = []
    labels_list = []
    
    for f in features:
        # Process audio into log-mel spectrogram features.
        audio_array = f["audio"]["array"]
        sample_rate = f["audio"]["sampling_rate"]
        inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt", truncation=True, padding="max_length")
        # Squeeze extra dimensions so that input_features is 2D: (feature_length, hidden_size)
        input_features_list.append(inputs.input_features.squeeze(0))
        
        # Tokenize the target text; pad/truncate to a fixed max_length (e.g., 512 tokens).
        tokenized = processor.tokenizer(f["text"], truncation=True, max_length=448, padding="max_length", return_tensors="pt")
        labels_list.append(tokenized.input_ids.squeeze(0))
    
    batch = {
        "input_features": torch.stack(input_features_list),
        "labels": torch.stack(labels_list)
    }
    return batch

# Test the data collator with one sample from your dataset (ds is your dataset from the generator)
# For example, if ds is already created as shown in the previous step:
sample = ds[0]
collated = data_collator([sample])
print("Data collator test:")
print("Input features shape:", collated["input_features"].shape)
print("Labels shape:", collated["labels"].shape)


Data collator test:
Input features shape: torch.Size([1, 80, 3000])
Labels shape: torch.Size([1, 448])


In [3]:
from transformers import TrainingArguments

In [None]:
TrainingArguments()

In [50]:
training_args = TrainingArguments(
    output_dir="./whisper-diarization-lora",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    eval_strategy="no",
    save_strategy="epoch",
    num_train_epochs=1,
    learning_rate=1e-4,
    fp16=True,
    logging_steps=5,
    report_to="none",
    remove_unused_columns=False
)

In [51]:
training_args.generation_config = peft_model.generation_config
training_args.data_collator = data_collator

In [56]:
trainer = Seq2SeqTrainer(
    model=peft_model,
    args=training_args,
    train_dataset=ds,  # your dataset created from the generator
    data_collator=data_collator,
    processing_class=processor,
)
trainer.label_names = ["labels"]

No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss
5,1.6986
10,1.2602
15,1.0308
20,0.9315


In [None]:
pe

In [None]:
ds[0]