In [None]:
from src.requirements import *

## Audio Preprocessing
- resample
- stereo to mono
- normalize
- plot

In [None]:
class MemoryMappedAudioDataset(Dataset):
    """Memory-mapped audio dataset with robust error handling."""
    
    def __init__(self, metadata_path, cache_dir='data/cache_mmap', top_db=20):
        super().__init__()
        self.df = pd.read_csv(metadata_path, sep="\t")
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.top_db = top_db
        
        dataset_name = Path(metadata_path).stem
        meta_file = self.cache_dir / f"{dataset_name}_meta.npz"
        audio_file = self.cache_dir / f"{dataset_name}_audio.dat"
        
        if not meta_file.exists():
            print("Preprocessing audio (first time only)...")
            self._preprocess_all(audio_file, meta_file)
        else:
            print(f"Loading metadata from cache...")
        
        meta = np.load(meta_file, allow_pickle=True)
        self.audio_shapes = meta['shapes']
        self.audio_offsets = meta['offsets']
        self.total_size = meta['total_size'].item()
        
        self.audio_mmap = np.memmap(audio_file, dtype='float32', mode='r', shape=(self.total_size,))
        
        print(f"✓ Cache ready! {len(self.df)} samples")
        print(f"  Total audio size: {self.total_size * 4 / (1024**3):.2f} GB")
    
    def _preprocess_all(self, audio_file, meta_file):
        """Preprocess and save - single pass, in-order."""
        import time
        start = time.time()
        
        print("Processing audio files...")
        
        # Collect ALL audio first
        all_audio_data = []
        all_shapes = []
        
        for idx in tqdm(range(len(self.df)), desc="Loading & preprocessing"):
            path = self.df.iloc[idx]['path']
            
            # Load
            waveform, sr = sf.read(path, always_2d=True)
            waveform = np.array(waveform.T, dtype=np.float32)
            
            # Stereo to mono
            if waveform.shape[0] > 1:
                waveform = waveform.mean(axis=0)
            else:
                waveform = waveform[0]
            
            # Trim
            trimmed, _ = librosa.effects.trim(waveform, top_db=self.top_db)
            
            # Normalize
            max_val = np.abs(trimmed).max()
            if max_val > 0:
                trimmed = trimmed / max_val
            
            # Store
            all_audio_data.append(trimmed.astype(np.float32))
            all_shapes.append(len(trimmed))
        
        # Calculate total size
        total_size = sum(all_shapes)
        print(f"Total samples: {total_size:,} ({total_size * 4 / (1024**3):.2f} GB)")
        
        # Create offsets
        offsets = np.zeros(len(all_shapes) + 1, dtype=np.int64)
        np.cumsum(all_shapes, out=offsets[1:])
        
        # Write to file
        print("Writing to memory-mapped file...")
        mmap = np.memmap(audio_file, dtype='float32', mode='w+', shape=(total_size,))
        
        for i in tqdm(range(len(all_audio_data)), desc="Writing"):
            start_pos = offsets[i]
            end_pos = offsets[i + 1]
            mmap[start_pos:end_pos] = all_audio_data[i]
        
        mmap.flush()
        del mmap
        
        # Save metadata
        np.savez(meta_file,
                 shapes=np.array(all_shapes, dtype=np.int64),
                 offsets=offsets,
                 total_size=np.array(total_size, dtype=np.int64))
        
        print(f"Done in {time.time() - start:.1f}s")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        start = self.audio_offsets[idx]
        end = self.audio_offsets[idx + 1]
        return torch.from_numpy(self.audio_mmap[start:end].copy())

In [None]:
import time

t0 = time.time()
ssl_train_dataset = MemoryMappedAudioDataset(
    metadata_path='data/metadata_normal.tsv',
    cache_dir='data/cache_mmap/ssl',
    top_db=TOP_DB
)
t1 = time.time()

