In [4]:
from datasets import load_from_disk

ds = load_from_disk("/home/hungphongtrn/Workspace/Amy-LM/data/Amy-LM-Dataset")

In [7]:
ds

Dataset({
    features: ['segment_id', 'llm_feat', 'llm_times', 'wavlm_feat', 'audio'],
    num_rows: 5700
})

In [19]:
import numpy as np
segment_id = ds["segment_id"][0]
llm_features = np.array(ds["llm_feat"][0])
llm_times = np.array(ds["llm_times"][0])
wavlm_features = np.array(ds["wavlm_feat"][0])
audio = ds["audio"][0]



In [21]:
print("segment_id", segment_id)
print("llm_features", llm_features.shape)
print("llm_times", llm_times)
print("wavlm_features", wavlm_features.shape)
print("audio", audio.get_all_samples())


segment_id YOU1000000044_S0000363
llm_features (32, 2048)
llm_times [[ 0.    0.24]
 [ 0.    0.24]
 [ 0.    0.24]
 [ 0.    0.24]
 [ 0.24  0.4 ]
 [ 0.4   0.56]
 [ 0.4   0.56]
 [ 0.4   0.56]
 [ 0.4   0.56]
 [ 0.56  0.72]
 [ 0.72  0.96]
 [ 0.72  0.96]
 [ 0.96  1.12]
 [ 1.12  1.28]
 [ 1.28  1.52]
 [ 1.28  1.52]
 [ 1.52  1.84]
 [ 1.84  2.  ]
 [ 2.    2.16]
 [ 2.16  2.32]
 [ 2.32  2.4 ]
 [ 2.4   2.56]
 [ 2.4   2.56]
 [ 2.56  2.64]
 [ 2.56  2.64]
 [ 2.56  2.64]
 [ 2.56  2.64]
 [ 2.64  2.88]
 [ 2.88 -1.  ]
 [ 2.88 -1.  ]
 [ 2.88 -1.  ]
 [ 2.88 -1.  ]]
wavlm_features (713, 1024)
audio AudioSamples:
  data (shape): torch.Size([1, 48800])
  pts_seconds: 0.0
  duration_seconds: 3.05
  sample_rate: 16000



In [24]:
import numpy as np
import torch

