In [None]:
import torchaudio
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
import torch

class LID:
    def __init__(self, device='cpu'):
        model_id = "facebook/mms-lid-256"

        self.processor = AutoFeatureExtractor.from_pretrained(model_id)
        self.model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
        self.device = torch.device(device)
        self.model.to(device)
    def infer_lid_distribution_batch(self, waveform_chunks, sample_rate, threshold=1):
        """
        Returns the language probability distribution for multiple waveform chunks.

        Args:
            waveform_chunks (List[Tensor] or Tensor): List of 1D tensors or 2D tensor (batch_size, sequence_length) of audio samples.
            sample_rate (int): Sampling rate of the audio.
            threshold (float): top languages that explain certain percentages of the distribution. Default as 1 to return all languages.

        Returns:
            List[Dict[str, float]]: List of mappings from language code to probability for each chunk.
        """
        # Handle both list of tensors and batched tensor inputs
        if isinstance(waveform_chunks, list):
            # Process as a batch - pad sequences to same length if needed
            inputs = self.processor(waveform_chunks, sampling_rate=sample_rate, return_tensors="pt", padding=True)
        else:
            # Assume it's already a batched tensor
            inputs = self.processor(waveform_chunks, sampling_rate=sample_rate, return_tensors="pt")

        with torch.no_grad():
            logits = self.model(**inputs).logits  # Shape: (batch_size, num_classes)
            probs = torch.softmax(logits, dim=-1)  # Shape: (batch_size, num_classes)

        batch_results = []

        # Process each sample in the batch
        for batch_idx in range(probs.shape[0]):
            sample_probs = probs[batch_idx]

            # Map probabilities to language labels for this sample
            lang_distribution = {
                self.model.config.id2label[idx]: prob.item()
                for idx, prob in enumerate(sample_probs)
            }

            sorted_langs = sorted(lang_distribution.items(), key=lambda x: x[1], reverse=True)
            cumulative_prob = 0.0
            selected_langs = {}

            for lang, prob in sorted_langs:
                selected_langs[lang] = prob
                cumulative_prob += prob
                if cumulative_prob >= threshold:
                    break

            batch_results.append(selected_langs)

        return batch_results


In [4]:
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
import torch

model_id = "facebook/mms-lid-256"

model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
model.to('cpu')

Wav2Vec2ForSequenceClassification(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (1-4): 4 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=

In [2]:
import torch
import time
import numpy as np
from typing import List, Dict
import matplotlib.pyplot as plt

from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor

model_id = "facebook/mms-lid-256"
processor = AutoFeatureExtractor.from_pretrained(model_id)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)


def create_dummy_batch_tensor(batch_size: int, duration_seconds: float = 5.0, sample_rate: int = 16000):
    """
    Create a batch of dummy audio as a single batched tensor.

    Returns:
        2d matrix in list[list[float]]
    """
    num_samples = int(duration_seconds * sample_rate)
    return (torch.randn(batch_size, num_samples) * 0.1).tolist()


# Batch function
def infer_lid_distribution_batch(waveform_chunks, sample_rate, threshold=1):
    """Batch version"""
    if isinstance(waveform_chunks, list):
        inputs = processor(waveform_chunks, sampling_rate=sample_rate, return_tensors="pt", padding=True)
    else:
        inputs = processor(waveform_chunks, sampling_rate=sample_rate, return_tensors="pt")

    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)


def benchmark_single_vs_batch(batch_sizes: List[int], num_runs: int = 3):
    """
    Benchmark single vs batch processing across different batch sizes.

    Args:
        batch_sizes: List of batch sizes to test
        num_runs: Number of runs to average over
    """


    print(f"Benchmarking Language ID: Single vs Batch Processing")
    print(f"Audio: 30s chunks at 16kHz ({30*16000:,} samples per chunk)")
    print(f"Runs per test: {num_runs}")
    print("-" * 60)

    for batch_size in batch_sizes:
        print(f"\nTesting batch size: {batch_size}")

        # Create test data

        waveform_tensor = create_dummy_batch_tensor(batch_size)


        # Time batch processing
        batch_times = []
        for run in range(num_runs):
            start_time = time.time()
            infer_lid_distribution_batch(waveform_tensor, 16000)
            end_time = time.time()
            batch_times.append(end_time - start_time)

        avg_batch_time = np.mean(batch_times)

        print(f"Time: {avg_batch_time:.3f}s ")


# Example usage:
if __name__ == "__main__":

    # Full benchmark across different batch sizes
    batch_sizes_to_test = [1]

    print("\n" + "="*60)
    print("RUNNING FULL BENCHMARK")
    print("="*60)

    benchmark_single_vs_batch(
        batch_sizes=batch_sizes_to_test,
        num_runs=1
    )




RUNNING FULL BENCHMARK
Benchmarking Language ID: Single vs Batch Processing
Audio: 30s chunks at 16kHz (480,000 samples per chunk)
Runs per test: 1
------------------------------------------------------------

Testing batch size: 1
Time: 1.022s 


Quick test with batch size 2:
list
Created 2 waveforms, each with 480,000 samples
Batch processing took: 13.467s
Got 2 results
Sample result keys: ['ara', 'cmn', 'eng', 'spa', 'fra']...

============================================================
RUNNING FULL BENCHMARK
============================================================
Benchmarking Language ID: Single vs Batch Processing
Audio: 30s chunks at 16kHz (480,000 samples per chunk)
Threshold: 0.95
Runs per test: 3
------------------------------------------------------------

Testing batch size: 1
  Single processing: 2.529s (±0.006s)
  Batch processing:  2.697s (±0.088s)
  Speedup: 0.94x

Testing batch size: 2
  Single processing: 5.413s (±0.130s)
  Batch processing:  4.944s (±0.011s)
  Speedup: 1.10x

Testing batch size: 4
  Single processing: 10.646s (±0.146s)
  Batch processing:  9.600s (±0.058s)
  Speedup: 1.11x

Testing batch size: 8
  Single processing: 21.354s (±0.094s)
  Batch processing:  18.376s (±0.064s)
  Speedup: 1.16x

Testing batch size: 16
  Single processing: 45.138s (±0.306s)
  Batch processing:  42.045s (±0.465s)
  Speedup: 1.07x