In [None]:
class MemoryMappedASRDataset(Dataset):
    """Memory-mapped ASR dataset with robust error handling."""
    
    def __init__(self, metadata_path, tokenizer, cache_dir='data/cache_mmap_asr', top_db=20):
        super().__init__()
        self.df = pd.read_csv(metadata_path, sep="\t")
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.tokenizer = tokenizer
        self.top_db = top_db
        
        dataset_name = Path(metadata_path).stem
        meta_file = self.cache_dir / f"{dataset_name}_meta.npz"
        audio_file = self.cache_dir / f"{dataset_name}_audio.dat"
        
        if not meta_file.exists():
            print("Preprocessing audio (first time only)...")
            self._preprocess_all(audio_file, meta_file)
        else:
            print(f"Loading metadata from cache...")
        
        meta = np.load(meta_file, allow_pickle=True)
        self.audio_shapes = meta['shapes']
        self.audio_offsets = meta['offsets']
        self.total_size = meta['total_size'].item()
        
        self.audio_mmap = np.memmap(audio_file, dtype='float32', mode='r', shape=(self.total_size,))
        
        print(f"✓ Cache ready! {len(self.df)} samples")
        print(f"  Total audio size: {self.total_size * 4 / (1024**3):.2f} GB")
        
        # Encode transcripts
        print("Encoding transcripts...")
        self.encode_transcripts = []
        for t in tqdm(self.df['transcript'].tolist(), desc="Encoding"):
            encoded = tokenizer.encode(t)
            encoded = [i if i >= 0 else 0 for i in encoded]
            self.encode_transcripts.append(torch.tensor(encoded, dtype=torch.long))
    
    def _preprocess_all(self, audio_file, meta_file):
        """Preprocess and save - single pass, in-order."""
        import time
        start = time.time()
        
        print("Processing audio files...")
        
        # Collect ALL audio first
        all_audio_data = []
        all_shapes = []
        
        for idx in tqdm(range(len(self.df)), desc="Loading & preprocessing"):
            path = self.df.iloc[idx]['path']
            
            # Load
            waveform, sr = sf.read(path, always_2d=True)
            waveform = np.array(waveform.T, dtype=np.float32)
            
            # Stereo to mono
            if waveform.shape[0] > 1:
                waveform = waveform.mean(axis=0)
            else:
                waveform = waveform[0]
            
            # Trim
            trimmed, _ = librosa.effects.trim(waveform, top_db=self.top_db)
            
            # Normalize
            max_val = np.abs(trimmed).max()
            if max_val > 1:
                trimmed = trimmed / max_val
            
            # Store
            all_audio_data.append(trimmed.astype(np.float32))
            all_shapes.append(len(trimmed))
        
        # Calculate total size
        total_size = sum(all_shapes)
        print(f"Total samples: {total_size:,} ({total_size * 4 / (1024**3):.2f} GB)")
        
        # Create offsets
        offsets = np.zeros(len(all_shapes) + 1, dtype=np.int64)
        np.cumsum(all_shapes, out=offsets[1:])
        
        # Write to file
        print("Writing to memory-mapped file...")
        mmap = np.memmap(audio_file, dtype='float32', mode='w+', shape=(total_size,))
        
        for i in tqdm(range(len(all_audio_data)), desc="Writing"):
            start_pos = offsets[i]
            end_pos = offsets[i + 1]
            mmap[start_pos:end_pos] = all_audio_data[i]
        
        mmap.flush()
        del mmap
        
        # Save metadata
        np.savez(meta_file,
                 shapes=np.array(all_shapes, dtype=np.int64),
                 offsets=offsets,
                 total_size=np.array(total_size, dtype=np.int64))
        
        print(f"✓ Done in {time.time() - start:.1f}s")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        start = self.audio_offsets[idx]
        end = self.audio_offsets[idx + 1]
        waveform = torch.from_numpy(self.audio_mmap[start:end].copy())
        target = self.encode_transcripts[idx]
        return waveform, target

In [None]:
from src.tokenizer import *
tokenizer = Tokenizer.load(os.path.join("data", "tokenizer.json"))

In [None]:
# # Subsequent runs: loads from cache (FAST!)
asr_train_dataset = MemoryMappedASRDataset(
    metadata_path='data/metadata_normal.tsv',
    tokenizer=tokenizer,
    cache_dir='data/cache_mmap/asr',
    top_db=TOP_DB
)