def process_batch(batch, target_fps=25):
    """
    Processing function compatible with ds.map(..., batched=True).
    """
    batch_size = len(batch['segment_id'])
    
    # Initialize output lists
    out_llm_feats = []
    out_wavlm_feats = []
    out_lengths = []
    
    # Iterate through each sample in the batch
    for i in range(batch_size):
        # -------------------------------------------------------------
        # 1. Setup Individual Item Data
        # -------------------------------------------------------------
        # Audio info
        audio_array = batch['audio'][i]['array']
        sr = batch['audio'][i]['sampling_rate']
        duration = len(audio_array) / sr
        
        # Calculate target frames for this specific item
        num_frames = int(np.ceil(duration * target_fps))
        out_lengths.append(num_frames)

        # Create Time Grid (A and B)
        frame_indices = np.arange(num_frames)
        grid_starts_A = frame_indices / target_fps
        grid_ends_B = (frame_indices + 1) / target_fps

        # -------------------------------------------------------------
        # 2. Process LLM Features (Integral Image Alignment)
        # -------------------------------------------------------------
        llm_feat = np.array(batch['llm_feat'][i])
        llm_times = np.array(batch['llm_times'][i])

        # Fix -1 in end times (replace with actual duration)
        # Note: We must copy to avoid mutating the cached dataset in place unexpectedly
        t_ends = llm_times[:, 1].copy()
        t_ends[t_ends == -1] = duration
        t_starts = llm_times[:, 0]

        # --- Vectorized Binary Search ---
        
        # A: Closest smaller start time -> First token with that start
        idx_closest_start = np.searchsorted(t_starts, grid_starts_A, side='right') - 1
        idx_closest_start = np.clip(idx_closest_start, 0, len(t_starts) - 1)
        val_closest_start = t_starts[idx_closest_start]
        final_idx_starts = np.searchsorted(t_starts, val_closest_start, side='left')

        # B: Closest larger end time -> Last token with that end
        idx_closest_end = np.searchsorted(t_ends, grid_ends_B, side='left')
        idx_closest_end = np.clip(idx_closest_end, 0, len(t_ends) - 1)
        val_closest_end = t_ends[idx_closest_end]
        final_idx_ends = np.searchsorted(t_ends, val_closest_end, side='right') - 1

        # --- Integral Image Pooling ---
        # Prefix sum (pad with 0 at start)
        feat_cumsum = np.vstack([np.zeros((1, llm_feat.shape[1])), np.cumsum(llm_feat, axis=0)])
        
        # Sum = Cumulative[End+1] - Cumulative[Start]
        # We ensure indices are valid for the cumsum array
        sums = feat_cumsum[final_idx_ends + 1] - feat_cumsum[final_idx_starts]
        
        # Mean = Sum / Count
        counts = (final_idx_ends - final_idx_starts + 1).reshape(-1, 1)
        counts = np.maximum(counts, 1) # Prevent division by zero
        aligned_llm = sums / counts
        
        out_llm_feats.append(aligned_llm.astype(np.float32))

        # -------------------------------------------------------------
        # 3. Process WavLM Features (Adaptive Pooling)
        # -------------------------------------------------------------
        wavlm_tensor = torch.tensor(batch['wavlm_feat'][i]).float() # (Src_Len, 1024)
        
        # Transpose to (1, Channel, Time) for pooling
        wavlm_tensor = wavlm_tensor.transpose(0, 1).unsqueeze(0)
        
        # Pool to exact number of frames
        wavlm_aligned = torch.nn.functional.adaptive_avg_pool1d(wavlm_tensor, num_frames)
        
        # Transpose back: (Time, Channel)
        wavlm_aligned = wavlm_aligned.squeeze(0).transpose(0, 1)
        
        out_wavlm_feats.append(wavlm_aligned.numpy())

    # Return dictionary of lists (updates the dataset columns)
    return {
        "llm_feat": out_llm_feats,      # List of (T, 2048)
        "wavlm_feat": out_wavlm_feats,  # List of (T, 1024)
        "num_frames": out_lengths       # List of ints
    }

# --- How to apply ---
# updated_ds = ds.map(
#     lambda x: process_batch(x, target_fps=25), 
#     batched=True, 
#     batch_size=32, 
#     num_proc=4,   # Safe to use multiprocessing with this logic
#     remove_columns=["audio", "llm_times"] # Optional: clean up columns you don't need
# )

In [28]:
processed_output = process_batch(ds[:1])

In [29]:
processed_output

{'llm_feat': [array([[ 14.856689 ,  -9.355469 ,  14.222656 , ..., -13.0217285,
          -12.165039 ,   7.4833984],
         [ 14.856689 ,  -9.355469 ,  14.222656 , ..., -13.0217285,
          -12.165039 ,   7.4833984],
         [ 14.856689 ,  -9.355469 ,  14.222656 , ..., -13.0217285,
          -12.165039 ,   7.4833984],
         ...,
         [ 27.441406 ,  11.6953125,  14.611328 , ..., -43.984375 ,
          -23.296875 ,  -1.7578125],
         [ 27.441406 ,  11.6953125,  14.611328 , ..., -43.984375 ,
          -23.296875 ,  -1.7578125],
         [ 27.441406 ,  11.6953125,  14.611328 , ..., -43.984375 ,
          -23.296875 ,  -1.7578125]], shape=(77, 2048), dtype=float32)],
 'wavlm_feat': [array([[ 0.04122524,  0.19869384, -0.13108978, ..., -0.10124207,
           0.06408997, -0.1493988 ],
         [ 0.09528179, -0.07081299, -0.18926391, ..., -0.07943535,
          -0.1532257 ,  0.14864807],
         [ 0.01480007, -0.16293183,  0.06273498, ..., -0.02133637,
          -0.08496094,  0