I'm experimenting with an approach to improve how Spoken Dialog Systems (like OpenAI's Advanced Voice Mode) identify when a user is done speaking (the end of a their `converational turn`) so it responds at appropriate times without interrupting the user mid-thought.

As a user, current systems appear to rely primarily on pause length from a Voice Activity Detector (VAD). So an `X` second pause mid-sentence is treated similarly to `X` seconds of silence at the end of a sentence.

This notebook adds an end-of-turn prediction head to a Whisper model. This prediction head relies on both acoustic and linguistic information (both the Whisper encoder and the Whisper decoder). I will train it from multi-turn conversations between human speakers with labeled transitions. Implicitly, I assume humans do well at knowing when to jump in without interrupting ¯\_(ツ)_/¯

Candidate datasets for training: AMI Meeting Corpus, Switchboard (which has a $3k licensing fee), CALLHOME. If necessary, I could use diarazation tools like pyannotate-audio or pyAudioAnalysis to identify changes in the speaker.

A production system might still use a VAD as a computationally cheap approach to identify pauses of at least 1 second, and then I'd call my updated model from this notebook for final end-of-turn detection and transcription only when there is some meaningful pause.

This is an educational project to gain initial experience with audio models. I may be naively underestimating what others before me have done. I will nevertheless run incremental experiments as a learning experience. The main steps are:

- [x] Download a Whisper model from HuggingFace hub and verify that I can run it on a trivial file
- [x] Inspect the architecture and plan how to integrate a prediction head
- [x] Overfit on a single sample to verify I can train
- [ ] Train on a set of ~100 samples with some small number of validation samples to test infrastructure
- [ ] Train on a larger sample to test if I can make something that broadly "works"

# Example Calling A Whisper Model

I start with `whisper-tiny.en`. Later stages may use a large model like `whisper-large-v3` once I have the basic workflow wired up.

In [1]:
from time import time
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-tiny.en"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
 
processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)

start = time()
result = pipe("./test_data/test.wav")
print(f"Keys: {result.keys() if isinstance(result, dict) else 'not a dict'}")
print(result["text"])
print(f"Time taken: {time() - start}")

long_audio_file = "./test_data/multi-chunk-test.wav"
result2 = pipe(long_audio_file, return_timestamps=True)
print(result2)

Device set to use cpu


Keys: dict_keys(['text'])
 My name is Dan and this is a test audio file.
Time taken: 0.24007105827331543


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


