# Test Notebook: Alignment Module

This notebook tests the `alignment` module for speech-to-text alignment.

**Tests:**
1. Module imports and structure
2. Data classes (AlignmentResult, AlignedWord, AlignedToken)
3. WFST factor transducer construction
4. Tokenizers (via text_frontend)
5. Audio segmentation (via audio_frontend)
6. LIS utilities
7. MFA backend availability
8. Gentle backend availability
9. WFST Aligner integration
10. Segment-wise alignment API
11. Ground truth data loading
12. Run WFST alignment on sample audio
13. Accuracy comparison (prediction vs ground truth)
14. Listening test (audio preview)

**Installation (Colab):**
```bash
# GPU Version
pip install k2==1.24.4.dev20251030+cuda12.6.torch2.9.0 -f https://k2-fsa.github.io/k2/cuda.html

# CPU Version (use --no-deps to avoid env changes)
pip install k2==1.24.4.dev20251029+cpu.torch2.9.0 --no-deps -f https://k2-fsa.github.io/k2/cpu.html

# Common dependencies
pip install pytorch-lightning cmudict g2p_en pydub
pip install git+https://github.com/huangruizhe/lis.git
```

## Setup

In [None]:
# Reset repo if needed (uncomment to force fresh clone)
# !rm -rf /content/torchaudio_aligner

In [None]:
# =============================================================================
# Install Dependencies (run once)
# =============================================================================

# ===== GPU Version =====
# !pip install k2==1.24.4.dev20251030+cuda12.6.torch2.9.0 -f https://k2-fsa.github.io/k2/cuda.html

# ===== CPU Version (--no-deps to avoid env changes) =====
# !pip install k2==1.24.4.dev20251029+cpu.torch2.9.0 --no-deps -f https://k2-fsa.github.io/k2/cpu.html

# ===== Common dependencies =====
# !pip install pytorch-lightning cmudict g2p_en pydub
# !pip install git+https://github.com/huangruizhe/lis.git

In [None]:
# =============================================================================
# Setup: Configure Imports
# =============================================================================

import sys
import os
from pathlib import Path

# ===== CONFIGURATION =====
GITHUB_REPO = "https://github.com/huangruizhe/torchaudio_aligner.git"
BRANCH = "dev"
# =========================

test_results = {}

def setup_imports():
    IN_COLAB = 'google.colab' in sys.modules
    
    if IN_COLAB:
        repo_path = '/content/torchaudio_aligner'
        src_path = f'{repo_path}/src'
        
        if not os.path.exists(repo_path):
            print(f"Cloning repository (branch: {BRANCH})...")
            os.system(f'git clone -b {BRANCH} {GITHUB_REPO} {repo_path}')
        else:
            print(f"Updating repository (branch: {BRANCH})...")
            os.system(f'cd {repo_path} && git fetch origin && git checkout {BRANCH} && git pull origin {BRANCH}')
    else:
        possible_paths = [
            Path(".").absolute().parent / "src",
            Path(".").absolute() / "src",
        ]
        src_path = None
        for p in possible_paths:
            if p.exists() and (p / "alignment").exists():
                src_path = str(p.absolute())
                break
        if src_path is None:
            raise FileNotFoundError("src directory not found")
        print(f"Running locally from: {src_path}")
    
    if src_path not in sys.path:
        sys.path.insert(0, src_path)
    return src_path

src_path = setup_imports()

import torch
import logging
logging.basicConfig(level=logging.INFO)

print()
print("=" * 60)
print("Checking dependencies...")
print("=" * 60)

try:
    import k2
    print(f"‚úÖ k2 available")
    K2_AVAILABLE = True
except ImportError:
    print("‚ö†Ô∏è k2 not available")
    K2_AVAILABLE = False

try:
    import lis
    print("‚úÖ lis library available")
    LIS_AVAILABLE = True
except ImportError:
    print("‚ö†Ô∏è lis not available")
    LIS_AVAILABLE = False

print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## Test 1: Module Imports

In [None]:
print("=" * 60)
print("Test 1: Module Imports and Structure")
print("=" * 60)

try:
    from alignment import (
        AlignmentResult, AlignedWord, AlignedToken, AlignmentConfig,
        AlignerBackend, WFSTAligner, MFAAligner, GentleAligner,
        align, get_aligner, list_backends,
    )
    print("üì¶ Imports successful!")
    backends = list_backends()
    print(f"\nüîß Available backends: {list(backends.keys())}")
    test_results["Test 1"] = "‚úÖ PASSED"
    print(f"\n‚úÖ Test 1 PASSED")
