In [1]:
from datasets import load_dataset

wavlm_feature = load_dataset("hungphongtrn/wavlm-features", split="dev")
lllm_hidden_states = load_dataset("hungphongtrn/llm-features", split="dev")
gigaspeech = load_dataset("fixie-ai/gigaspeech", "dev", split="dev")
speech_time_alignment = load_dataset("hungphongtrn/speech-time-alignment", split="dev")

  from .autonotebook import tqdm as notebook_tqdm


In [24]:
gigaspeech['audio'][0].get_all_samples()

AudioSamples:
  data (shape): torch.Size([1, 48800])
  pts_seconds: 0.0
  duration_seconds: 3.05
  sample_rate: 16000

In [None]:
# Load WavLM features
import torch

wavlm_feat_sample = torch.tensor(wavlm_feature['wavlm_feat'][0], dtype=torch.float16)
wavlm_feat_sample.shape

torch.Size([713, 1024])

In [27]:
# Load audio
import torchaudio

audio_data = gigaspeech['audio'][0].get_all_samples().data
resampled_audio = torchaudio.functional.resample(audio_data, gigaspeech['audio'][0].get_all_samples().sample_rate, 24000)

In [28]:
resampled_audio.shape

torch.Size([1, 73200])

In [4]:
# Load LLM features
llm_feat_sample = torch.tensor(lllm_hidden_states['llm_feat'][0], dtype=torch.float16)
llm_feat_sample.shape


torch.Size([16, 2048])

In [6]:
import json 
sample_alignment = json.loads(speech_time_alignment[0]['alignment_json'])
timestamp_aligned = sample_alignment['char']
timestamp_aligned

[{'char': ['So'],
  'start_offset': 3,
  'end_offset': 5,
  'start': 0.24,
  'end': 0.4},
 {'char': ['I'],
  'start_offset': 7,
  'end_offset': 9,
  'start': 0.56,
  'end': 0.72},
 {'char': ['don'],
  'start_offset': 9,
  'end_offset': 10,
  'start': 0.72,
  'end': 0.8},
 {'char': ["'"],
  'start_offset': 10,
  'end_offset': 10,
  'start': 0.8,
  'end': 0.8},
 {'char': ['t'],
  'start_offset': 11,
  'end_offset': 12,
  'start': 0.88,
  'end': 0.96},
 {'char': ['know'],
  'start_offset': 12,
  'end_offset': 14,
  'start': 0.96,
  'end': 1.12},
 {'char': ['if'],
  'start_offset': 14,
  'end_offset': 16,
  'start': 1.12,
  'end': 1.28},
 {'char': ['there'],
  'start_offset': 16,
  'end_offset': 17,
  'start': 1.28,
  'end': 1.36},
 {'char': ["'"],
  'start_offset': 17,
  'end_offset': 17,
  'start': 1.36,
  'end': 1.36},
 {'char': ['s'],
  'start_offset': 18,
  'end_offset': 19,
  'start': 1.44,
  'end': 1.52},
 {'char': ['somet'],
  'start_offset': 19,
  'end_offset': 21,
  'start': 1.52

In [None]:
from torch.nn import functional as F

def align_wavlm_features(features):
    """
    Downsamples WavLM features from 50Hz to 12.5Hz (4x reduction)
    using Average Pooling.
    
    Args:
        features (torch.Tensor): Shape (Time, Dim) or (Batch, Time, Dim).
                                 Example: [713, 1024]
    
    Returns:
        torch.Tensor: Downsampled features.
                      Example: [178, 1024]
    """
    # 1. Handle unbatched input (Time, Dim) -> (1, Time, Dim)
    is_unbatched = features.dim() == 2
    if is_unbatched:
        features = features.unsqueeze(0)
    
    # 2. Permute to (Batch, Dim, Time) for PyTorch Pooling
    # Input is currently (Batch, Time, Dim)
    features_transposed = features.permute(0, 2, 1)
    
    # 3. Apply Average Pooling
    # kernel_size=4, stride=4 performs the 4x downsampling
    pooled = F.avg_pool1d(features_transposed, kernel_size=4, stride=4)
    
    # 4. Permute back to (Batch, Time, Dim)
    output = pooled.permute(0, 2, 1)
    
    # 5. Remove batch dim if input was unbatched
    if is_unbatched:
        output = output.squeeze(0)
        
    return output




torch.Size([713, 256])