# PyTorch ASR Phoneme Extraction

This notebook demonstrates how to extract phoneme representations from speech using PyTorch and pre-trained ASR models.

## Setup

First, let's install the required libraries if they're not already installed.

In [None]:
# Uncomment and run if you need to install the packages
!pip install torch torchaudio transformers matplotlib numpy soundfile librosa

In [None]:
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import matplotlib.pyplot as plt
import numpy as np
import IPython.display as ipd
import os
import subprocess

# Try to install required packages for audio processing
try:
    import librosa
    import soundfile as sf
except ImportError:
    print("Installing librosa and soundfile for audio processing...")
    !pip install librosa soundfile
    import librosa
    import soundfile as sf

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Download Sample Audio

Let's download a sample audio file to work with.

In [None]:
# Download and extract a sample audio file from LibriSpeech
import os
import tarfile
import tempfile
from urllib.request import urlretrieve
import shutil

sample_dir = "sample_data"
os.makedirs(sample_dir, exist_ok=True)

# Target audio file paths - we'll create both FLAC and WAV versions
flac_path = os.path.join(sample_dir, "sample_audio.flac")
wav_path = os.path.join(sample_dir, "sample_audio.wav")

# Check which files exist and set the audio path accordingly
flac_exists = os.path.exists(flac_path)
wav_exists = os.path.exists(wav_path)

# Prefer WAV if it exists, otherwise use FLAC if it exists
if wav_exists:
    audio_path = wav_path
    print(f"Using existing WAV file: {wav_path}")
elif flac_exists:
    audio_path = flac_path
    print(f"Using existing FLAC file: {flac_path}")
    
    # Try to convert to WAV if FLAC exists but WAV doesn't
    print("Converting FLAC to WAV format for better compatibility...")
    try:
        # Load the audio file with librosa
        audio_data, sample_rate = librosa.load(flac_path, sr=None)
        
        # Save as WAV using soundfile
        sf.write(wav_path, audio_data, sample_rate)
        print(f"Converted audio saved to {wav_path}")
        audio_path = wav_path  # Use the newly created WAV file
    except Exception as e:
        print(f"Error converting audio: {e}")
        print("Using original FLAC file instead.")
else:
    # Neither file exists, need to download
    print("Sample audio not found. Downloading and extracting from archive...")
    
    # Download the tarball
    tarball_url = "https://openslr.elda.org/resources/12/dev-clean.tar.gz"
    with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as temp_file:
        print(f"Downloading archive from {tarball_url}...")
        urlretrieve(tarball_url, temp_file.name)
        tarball_path = temp_file.name
    
    # Extract the specific file we need
    target_file_path = "LibriSpeech/dev-clean/84/121123/84-121123-0001.flac"
    
    with tempfile.TemporaryDirectory() as temp_dir:
        print("Extracting archive...")
        with tarfile.open(tarball_path, "r:gz") as tar:
            # Extract only the file we need
            member = tar.getmember(target_file_path)
            tar.extract(member, path=temp_dir)
        
        # Move the extracted file to our sample directory
        extracted_file = os.path.join(temp_dir, target_file_path)
        shutil.copy(extracted_file, flac_path)
    
    # Clean up the tarball
    os.unlink(tarball_path)
    print(f"Sample audio extracted to {flac_path}")
    
    # Convert FLAC to WAV using librosa
    print("Converting FLAC to WAV format...")
    try:
        # Load the audio file with librosa
        audio_data, sample_rate = librosa.load(flac_path, sr=None)
        
        # Save as WAV using soundfile
        sf.write(wav_path, audio_data, sample_rate)
        print(f"Converted audio saved to {wav_path}")
        audio_path = wav_path  # Use the WAV file
    except Exception as e:
        print(f"Error converting audio: {e}")
        print("Using original FLAC file instead.")
        audio_path = flac_path

print(f"Using audio file: {audio_path}")

## Loading and Processing Audio

In [None]:
def process_audio(file_path):
    # Load audio using alternative method if torchaudio fails
    try:
        # Try torchaudio first
        waveform, sample_rate = torchaudio.load(file_path)
    except RuntimeError:
        # Fall back to using librosa
        print(f"torchaudio failed to load {file_path}, trying librosa instead...")
        import librosa
        import numpy as np
        
        # Load with librosa (automatically handles various formats including FLAC)
        audio_data, sample_rate = librosa.load(file_path, sr=None)
        waveform = torch.from_numpy(audio_data).unsqueeze(0).float()
        print("Successfully loaded audio with librosa")
    
    # Resample if needed
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(sample_rate, 16000)
        waveform = resampler(waveform)
        sample_rate = 16000
    
    # Convert to mono if needed
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    return waveform.squeeze(), sample_rate

