In [38]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
from torch.utils.data import Dataset, DataLoader
from torch.nn import ReLU, GELU


In [39]:
# data_path = "phase2_data/subset_80k.csv"

# df = pd.read_csv(data_path)

# import os

# # Get list of mp3 files in the directory
# mp3_files = [f for f in os.listdir('phase2_data/subset_80k_audio') if f.endswith('.mp3')]

# # Filter dataframe to keep only entries corresponding to existing mp3 files
# df = df[df['audio_file'].isin(mp3_files)]

# df.to_csv("45K_audio.csv")

In [40]:
# before you build any tensor / module
import torch, os
torch.set_default_dtype(torch.float16)   # makes every new tensor FP16
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"  # safe CPU fallback

def _nearest_code(x, weight):
    # x: [N, C]  weight: [num_codes, C]
    # ‖x‖² + ‖w‖² − 2·x·wᵀ
    xx = (x * x).sum(-1, keepdim=True)            # [N, 1]
    ww = (weight * weight).sum(-1).unsqueeze(0)   # [1, K]
    dist = xx + ww - 2.0 * x @ weight.T           # [N, K]
    return dist.argmin(-1)                        # [N]




In [41]:
def preprocess_audio(file_path, sample_rate=16000, target_duration=10.0):

    audio, sr = librosa.load(file_path, sr=sample_rate)
    audio = librosa.util.normalize(audio)
    target_length = int(sample_rate * target_duration)
    
    if len(audio) < target_length:
        audio = np.pad(audio, (0, target_length - len(audio)))
    else:
        audio = audio[:target_length]
    return audio

In [49]:
import os
import torch
import pandas as pd
import numpy as np
import librosa
from torch.utils.data import Dataset

class ArabicAudioDataset(Dataset):
    def __init__(self, data_path="45K_audio.csv", audio_folder="phase2_data/subset_80k_audio",
                 sample_rate=16000, target_duration=10.0, n_mels=80):
        self.data = pd.read_csv(data_path)
        self.audio_folder = audio_folder
        self.sample_rate = sample_rate
        self.target_duration = target_duration
        self.n_mels = n_mels
        self.target_length = int(sample_rate * target_duration)

    def __len__(self):
        return len(self.data)

    def load_and_preprocess_audio(self, full_path):
        try:
            audio, sr = librosa.load(full_path, sr=self.sample_rate)
            if len(audio) > self.target_length:
                audio = audio[:self.target_length]
            else:
                pad_length = self.target_length - len(audio)
                audio = np.pad(audio, (0, pad_length), mode='constant')
            return audio
        except Exception as e:
            print(f"Error loading {full_path}: {e}")
            return np.zeros(self.target_length, dtype=np.float32)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        file_name = row['audio_file']
        full_path = os.path.join(self.audio_folder, file_name)

        audio = self.load_and_preprocess_audio(full_path)

        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=self.sample_rate,
            n_mels=self.n_mels,
            n_fft=1024,
            hop_length=256,
            win_length=1024
        )
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        return {
            'audio_file': row['audio_file'],
            'text': row['clean_text'],
            'length': float(row['length']),
            'mel_spec': torch.FloatTensor(mel_spec_db),
            'audio': torch.FloatTensor(audio)}

In [50]:
import torch, torch.nn as nn
torch.set_default_dtype(torch.float16)