except Exception as e:
    test_results["Test 1"] = "‚ùå FAILED"
    print(f"‚ùå Test 1 FAILED: {e}")
    import traceback; traceback.print_exc()

## Test 2: Data Classes

In [None]:
print("=" * 60)
print("Test 2: Data Classes")
print("=" * 60)

try:
    config = AlignmentConfig(backend="wfst", segment_size=15.0, skip_penalty=-0.5)
    print(f"AlignmentConfig: segment_size={config.segment_size}s")
    
    word = AlignedWord(word="hello", start_time=100, end_time=150)
    print(f"AlignedWord: '{word.word}' [{word.start_time}-{word.end_time}]")
    
    result = AlignmentResult(
        word_alignments={0: AlignedWord("hello", 100, 150), 1: AlignedWord("world", 160, 220)},
        unaligned_indices=[(2, 3)],
    )
    print(f"AlignmentResult: {result.num_aligned_words} words")
    
    test_results["Test 2"] = "‚úÖ PASSED"
    print(f"\n‚úÖ Test 2 PASSED")
except Exception as e:
    test_results["Test 2"] = "‚ùå FAILED"
    print(f"‚ùå Test 2 FAILED: {e}")

## Test 3-10: Core Components

Tests 3-10 verify WFST transducer, tokenizers, segmentation, LIS, and aligner backends.

In [None]:
# Test 3: WFST Factor Transducer
print("=" * 60)
print("Test 3: WFST Factor Transducer")
print("=" * 60)

if not K2_AVAILABLE:
    test_results["Test 3"] = "‚è≠Ô∏è SKIPPED"
    print("‚è≠Ô∏è Skipped - k2 not available")
else:
    try:
        from alignment.wfst import make_factor_transducer_word_level_index_with_skip
        tokenized = [[7, 4, 11, 11, 14], [22, 14, 17, 11, 3]]
        graph, word_sym, token_sym = make_factor_transducer_word_level_index_with_skip(tokenized)
        print(f"Graph: {graph.shape[0]} states, {graph.num_arcs} arcs")
        test_results["Test 3"] = "‚úÖ PASSED"
        print("‚úÖ Test 3 PASSED")
    except Exception as e:
        test_results["Test 3"] = "‚ùå FAILED"
        print(f"‚ùå Test 3 FAILED: {e}")

In [None]:
# Test 4: Tokenizers
print("=" * 60)
print("Test 4: Tokenizers")
print("=" * 60)