# Load and process the audio
waveform, sample_rate = process_audio(audio_path)

# Display audio information
print(f"Sample rate: {sample_rate} Hz")
print(f"Waveform shape: {waveform.shape}")
print(f"Audio duration: {waveform.shape[0]/sample_rate:.2f} seconds")

# Play the audio
ipd.Audio(waveform.numpy(), rate=sample_rate)

## Load Pre-trained ASR Model

We'll use the Wav2Vec 2.0 model from Facebook, which has been pre-trained on 960 hours of LibriSpeech.

In [None]:
# Load pre-trained model and processor
model_name = "facebook/wav2vec2-base-960h"
print(f"Loading model: {model_name}")

processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)

print("Model loaded successfully!")

## Extracting Phoneme Probabilities

Now we'll extract the logits from the model, which represent the probabilities of different phonemes at each time step.

In [None]:
def extract_phoneme_probs(waveform, sample_rate=16000):
    # Process audio for model input
    input_values = processor(waveform, sampling_rate=sample_rate, return_tensors="pt").input_values
    input_values = input_values.to(device)
    
    # Get model outputs (without gradient calculation)
    with torch.no_grad():
        outputs = model(input_values)
        logits = outputs.logits
    
    # Convert logits to probabilities
    probs = torch.nn.functional.softmax(logits, dim=-1)
    
    return probs.cpu().squeeze(), processor.tokenizer.decoder

# Get phoneme probabilities
phoneme_probs, decoder = extract_phoneme_probs(waveform)
print(f"Shape of phoneme probabilities: {phoneme_probs.shape}")
print(f"Number of time steps: {phoneme_probs.shape[0]}")
print(f"Number of phoneme classes: {phoneme_probs.shape[1]}")

## Visualizing Phoneme Activations

Let's visualize the top phoneme activations over time.

In [None]:
def plot_phoneme_activations(probs, decoder, top_k=5):
    # Get top-k phonemes at each time step
    top_probs, top_indices = torch.topk(probs, k=top_k, dim=1)
    
    # Convert to numpy for plotting
    top_probs = top_probs.numpy()
    top_indices = top_indices.numpy()
    
    # Get phoneme labels
    phoneme_map = {v: k for k, v in decoder.items()}
    
    # Create a time axis (assuming 50 frames per second for Wav2Vec 2.0)
    time_steps = np.arange(top_probs.shape[0]) / 50
    
    # Plot
    plt.figure(figsize=(15, 8))
    
    # Plot for a subset of time steps for clarity
    start_idx = 0
    end_idx = min(200, len(time_steps))  # Show first 4 seconds or less
    
    for i in range(top_k):
        plt.plot(time_steps[start_idx:end_idx], 
                 top_probs[start_idx:end_idx, i], 
                 label=f"Class {top_indices[0, i]} ({phoneme_map.get(top_indices[0, i], '')})")
    
    plt.xlabel("Time (seconds)")
    plt.ylabel("Probability")
    plt.title("Top Phoneme Activations Over Time")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Visualize phoneme activations
plot_phoneme_activations(phoneme_probs, decoder)

## Decoding to Phonemes and Text

Now let's decode the model outputs to both phonemes and text.

In [None]:
def decode_outputs(probs, decoder):
    # Get the most likely phoneme at each time step
    pred_ids = torch.argmax(probs, dim=-1)
    
    # Decode to phonemes (keeping all predictions)
    phoneme_sequence = [decoder.get(id.item(), f"[{id.item()}]") for id in pred_ids]
    
    # Apply CTC decoding logic (collapse repeated tokens and remove blanks)
    collapsed_phonemes = []
    prev_id = -1
    for id in pred_ids:
        if id != prev_id and id != 0:  # 0 is usually the blank token in CTC
            collapsed_phonemes.append(decoder.get(id.item(), f"[{id.item()}]"))
        prev_id = id
    
    # Join phonemes to get the text
    text = ''.join(collapsed_phonemes).replace('|', ' ')
    
    return phoneme_sequence, collapsed_phonemes, text