class RVQ(nn.Module):
    def __init__(self, num_quantizers=4, num_codes=512, latent_dim=256):
        super().__init__()
        self.codebooks = nn.ModuleList([
            nn.Embedding(num_codes, latent_dim) for _ in range(num_quantizers)
        ])

    # @torch.no_grad()          # remove if training
    def _nearest_code(self, x, weight):
        xx = (x * x).sum(-1, keepdim=True)          # [N,1]
        ww = (weight * weight).sum(-1).unsqueeze(0) # [1,K]
        return (xx + ww - 2 * x @ weight.T).argmin(-1)

    def forward(self, z, chunk=4096):
        B, C, T = z.size()
        z = z.permute(0, 2, 1).reshape(-1, C)       # [B*T, C]

        residual = z
        all_q, all_idx = [], []

        for codebook in self.codebooks:
            q_chunks, i_chunks = [], []
            for s in range(0, residual.size(0), chunk):
                e = s + chunk
                idx = self._nearest_code(residual[s:e], codebook.weight)
                i_chunks.append(idx)
                q_chunks.append(codebook(idx))
            idx = torch.cat(i_chunks, 0)
            q   = torch.cat(q_chunks, 0)

            all_q.append(q)
            all_idx.append(idx)
            residual -= q
            

        quantized = torch.stack(all_q).sum(0)              # [B*T, C]
        indices   = torch.stack(all_idx)                   # [Q, B*T]
        quantized = quantized.view(B, T, C).permute(0, 2, 1)
        return quantized, indices


In [51]:


class Encoder(nn.Module):

    def __init__(self, input_dim=80):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv1d(input_dim, 128, 3, padding=1),   # 256→128
            nn.ReLU(),
            nn.Conv1d(128, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(128, 256, 3, stride=2, padding=1),# 512→256
            nn.ReLU(),
            nn.Conv1d(256, 256, 3, padding=1),
            nn.ReLU())
        

    def forward(self, x):
        return self.conv_layers(x)

In [56]:
class Decoder(nn.Module):
    def __init__(self, output_dim=80):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.ConvTranspose1d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(512, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(256, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(256, output_dim, 3, padding=1)
        )
        
    def forward(self, x):
        return self.conv_layers(x)

In [57]:
class AcousticCodec(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.rvq = RVQ()
        self.decoder = Decoder()
        
    def forward(self, mel_spec):
        encoded = self.encoder(mel_spec)
        quantized, indices = self.rvq(encoded)
        reconstructed = self.decoder(quantized)
        return reconstructed, indices, encoded

In [62]:
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.mps import is_available as mps_ok

# Mixed precision support toggle
use_mps16 = mps_ok()
scaler = torch.cuda.amp.GradScaler(enabled=use_mps16)

def train_acoustic_codec(num_epochs=50, batch_size=16, grad_accum_steps=4, learning_rate=1e-4):
    device = 'mps' if use_mps16 else 'cpu'
    print(f"Using device: {device}")
    
    dataset = ArabicAudioDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model = AcousticCodec().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        total_loss = 0
        optimizer.zero_grad()
        
        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)):
            mel_specs = batch['mel_spec'].to(device)

            with torch.autocast(device_type='mps', dtype=torch.float16, enabled=use_mps16):
                # Before loss computation


                reconstructed, indices, encoded = model(mel_specs)
                min_len = min(reconstructed.shape[-1], mel_specs.shape[-1])
                reconstructed = reconstructed[..., :min_len]
                mel_specs = mel_specs[..., :min_len]

                # Now it's safe
                recon_loss = F.mse_loss(reconstructed, mel_specs)
                recon_loss = F.mse_loss(reconstructed, mel_specs)
                loss = recon_loss 
            
            loss = loss / grad_accum_steps
            scaler.scale(loss).backward()
            
            if (step + 1) % grad_accum_steps == 0 or step == len(dataloader) - 1:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            total_loss += loss.item() * grad_accum_steps

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
    
    return model


  scaler = torch.cuda.amp.GradScaler(enabled=use_mps16)


In [64]:
if __name__ == "__main__":
    print("Starting training...")

    model = train_acoustic_codec(
        num_epochs=15,
        batch_size=16,
        learning_rate=0.0001
    )
    
    torch.save(model.state_dict(), 'acoustic_codec_final.pth')
    print("Training completed and model saved!")

Starting training...
Using device: mps


Epochs:   7%|▋         | 1/15 [04:43<1:06:12, 283.78s/it]

Epoch [1/15], Average Loss: nan


Epochs:  13%|█▎        | 2/15 [09:40<1:03:05, 291.18s/it]

Epoch [2/15], Average Loss: nan


Epochs:  13%|█▎        | 2/15 [10:16<1:06:44, 308.07s/it]


KeyboardInterrupt: 