try:
    from text_frontend import CharTokenizer, create_tokenizer_from_labels
    labels = ('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 
              'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '*')
    tokenizer = create_tokenizer_from_labels(labels)
    encoded = tokenizer.encode("hello world")
    print(f"Encoded 'hello world': {encoded}")
    test_results["Test 4"] = "‚úÖ PASSED"
    print("‚úÖ Test 4 PASSED")
except Exception as e:
    test_results["Test 4"] = "‚ùå FAILED"
    print(f"‚ùå Test 4 FAILED: {e}")

In [None]:
# Test 5: Audio Segmentation
print("=" * 60)
print("Test 5: Audio Segmentation")
print("=" * 60)

try:
    from audio_frontend import segment_waveform
    waveform_test = torch.randn(480000)
    result = segment_waveform(waveform_test, sample_rate=16000, segment_size=15.0, overlap=2.0)
    print(f"Segmented into {result.num_segments} segments")
    test_results["Test 5"] = "‚úÖ PASSED"
    print("‚úÖ Test 5 PASSED")
except Exception as e:
    test_results["Test 5"] = "‚ùå FAILED"
    print(f"‚ùå Test 5 FAILED: {e}")

In [None]:
# Tests 6-10: LIS, MFA, Gentle, WFST Aligner, Segment API
print("=" * 60)
print("Tests 6-10: Backend checks")
print("=" * 60)

# Test 6: LIS
if LIS_AVAILABLE:
    try:
        from alignment.wfst.lis_utils import compute_lis
        lis_result = compute_lis([1, 5, 2, 6, 3, 7])
        print(f"Test 6 LIS: {lis_result} ‚úÖ")
        test_results["Test 6"] = "‚úÖ PASSED"
    except Exception as e:
        test_results["Test 6"] = "‚ùå FAILED"
        print(f"Test 6 LIS: ‚ùå {e}")
else:
    test_results["Test 6"] = "‚è≠Ô∏è SKIPPED"
    print("Test 6 LIS: ‚è≠Ô∏è Skipped")

# Test 7: MFA
try:
    aligner = MFAAligner(AlignmentConfig(backend="mfa"))
    mfa_ok = aligner._check_mfa_available()
    test_results["Test 7"] = "‚úÖ PASSED" if mfa_ok else "‚ö†Ô∏è MFA NOT INSTALLED"
    print(f"Test 7 MFA: {'‚úÖ Available' if mfa_ok else '‚ö†Ô∏è Not installed'}")
except Exception as e:
    test_results["Test 7"] = "‚ùå FAILED"
    print(f"Test 7 MFA: ‚ùå {e}")

# Test 8: Gentle
try:
    aligner = GentleAligner(AlignmentConfig(backend="gentle"))
    gentle_ok = aligner._check_gentle_python() or aligner._check_gentle_server()
    test_results["Test 8"] = "‚úÖ PASSED" if gentle_ok else "‚ö†Ô∏è GENTLE NOT INSTALLED"
    print(f"Test 8 Gentle: {'‚úÖ Available' if gentle_ok else '‚ö†Ô∏è Not installed'}")
except Exception as e:
    test_results["Test 8"] = "‚ùå FAILED"
    print(f"Test 8 Gentle: ‚ùå {e}")

# Test 9-10: WFST Aligner
if K2_AVAILABLE and LIS_AVAILABLE:
    try:
        from alignment import WFSTAligner, SegmentAlignmentResult
        aligner = WFSTAligner(AlignmentConfig(backend="wfst"))
        has_align_segments = hasattr(aligner, 'align_segments')
        test_results["Test 9"] = "‚úÖ PASSED"
        test_results["Test 10"] = "‚úÖ PASSED" if has_align_segments else "‚ùå FAILED"
        print(f"Test 9 WFST Aligner: ‚úÖ")
        print(f"Test 10 Segment API: {'‚úÖ' if has_align_segments else '‚ùå'}")
    except Exception as e:
        test_results["Test 9"] = "‚ùå FAILED"
        test_results["Test 10"] = "‚ùå FAILED"
        print(f"Test 9-10: ‚ùå {e}")
else:
    test_results["Test 9"] = "‚è≠Ô∏è SKIPPED"
    test_results["Test 10"] = "‚è≠Ô∏è SKIPPED"
    print("Test 9-10: ‚è≠Ô∏è Skipped (missing k2/lis)")

## Test 11: Ground Truth Data

In [None]:
print("=" * 60)
print("Test 11: Ground Truth Data")
print("=" * 60)

# Ground truth from MMS-FA CTC alignment (50fps = 20ms per frame)
TRANSCRIPT = "I HAD THAT CURIOSITY BESIDE ME AT THIS MOMENT"
FRAME_RATE = 50  # 20ms per frame

GROUND_TRUTH_WORDS = [
    {"word": "I", "start": 31, "end": 35},
    {"word": "HAD", "start": 37, "end": 44},
    {"word": "THAT", "start": 45, "end": 53},
    {"word": "CURIOSITY", "start": 56, "end": 92},
    {"word": "BESIDE", "start": 95, "end": 116},
    {"word": "ME", "start": 118, "end": 124},
    {"word": "AT", "start": 126, "end": 129},
    {"word": "THIS", "start": 131, "end": 139},
    {"word": "MOMENT", "start": 143, "end": 157},
]

print(f"Transcript: '{TRANSCRIPT}'")
print(f"Frame rate: {FRAME_RATE} fps (20ms/frame)")
print(f"\nGround truth ({len(GROUND_TRUTH_WORDS)} words):")
for w in GROUND_TRUTH_WORDS:
    start_sec = w['start'] / FRAME_RATE
    end_sec = w['end'] / FRAME_RATE
    print(f"  {w['word']:12s}: [{w['start']:3d}, {w['end']:3d}) = [{start_sec:.2f}s, {end_sec:.2f}s)")

test_results["Test 11"] = "‚úÖ PASSED"
print(f"\n‚úÖ Test 11 PASSED")

## Test 12: Run WFST Alignment

In [None]:
print("=" * 60)
print("Test 12: Run WFST Alignment on Sample Audio")
print("=" * 60)

if not K2_AVAILABLE or not LIS_AVAILABLE:
    test_results["Test 12"] = "‚è≠Ô∏è SKIPPED"
    print("‚è≠Ô∏è Skipped - missing k2 or lis")
else:
    try:
        import torchaudio
        from alignment import WFSTAligner, AlignmentConfig
        
        # Load sample audio
        print("\nüéµ Loading sample audio...")
        SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
        waveform, sr = torchaudio.load(SPEECH_URL)
        if sr != 16000:
            waveform = torchaudio.functional.resample(waveform, sr, 16000)
            sr = 16000
        if waveform.size(0) > 1:
            waveform = waveform[0:1]
        print(f"   Shape: {waveform.shape}, Duration: {waveform.size(1)/sr:.2f}s")
        
        # Load model
        print("\nüîß Loading MMS-FA model...")
        try:
            from labeling_utils import load_model
            model = load_model("mms-fa")
        except ImportError:
            bundle = torchaudio.pipelines.MMS_FA
            _model = bundle.get_model().to("cpu")
            class MockModelBackend:
                def __init__(self, model, bundle):
                    self._model, self._bundle = model, bundle
                def get_emissions(self, waveforms, lengths):
                    with torch.inference_mode():
                        return self._model(waveforms.squeeze(-1))
                def get_vocab_info(self):
                    class VI:
                        labels = tuple(bundle.get_labels())
                        blank_token, unk_token = '-', '*'
                    return VI()
            model = MockModelBackend(_model, bundle)
        print("   Model loaded")
        
        # Run alignment
        print("\nüîß Running alignment...")
        config = AlignmentConfig(backend="wfst", segment_size=15.0, overlap=2.0)
        aligner = WFSTAligner(config)
        aligner.set_model(model)
        alignment_result = aligner.align(waveform.squeeze(0), TRANSCRIPT)
        
        aligned_words = alignment_result.word_alignments
        print(f"\nüìä Aligned {len(aligned_words)} words")
        
        # Store for later tests
        ALIGNED_WORDS = aligned_words
        WAVEFORM = waveform
        SR = sr
        
        test_results["Test 12"] = "‚úÖ PASSED"
        print("\n‚úÖ Test 12 PASSED")
        
    except Exception as e:
        test_results["Test 12"] = "‚ùå FAILED"
        print(f"‚ùå Test 12 FAILED: {e}")
        import traceback; traceback.print_exc()

## Test 13: Prediction vs Ground Truth Comparison

In [None]:
print("=" * 60)
print("Test 13: Prediction vs Ground Truth Comparison")
print("=" * 60)

if "Test 12" not in test_results or "PASSED" not in test_results.get("Test 12", ""):
    test_results["Test 13"] = "‚è≠Ô∏è SKIPPED"
    print("‚è≠Ô∏è Skipped - Test 12 did not pass")
else:
    try:
        print("\nüìä Prediction vs Ground Truth:")
        print("-" * 90)
        print(f"{'Word':<12} {'GT Start':<10} {'Pred Start':<12} {'Œî Start':<10} {'GT End':<10} {'Pred End':<12} {'Status'}")
        print("-" * 90)
        
        total_start_error = 0
        matched = 0
        
        for gt in GROUND_TRUTH_WORDS:
            word = gt["word"]
            gt_start, gt_end = gt["start"], gt["end"]
            
            # Find prediction
            pred = None
            for idx, aligned in ALIGNED_WORDS.items():
                if aligned.word and aligned.word.upper() == word.upper():
                    pred = aligned
                    break
            
            if pred:
                pred_start = int(pred.start_time)
                pred_end = int(pred.end_time) if pred.end_time else pred_start + (gt_end - gt_start)
                delta = abs(pred_start - gt_start)
                total_start_error += delta
                matched += 1
                status = "‚úÖ" if delta <= 5 else ("‚ö†Ô∏è" if delta <= 10 else "‚ùå")
                print(f"{word:<12} {gt_start:<10} {pred_start:<12} {delta:<10} {gt_end:<10} {pred_end:<12} {status}")
            else:
                print(f"{word:<12} {gt_start:<10} {'N/A':<12} {'N/A':<10} {gt_end:<10} {'N/A':<12} ‚ùå NOT FOUND")
        
        print("-" * 90)
        
        if matched > 0:
            avg_error = total_start_error / matched
            print(f"\nüìà Summary:")
            print(f"   Matched: {matched}/{len(GROUND_TRUTH_WORDS)} words")
            print(f"   Avg start frame error: {avg_error:.1f} frames ({avg_error * 20:.0f}ms)")
            
            if avg_error <= 5:
                print("   Accuracy: ‚úÖ EXCELLENT")
            elif avg_error <= 10:
                print("   Accuracy: ‚ö†Ô∏è ACCEPTABLE")
            else:
                print("   Accuracy: ‚ùå NEEDS IMPROVEMENT")
        
        test_results["Test 13"] = "‚úÖ PASSED"
        print("\n‚úÖ Test 13 PASSED")
        
    except Exception as e:
        test_results["Test 13"] = "‚ùå FAILED"
        print(f"‚ùå Test 13 FAILED: {e}")
        import traceback; traceback.print_exc()

## Test 14: Listening Test (Audio Preview)

In [None]:
print("=" * 60)
print("Test 14: Listening Test")
print("=" * 60)

if "Test 12" not in test_results or "PASSED" not in test_results.get("Test 12", ""):
    test_results["Test 14"] = "‚è≠Ô∏è SKIPPED"
    print("‚è≠Ô∏è Skipped - Test 12 did not pass")
else:
    try:
        from IPython.display import Audio, display, HTML
        
        SAMPLES_PER_FRAME = SR // FRAME_RATE  # 320 samples per frame at 16kHz/50fps
        
        def get_audio_segment(start_frame, end_frame, padding_frames=2):
            """Extract audio segment by frame indices."""
            start_frame = max(0, start_frame - padding_frames)
            end_frame = end_frame + padding_frames
            x0 = int(start_frame * SAMPLES_PER_FRAME)
            x1 = min(int(end_frame * SAMPLES_PER_FRAME), WAVEFORM.size(1))
            return WAVEFORM[:, x0:x1]
        
        print("\nüéß Listening to aligned words (Prediction vs Ground Truth):")
        print("=" * 70)
        
        for gt in GROUND_TRUTH_WORDS:
            word = gt["word"]
            gt_start, gt_end = gt["start"], gt["end"]
            
            # Find prediction
            pred = None
            for idx, aligned in ALIGNED_WORDS.items():
                if aligned.word and aligned.word.upper() == word.upper():
                    pred = aligned
                    break
            
            print(f"\n{'='*70}")
            print(f"Word: {word}")
            print(f"{'='*70}")
            
            # Ground Truth audio
            gt_audio = get_audio_segment(gt_start, gt_end)
            gt_start_sec = gt_start / FRAME_RATE
            gt_end_sec = gt_end / FRAME_RATE
            print(f"\nüéØ Ground Truth: [{gt_start_sec:.3f}s - {gt_end_sec:.3f}s]")
            display(Audio(gt_audio.numpy(), rate=SR))
            
            # Prediction audio
            if pred:
                pred_start = int(pred.start_time)
                pred_end = int(pred.end_time) if pred.end_time else pred_start + (gt_end - gt_start)
                pred_audio = get_audio_segment(pred_start, pred_end)
                pred_start_sec = pred_start / FRAME_RATE
                pred_end_sec = pred_end / FRAME_RATE
                delta = abs(pred_start - gt_start)
                status = "‚úÖ" if delta <= 5 else ("‚ö†Ô∏è" if delta <= 10 else "‚ùå")
                print(f"\nüîÆ Prediction: [{pred_start_sec:.3f}s - {pred_end_sec:.3f}s] (Œî={delta} frames) {status}")
                display(Audio(pred_audio.numpy(), rate=SR))
            else:
                print(f"\nüîÆ Prediction: ‚ùå NOT FOUND")
        
        test_results["Test 14"] = "‚úÖ PASSED"
        print(f"\n{'='*70}")
        print("‚úÖ Test 14 PASSED")
        
    except Exception as e:
        test_results["Test 14"] = "‚ùå FAILED"
        print(f"‚ùå Test 14 FAILED: {e}")
        import traceback; traceback.print_exc()

## Test Summary

In [None]:
print("=" * 60)
print("üìã TEST RESULTS SUMMARY")
print("=" * 60)

print()
for test_name, result in sorted(test_results.items(), key=lambda x: int(x[0].split()[1]) if x[0].split()[1].isdigit() else 99):
    print(f"  {result}  {test_name}")

passed = sum(1 for r in test_results.values() if "‚úÖ" in r)
failed = sum(1 for r in test_results.values() if "‚ùå" in r)
skipped = sum(1 for r in test_results.values() if "‚è≠Ô∏è" in r or "‚ö†Ô∏è" in r)

print()
print(f"  Passed:  {passed}")
print(f"  Skipped: {skipped}")
print(f"  Failed:  {failed}")
print()

if failed == 0:
    print("üéâ All tests passed!")
else:
    print(f"‚ö†Ô∏è {failed} test(s) failed")