# Decode outputs
phoneme_sequence, collapsed_phonemes, text = decode_outputs(phoneme_probs, decoder)

print("Full phoneme sequence (first 50 frames):")
print(phoneme_sequence[:50])
print("\nCollapsed phoneme sequence:")
print(collapsed_phonemes)
print("\nDecoded text:")
print(text)

## Using a Model Fine-tuned for Phoneme Recognition

For a more direct approach to phoneme recognition, we can use a model specifically fine-tuned for phoneme recognition.

In [None]:
# Load a model fine-tuned for phoneme recognition
# Note: This will download a different model
phoneme_model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
print(f"Loading phoneme model: {phoneme_model_name}")

try:
    # Import specific processor class for this model
    from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
    
    # Load the model components separately
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(phoneme_model_name)
    tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(phoneme_model_name)
    phoneme_processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
    phoneme_model = Wav2Vec2ForCTC.from_pretrained(phoneme_model_name).to(device)
    print("Phoneme model loaded successfully!")
    
    def transcribe_to_phonemes(waveform, sample_rate=16000):
        # Process audio for model input
        input_values = phoneme_processor(waveform, sampling_rate=sample_rate, return_tensors="pt").input_values
        input_values = input_values.to(device)
        
        # Get model predictions
        with torch.no_grad():
            logits = phoneme_model(input_values).logits
        
        # Decode phonemes
        predicted_ids = torch.argmax(logits, dim=-1)
        phoneme_string = phoneme_processor.batch_decode(predicted_ids)[0]
        
        return phoneme_string

    # Get phoneme transcription
    phoneme_transcription = transcribe_to_phonemes(waveform)
    print("\nPhoneme transcription:")
    print(phoneme_transcription)
    
except Exception as e:
    print(f"Error loading phoneme model: {e}")
    print("Skipping phoneme-specific model demonstration.")

## Analyzing Phoneme Distributions

Let's analyze the distribution of phonemes in our sample.

In [None]:
# Count phoneme occurrences
from collections import Counter

# Count non-blank phonemes
phoneme_counts = Counter([p for p in collapsed_phonemes if p != ''])

# Plot top 15 phonemes
top_phonemes = phoneme_counts.most_common(15)
phonemes, counts = zip(*top_phonemes)

plt.figure(figsize=(12, 6))
plt.bar(phonemes, counts)
plt.title('Top 15 Phonemes in Sample')
plt.xlabel('Phoneme')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## Comparing Phoneme Transcriptions

Let's compare the different phoneme transcriptions side by side to see the differences between methods.

In [None]:
# Create a comparison of the different phoneme transcriptions
import pandas as pd
from IPython.display import display, HTML

# Store the transcriptions in variables for comparison
# Note: These will be populated when the cells above are run
ctc_text = text  # From the CTC decoding section
try:
    specialized_text = phoneme_transcription  # From the specialized model section
except NameError:
    specialized_text = "[Model failed to load]"

# Create a DataFrame for comparison
comparison_df = pd.DataFrame({
    'Method': ['Wav2Vec2 Base with CTC Decoding', 'Specialized Phoneme Model'],
    'Transcription': [ctc_text, specialized_text]
})

# Display the comparison
display(HTML(comparison_df.to_html(index=False)))

# Also show a more detailed comparison of the phoneme sequences
print("\nDetailed Phoneme Sequence Comparison:")
print("\nWav2Vec2 Base with CTC Decoding:")
print(' '.join(collapsed_phonemes[:30]) + "...")

try:
    # For the specialized model, we might need to split the string into individual phonemes
    if isinstance(specialized_text, str):
        specialized_phonemes = list(specialized_text.replace(" ", "|SPACE|"))
        print("\nSpecialized Phoneme Model:")
        print(' '.join(specialized_phonemes[:30]) + "...")
except Exception as e:
    print(f"\nCould not process specialized phonemes: {e}")

## Conclusion

In this notebook, we've demonstrated how to:

1. Load and process audio files for ASR
2. Extract phoneme probabilities from Wav2Vec 2.0 models
3. Visualize phoneme activations over time
4. Decode phoneme sequences to text
5. Use models specifically fine-tuned for phoneme recognition
6. Compare different phoneme transcription methods

These techniques can be applied to various applications such as:
- Studying pronunciation patterns
- Developing language learning tools
- Creating more interpretable ASR systems
- Analyzing speech disorders