In [1]:
import os
import torch
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import soundfile as sf
import librosa
import numpy as np

# ==============================================================================
# 1. GPU ENABLED EXTRACTOR
# ==============================================================================
class HandcraftedFeatureExtractor:
    def __init__(self, device, sample_rate=16000, n_mfcc=40, n_mels=80, n_fft=400, hop_length=160):
        self.sample_rate = sample_rate
        self.hop_length = hop_length
        self.n_fft = n_fft
        self.device = device 

        # --- Move transforms to GPU immediately upon initialization ---
        self.mfcc_transform = T.MFCC(
            sample_rate=sample_rate, n_mfcc=n_mfcc,
            melkwargs={"n_fft": n_fft, "n_mels": n_mels, "hop_length": hop_length, "center": False}
        ).to(self.device)

        self.mel_transform = T.MelSpectrogram(
            sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mels, 
            hop_length=hop_length, center=False
        ).to(self.device)

    def _get_pitch_cpu(self, waveform_cpu):
        # This function runs on CPU because Librosa does not support GPU directly
        wav_numpy = waveform_cpu.squeeze().numpy()
        
        try:
            # Using pyin to ensure quality
            f0, _, _ = librosa.pyin(
                wav_numpy, 
                fmin=60, fmax=500, sr=self.sample_rate, 
                hop_length=self.hop_length, frame_length=self.n_fft, center=False
            )
        except:
            f0 = np.zeros(1)

        f0 = np.nan_to_num(f0)
        return torch.from_numpy(f0).view(1, 1, -1).float()

    def _align_length(self, feat1, feat2):
        min_len = min(feat1.shape[-1], feat2.shape[-1])
        return feat1[..., :min_len], feat2[..., :min_len]

    def _apply_cmvn(self, feature):
        if feature is not None:
            mean = feature.mean(dim=-1, keepdim=True)
            std = feature.std(dim=-1, keepdim=True)
            feature = (feature - mean) / (std + 1e-6)
        return feature

    def extract_all(self, waveform_cpu):
        """
        Extracts ALL 5 modes at once.
        Returns a dictionary: { 'ModeName': Tensor, ... }
        """
        output_dict = {}

        # 1. GPU Processing
        waveform_gpu = waveform_cpu.to(self.device)
        if waveform_gpu.dim() == 1: waveform_gpu = waveform_gpu.unsqueeze(0)

        mfcc_base = self.mfcc_transform(waveform_gpu)
        mel_base = self.mel_transform(waveform_gpu)
        mfbe_base = torch.log(mel_base + 1e-6)

        # 2. CPU Processing (Pitch)
        pitch_cpu = self._get_pitch_cpu(waveform_cpu)
        pitch_base = pitch_cpu.to(self.device) 

        if mfcc_base.dim() == 2: mfcc_base = mfcc_base.unsqueeze(0)
        if mfbe_base.dim() == 2: mfbe_base = mfbe_base.unsqueeze(0)

        # Combine Features
        output_dict["Only MFCC"] = self._apply_cmvn(mfcc_base)
        output_dict["Only MFBE"] = self._apply_cmvn(mfbe_base)
        output_dict["Only Pitch"] = self._apply_cmvn(pitch_base.squeeze(0))

        mfcc_aligned, pitch_aligned_1 = self._align_length(mfcc_base, pitch_base)
        output_dict["MFCC + Pitch"] = self._apply_cmvn(torch.cat([mfcc_aligned, pitch_aligned_1], dim=1))

        mfbe_aligned, pitch_aligned_2 = self._align_length(mfbe_base, pitch_base)
        output_dict["MFBE + Pitch"] = self._apply_cmvn(torch.cat([mfbe_aligned, pitch_aligned_2], dim=1))

        return output_dict