{'text': " Okay, this is going to be a test of recordings that are longer than 30 seconds so that I can test how well this model handles multi chunk. Audio files. So each chunk is going to be 30 seconds. If you have a recording this longer than 30 seconds, it's going to be processed as multiple chunks. And then each chunk is going to be handled by the large-engaged model to make a transcription. And then we're going to concatenate those various transcriptions. .", 'chunks': [{'timestamp': (0.0, 15.0), 'text': ' Okay, this is going to be a test of recordings that are longer than 30 seconds so that I can test how well this model handles multi chunk.'}, {'timestamp': (15.0, 0.0), 'text': ''}, {'timestamp': (7.0, 14.0), 'text': ' Audio files. So each chunk is going to be 30 seconds.'}, {'timestamp': (14.0, 18.0), 'text': " If you have a recording this longer than 30 seconds, it's going to be processed as multiple chunks."}, {'timestamp': (18.0, 23.0), 'text': ' And then each chunk is going

# Inspect the Pipeline and Model

In [2]:
# Whats' in the pipeline?
print([step for step in pipe.__dict__.keys() if not step.startswith("_")])


['type', 'task', 'model', 'tokenizer', 'feature_extractor', 'image_processor', 'processor', 'modelcard', 'framework', 'device', 'binary_output', 'prefix', 'generation_config', 'call_count']


In [3]:
pipe.feature_extractor

WhisperFeatureExtractor {
  "chunk_length": 30,
  "feature_extractor_type": "WhisperFeatureExtractor",
  "feature_size": 80,
  "hop_length": 160,
  "n_fft": 400,
  "n_samples": 480000,
  "nb_max_frames": 3000,
  "padding_side": "right",
  "padding_value": 0.0,
  "processor_class": "WhisperProcessor",
  "return_attention_mask": false,
  "sampling_rate": 16000
}

Pipeline expects incoming audio at 16kHz. Preprocessor creates 80 dimensional features in 0.01second increments (the hop length of 160 means 160 frames/samples of audio are combined into a single input frame to the model)

In [4]:
from torchinfo import summary
decoder_input_ids = torch.tensor([[processor.tokenizer.pad_token_id]]).to(device)
print(summary(pipe.model, input_size=(1, 80, 3000), decoder_input_ids=decoder_input_ids))

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Layer (type:depth-idx)                                  Output Shape              Param #
WhisperForConditionalGeneration                         [1, 1500, 384]            --
├─WhisperModel: 1-1                                     [1, 1500, 384]            --
│    └─WhisperEncoder: 2-1                              [1, 1500, 384]            576,000
│    │    └─Conv1d: 3-1                                 [1, 384, 3000]            92,544
│    │    └─Conv1d: 3-2                                 [1, 384, 1500]            442,752
│    │    └─ModuleList: 3-3                             --                        7,096,320
│    │    └─LayerNorm: 3-4                              [1, 1500, 384]            768
│    └─WhisperDecoder: 2-2                              [1, 6, 1, 64]             --
│    │    └─Embedding: 3-5                              [1, 1, 384]               19,915,776
│    │    └─WhisperPositionalEmbedding: 3-6             [1, 1, 384]               172,032
│    │    └─ModuleList: 3

# Adding TurnEndClassifier to Whisper Model

I add a TurnEndClassifier as an additional prediction head for the Whisper model. It's inputs are 
1. audio features from the encoder
2. Semantic info from decoder

## Alignment

Each recording will be classified with a single prediction of whether it ends at a turn-end. The recording has many values from the encoder on the time dimension (100 samples per second) and many values from the decoder (1 hidden state per token).

I compress the time and token dimensions from the encoder and decoder respectively into a 1d representation from each. I do these compressions with convolutional layers and then pooling layers (separate layers for the encoder and the decoder compression to 1d). These 1d vectors are concatenated, fed through a small feedforward network, and lead to a binary classification head indicating if this audio finishes with a turn-end.

For training, I'll preprocess conversations to have many samples that are each 10 seconds of audio and that finish at a turn end (target is 1) or that don't (target is 0).

In [5]:
from torch import nn
import torch

class TurnEndClassifier(nn.Module):
    def __init__(self, encoder_dim=384, decoder_dim=384, hidden_dim=64):
        super(TurnEndClassifier, self).__init__()

        # Linear layers to reduce the dimensions of the encoder and decoder outputs.
        # This is a temporary hack. In theory, hidden_dim could be larger than whisper_hidden_dim. Or encoder and decoder could have different dimensions.
        self.encoder_reduce = nn.Linear(encoder_dim, hidden_dim)
        self.decoder_reduce = nn.Linear(decoder_dim, hidden_dim)
        
        # Encoder processing layers
        self.encoder_conv = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=3, padding=1),
            nn.ReLU()
        )
        self.encoder_pool = nn.AdaptiveMaxPool1d(1)
        
        # Decoder processing layers
        self.decoder_conv = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder_pool = nn.AdaptiveMaxPool1d(1)
        
        # Fully connected layers for classification
        self.fc = nn.Sequential(
            nn.Linear(2*hidden_dim , hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, encoder_outputs, decoder_outputs):
        # Reduce feature dimension first (more efficient for subsequent conv layers)
        # [batch, time, whisper_hidden_dim] -> [batch, time, hidden_dim]
        encoder_outputs = self.encoder_reduce(encoder_outputs)
        
        # [batch, seq_len, whisper_hidden_dim] -> [batch, seq_len, hidden_dim]
        decoder_outputs = self.decoder_reduce(decoder_outputs)
        
        # Process encoder outputs
        # [batch, time, hidden_dim] -> [batch, hidden_dim, time]
        encoder_outputs = encoder_outputs.transpose(1, 2)  # Conv1d expects channels first
        encoder_features = self.encoder_conv(encoder_outputs)  # Now working with reduced dimensions
        encoder_features = self.encoder_pool(encoder_features).squeeze(-1)  # -> [batch, hidden_dim]
        
        # Process decoder outputs
        # [batch, seq_len, hidden_dim] -> [batch, hidden_dim, seq_len]
        decoder_outputs = decoder_outputs.transpose(1, 2)  # Conv1d expects channels first
        decoder_features = self.decoder_conv(decoder_outputs)  # Now working with reduced dimensions
        decoder_features = self.decoder_pool(decoder_features).squeeze(-1)  # -> [batch, hidden_dim]
        
        # Concatenate and classify
        # [batch, hidden_dim] + [batch, hidden_dim] -> [batch, 2*hidden_dim]
        combined_features = torch.cat((encoder_features, decoder_features), dim=1)
        # [batch, 2*hidden_dim] -> [batch, 1]
        output = self.fc(combined_features)
        
        return output

## Custom Model
Now we build the custom model that is a Whisper Model with the extra prediction head. Our goals include:
1. Allow training the layers in the TurnEndClassifier prediction head while keeping all other layers frozen
2. Allow embedding this in a pipeline that reuses parts of the Whisper pipeline (e.g. for preprocessing)

In [6]:
from transformers import WhisperForConditionalGeneration
import torch


class CustomWhisperModel(WhisperForConditionalGeneration):
    def __init__(self, config, turn_end_classifier=None):
        super().__init__(config)
        self.turn_end_classifier = turn_end_classifier # if turn_end_classifier is None, we'll set it later

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, turn_end_classifier, *args, **kwargs):
        model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
        model.turn_end_classifier = turn_end_classifier
        return model

    def forward_with_turn_end(
        self,
        input_features=None,
        decoder_input_ids=None,
        **kwargs  # Catch all other args to pass through
    ):
        assert self.turn_end_classifier is not None, "TurnEndClassifier must be set before calling forward_with_turn_end"
        
        # whisper_outputs is a Seq2SeqLMOutput object with `logits` for each token, encoder_last_hidden_state, and decoder_hidden_states, decoder_hidden_states, encoder_hidden_states and some other attributes. It's a type of namedtuple.
        whisper_outputs = super().forward(
            input_features=input_features,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=True,
            **kwargs
        )

        # Get output of last hidden layer from both encoder and decoder.
        encoder_hidden_states = whisper_outputs.encoder_last_hidden_state
        decoder_hidden_states = whisper_outputs.decoder_hidden_states[-1]

        turn_end_predictions = self.turn_end_classifier(
            encoder_hidden_states,
            decoder_hidden_states
        )
        
        return whisper_outputs, turn_end_predictions


## Test the CustomWhisperModel

In [7]:
def test_custom_whisper_model():
    # 1. Create fake data
    batch_size = 2
    sequence_length = 3000
    encoder_dim = 384  # Whisper tiny dimension
    decoder_seq_length = 20
    
    fake_input_features = torch.randn(batch_size, 80, sequence_length)
    fake_decoder_input_ids = torch.randint(0, 100, (batch_size, decoder_seq_length))
    
    # 2. Initialize models
    turn_end_classifier = TurnEndClassifier(
        encoder_dim=encoder_dim,
        decoder_dim=encoder_dim,
        hidden_dim=64
    )
    
    model = CustomWhisperModel.from_pretrained(
        "openai/whisper-tiny.en",
        turn_end_classifier=turn_end_classifier
    )
    
    # 3. Run forward pass

    whisper_outputs, turn_end_predictions = model.forward_with_turn_end(
        input_features=fake_input_features,
        decoder_input_ids=fake_decoder_input_ids,
    )
    print(f"Output type: {type(whisper_outputs)}")
    print("\nAvailable attributes:")
    print([attr for attr in dir(whisper_outputs) if not attr.startswith('_')])
    
    
        
test_custom_whisper_model()

Output type: <class 'transformers.modeling_outputs.Seq2SeqLMOutput'>

Available attributes:
['clear', 'copy', 'cross_attentions', 'decoder_attentions', 'decoder_hidden_states', 'encoder_attentions', 'encoder_hidden_states', 'encoder_last_hidden_state', 'fromkeys', 'get', 'items', 'keys', 'logits', 'loss', 'move_to_end', 'past_key_values', 'pop', 'popitem', 'setdefault', 'to_tuple', 'update', 'values']


# Wrap Custom Model In A Custom Pipeline

Create a pipeline that can take a .wav file as input and returns turn_end_probability. 

As an initial implementation, we will have the pipeline ONLY do prediction for whether the speaker is done speaking (turn-end-classification). The pipeline will not also do transcription.

This is computationally inefficient because we have to compute all the encoder and decoder states in order to do turn-end-prediction, and then they will repeated in the Whisper pipeline we use for transcription. A more efficient approach would reuse the encoder and decoder states calculated in the turn-end-prediction as a starting point for transcription, so we only need to do the sampling. But I will save that as a potential future enhancement.

In [8]:
from transformers import AutomaticSpeechRecognitionPipeline
import torch
import numpy as np
from typing import Dict
import librosa

class TurnEndPipeline(AutomaticSpeechRecognitionPipeline):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_duration = 30.0  # seconds
        self.sampling_rate = 16000  # Whisper expects 16kHz
    

    def _prepare_audio_features(self, audio):
        """Process audio, taking full audio if <= 30s or last 30s if longer."""

        if isinstance(audio, str):
            # Load from file
            audio_array, sampling_rate = librosa.load(audio, sr=self.sampling_rate)
            assert sampling_rate == 16000, f"Expected sampling rate of 16kHz, but got {sampling_rate}Hz."
        else:
            # Assume audio is already loaded as numpy array
            audio_array = audio
        max_samples = int(self.max_duration * self.sampling_rate)
        audio_array = audio_array[-max_samples:]

        features = self.feature_extractor(
            audio_array, 
            sampling_rate=self.sampling_rate, 
            return_tensors="pt"
        )        
        return features.to(self.device)
    
    def __call__(self, audio):
        """Predict if audio ends at a turn boundary.
        For files > 30 seconds, only examines the last 30 seconds."""
        
        # Process audio
        features = self._prepare_audio_features(audio)
        
        # Prepare model inputs
        model_kwargs = {
            "input_features": features.input_features,
            "decoder_input_ids": torch.tensor([[self.model.config.decoder_start_token_id]]).to(self.device)
        }
        if "attention_mask" in features:
            model_kwargs["attention_mask"] = features.attention_mask
            
        # Forward pass through model
        _, turn_end_predictions = self.model.forward_with_turn_end(**model_kwargs)
        
        return {
            "turn_end_probability": turn_end_predictions.squeeze().item()
        }

turn_end_classifier = TurnEndClassifier(
    encoder_dim=384,
    decoder_dim=384,
    hidden_dim=64
).to(device)


turn_end_model = CustomWhisperModel.from_pretrained(
    model_id,
    turn_end_classifier=turn_end_classifier,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True
).to(device)


turn_end_pipe = TurnEndPipeline(
    model=turn_end_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)


from time import time


start = time()
result = turn_end_pipe("./test_data/test.wav")
print(f"Turn end probability: {result['turn_end_probability']:.2f}")
print(f"Time taken: {time() - start}")


print("---")
start = time()
result = turn_end_pipe(long_audio_file)
print(f"Turn end probability: {result['turn_end_probability']:.2f}")
print(f"Time taken: {time() - start}")


Device set to use cpu


Turn end probability: 0.53
Time taken: 0.47891807556152344
---
Turn end probability: 0.52
Time taken: 0.09000015258789062


# Train/Overfit on Single Sample
Just to test that we can do basic training

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import librosa
import numpy as np
from tqdm import tqdm

def create_training_samples(audio_path):
    """Create two samples from the audio file"""
    # Load audio
    audio, sr = librosa.load(audio_path, sr=16000)
    
    # Sample 1: Full audio (turn end)
    full_audio = audio
    
    # Sample 2: First 4 seconds (not turn end)
    partial_audio = audio[:int(4 * sr)]
    
    return full_audio, partial_audio

def train_single_epoch(model, samples, labels, feature_extractor, tokenizer, optimizer, criterion, device):
    """Train for one epoch on our two samples"""
    model.eval()  # Freeze Whisper
    model.turn_end_classifier.train()  # Train only classifier
    
    epoch_loss = 0
    for audio, label in zip(samples, labels):
        # Reset gradients
        optimizer.zero_grad()
        
        # Preprocess audio
        features = feature_extractor(
            audio, 
            sampling_rate=16000, 
            return_tensors="pt"
        ).to(device)
        
        # Prepare model inputs
        decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]]).to(device)
        
        # Forward pass
        _, turn_end_pred = model.forward_with_turn_end(
            input_features=features.input_features,
            decoder_input_ids=decoder_input_ids
        )
        
        # Calculate loss
        target = torch.tensor([[label]], dtype=torch.float).to(device)
        loss = criterion(turn_end_pred, target)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(samples)