# ==============================================================================
# 2. DATASET (UPDATED WITH RESUME LOGIC)
# ==============================================================================
class AudioFolderDataset(Dataset):
    def __init__(self, root_dir, output_base_dir, modes_list, extractor, sample_rate=16000):
        self.root_dir = root_dir
        self.extractor = extractor
        self.sample_rate = sample_rate
        self.file_list = []
        
        print(f"-> Scanning input directory: {root_dir}")
        print(f"-> Checking output directory for existing files to resume...")

        total_files = 0
        skipped_files = 0
        
        # Walk through directories
        for root, _, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.wav', '.flac', '.mp3')):
                    total_files += 1
                    input_path = os.path.join(root, file)
                    
                    # --- RESUME LOGIC ---
                    # 1. Calculate the relative path (e.g., "Speaker1/audio01.wav")
                    rel_path = os.path.relpath(input_path, root_dir)
                    # 2. Change extension to .pt (e.g., "Speaker1/audio01.pt")
                    rel_path_pt = os.path.splitext(rel_path)[0] + ".pt"
                    
                    # 3. Check if THIS file exists in ALL 5 output mode folders
                    all_outputs_exist = True
                    for mode in modes_list:
                        expected_output_path = os.path.join(output_base_dir, mode, rel_path_pt)
                        if not os.path.exists(expected_output_path):
                            all_outputs_exist = False
                            break
                    
                    # 4. If all exist, skip. If any is missing, process it.
                    if all_outputs_exist:
                        skipped_files += 1
                    else:
                        self.file_list.append(input_path)

        print(f"-> Total found: {total_files}")
        print(f"-> Skipped (Already processed): {skipped_files}")
        print(f"-> Remaining to process: {len(self.file_list)}")

    def __len__(self): return len(self.file_list)

    def __getitem__(self, idx):
        path = self.file_list[idx]
        try:
            wav_numpy, sr = sf.read(path)
            waveform = torch.from_numpy(wav_numpy).float()
            
            if waveform.dim() == 1: waveform = waveform.unsqueeze(0)
            else: waveform = waveform.t()
            if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True)
            if sr != self.sample_rate: waveform = T.Resample(sr, self.sample_rate)(waveform)

            features_dict_gpu = self.extractor.extract_all(waveform)
            features_dict_cpu = {k: v.cpu().squeeze(0) for k, v in features_dict_gpu.items()}
            
            return features_dict_cpu, path 
        except Exception as e:
            print(f"Error processing {path}: {e}")
            return None, None

def collate_fn(batch):
    batch = [b for b in batch if b[0] is not None]
    if not batch: return None
    list_of_dicts, paths = zip(*batch)
    batched_output = {}
    keys = list_of_dicts[0].keys()

    for key in keys:
        tensors = [d[key] for d in list_of_dicts]
        max_len = max([t.shape[-1] for t in tensors])
        padded_tensors = [torch.nn.functional.pad(t, (0, max_len - t.shape[-1])) for t in tensors]
        batched_output[key] = torch.stack(padded_tensors)
        
    return batched_output, paths

# ==============================================================================
# 3. EXECUTION
# ==============================================================================
if __name__ == "__main__":
    # --- CONFIGURATION ---
    INPUT_PATH = r"D:\Speech_Verification\cut_audio_5s"
    OUTPUT_BASE_PATH = r"D:\Speech_Verification\cut_audio_5s_features"
    
    # Must list ALL modes here to check for completion
    MODES_TO_SAVE = ["Only MFCC", "Only MFBE", "Only Pitch", "MFCC + Pitch", "MFBE + Pitch"]
    
    # --- SETUP GPU ---
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"üî• Running on GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
        print("‚ö†Ô∏è GPU not found, running on CPU.")

    extractor = HandcraftedFeatureExtractor(device=device)
    
    # Pass output path and modes to Dataset for Resume Logic
    dataset = AudioFolderDataset(
        root_dir=INPUT_PATH, 
        output_base_dir=OUTPUT_BASE_PATH, 
        modes_list=MODES_TO_SAVE, 
        extractor=extractor
    )
    
    # If everything is processed, exit early
    if len(dataset) == 0:
        print("\n‚úÖ All files have been processed. Nothing to do!")
        exit()

    loader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0, collate_fn=collate_fn)

    print(f"\nStarting extraction...")
    
    # Create sub-folders (idempotent)
    for mode in MODES_TO_SAVE:
        os.makedirs(os.path.join(OUTPUT_BASE_PATH, mode), exist_ok=True)

    # Process loop
    for batch in tqdm(loader):
        if batch is None: continue
        
        batched_features, paths = batch
        
        for mode, features_tensor in batched_features.items():
            mode_output_dir = os.path.join(OUTPUT_BASE_PATH, mode)
            
            for i in range(len(paths)):
                rel_path = os.path.relpath(paths[i], INPUT_PATH)
                save_path = os.path.join(mode_output_dir, os.path.splitext(rel_path)[0] + ".pt")
                
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                torch.save(features_tensor[i].clone(), save_path)

    print(f"\nCompleted! Check: {OUTPUT_BASE_PATH}")

üî• Running on GPU: NVIDIA GeForce RTX 5080
-> Scanning input directory: D:\Speech_Verification\cut_audio_5s
-> Checking output directory for existing files to resume...
-> Total found: 394182
-> Skipped (Already processed): 480
-> Remaining to process: 393702

Starting extraction...


  f0, _, _ = librosa.pyin(
  std = feature.std(dim=-1, keepdim=True)
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 24607/24607 [62:35:43<00:00,  9.16s/it]   


Completed! Check: D:\Speech_Verification\cut_audio_5s_features