def train_model(audio_path, model, feature_extractor, tokenizer, device, num_epochs=51):
    """Main training loop"""
    # Prepare samples
    full_audio, partial_audio = create_training_samples(audio_path)
    samples = [full_audio, partial_audio]
    labels = [1.0, 0.0]  # 1 for turn end, 0 for not turn end
    
    # Setup training
    optimizer = optim.Adam(model.turn_end_classifier.parameters(), lr=0.001)
    criterion = nn.BCELoss()
    
    # Training loop
    losses = []
    for epoch in tqdm(range(num_epochs)):
        loss = train_single_epoch(
            model, samples, labels, 
            feature_extractor, tokenizer, 
            optimizer, criterion, device
        )
        losses.append(loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss:.4f}")
            
            # Print predictions
            with torch.no_grad():
                for audio, label in zip(samples, labels):
                    features = feature_extractor(
                        audio, 
                        sampling_rate=16000, 
                        return_tensors="pt"
                    ).to(device)
                    
                    decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]]).to(device)
                    _, pred = model.forward_with_turn_end(
                        input_features=features.input_features,
                        decoder_input_ids=decoder_input_ids
                    )
                    print(f"Target: {label:.1f}, Prediction: {pred.item():.4f}")
    
    return losses

# Run training
losses = train_model(
    "./test_data/test.wav",
    turn_end_model,
    turn_end_pipe.feature_extractor,
    turn_end_pipe.tokenizer,
    device
)

  2%|▏         | 1/51 [00:00<00:25,  1.94it/s]

Epoch 0, Loss: 0.7394
Target: 1.0, Prediction: 0.5481
Target: 0.0, Prediction: 0.5481


 22%|██▏       | 11/51 [00:04<00:15,  2.53it/s]

Epoch 10, Loss: 0.6957
Target: 1.0, Prediction: 0.5067
Target: 0.0, Prediction: 0.5058


 41%|████      | 21/51 [00:07<00:11,  2.54it/s]

Epoch 20, Loss: 0.6903
Target: 1.0, Prediction: 0.5034
Target: 0.0, Prediction: 0.4975


 61%|██████    | 31/51 [00:11<00:08,  2.49it/s]

Epoch 30, Loss: 0.6219
Target: 1.0, Prediction: 0.5919
Target: 0.0, Prediction: 0.4749


 80%|████████  | 41/51 [00:15<00:04,  2.44it/s]

Epoch 40, Loss: 0.0011
Target: 1.0, Prediction: 0.9994
Target: 0.0, Prediction: 0.0002


100%|██████████| 51/51 [00:19<00:00,  2.67it/s]

Epoch 50, Loss: 0.0000
Target: 1.0, Prediction: 1.0000
Target: 0.0, Prediction: 0.0000



