In [36]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
from torch.utils.data import Dataset, DataLoader

In [37]:
# data_path = "phase2_data/subset_80k.csv"

# df = pd.read_csv(data_path)

# import os

# # Get list of mp3 files in the directory
# mp3_files = [f for f in os.listdir('phase2_data/subset_80k_audio') if f.endswith('.mp3')]

# # Filter dataframe to keep only entries corresponding to existing mp3 files
# df = df[df['audio_file'].isin(mp3_files)]

# df.to_csv("45K_audio.csv")

In [38]:
def preprocess_audio(file_path, sample_rate=16000, target_duration=10.0):

    audio, sr = librosa.load(file_path, sr=sample_rate)
    audio = librosa.util.normalize(audio)
    target_length = int(sample_rate * target_duration)
    
    if len(audio) < target_length:
        audio = np.pad(audio, (0, target_length - len(audio)))
    else:
        audio = audio[:target_length]
    return audio

In [46]:
from tqdm import tqdm

class ArabicAudioDataset(Dataset):
    def __init__(self, data_path="45K_audio.csv", audio_folder="phase2_data/subset_80k_audio", sample_rate=16000, target_duration=10.0, n_mels=80):
        self.data = pd.read_csv(data_path)
        self.audio_folder = audio_folder
        self.sample_rate = sample_rate
        self.target_duration = target_duration
        self.n_mels = n_mels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        audio_path = "phase2_data/subset_80k_audio"
        
        audio = preprocess_audio(audio_path, sample_rate=self.sample_rate, target_duration=self.target_duration)
        
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=self.sample_rate,
            n_mels=self.n_mels,
            n_fft=1024,
            hop_length=256,
            win_length=1024
        )
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        return {
            'audio_file': row['audio_file'],
            'text': row['clean_text'],
            'length': float(row['length']),
            'mel_spec': torch.FloatTensor(mel_spec_db),
            'audio': torch.FloatTensor(audio)}


In [47]:
class RVQ(nn.Module):
    def __init__(self, num_quantizers=8, num_codes=1024, latent_dim=512):
        super().__init__()
        self.num_quantizers = num_quantizers
        self.codebooks = nn.ModuleList([
            nn.Embedding(num_codes, latent_dim)
            for _ in range(num_quantizers)
        ])
        
    def forward(self, z):
        B, C, T = z.size()
        z = z.permute(0, 2, 1).reshape(-1, C)  # [B*T, C]
        
        quantized = []
        indices = []
        residual = z
        
        for codebook in self.codebooks:
            distances = torch.cdist(residual.unsqueeze(1), codebook.weight.unsqueeze(0))
            min_indices = torch.argmin(distances, dim=-1)
            quantized_vectors = codebook(min_indices)
            
            quantized.append(quantized_vectors)
            indices.append(min_indices)
            
            residual = residual - quantized_vectors
            
        quantized = torch.stack(quantized).sum(0)
        indices = torch.stack(indices)
        
        quantized = quantized.reshape(B, T, C).permute(0, 2, 1)
        return quantized, indices

In [48]:
class Encoder(nn.Module):
    def __init__(self, input_dim=80):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv1d(input_dim, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, 256, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, 512, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(512, 512, 3, padding=1),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.conv_layers(x)

In [49]:
class Decoder(nn.Module):
    def __init__(self, output_dim=80):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.ConvTranspose1d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(512, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(256, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(256, output_dim, 3, padding=1)
        )
        
    def forward(self, x):
        return self.conv_layers(x)

In [50]:
class AcousticCodec(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.rvq = RVQ()
        self.decoder = Decoder()
        
    def forward(self, mel_spec):
        encoded = self.encoder(mel_spec)
        quantized, indices = self.rvq(encoded)
        reconstructed = self.decoder(quantized)
        return reconstructed, indices, encoded

In [51]:
from tqdm import tqdm

def train_acoustic_codec(num_epochs=50, batch_size=16, grad_accum_steps=4, learning_rate=1e-4):

    device = 'mps'
    print(f"Using device: {device}")
    
    dataset = ArabicAudioDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    
    model = AcousticCodec().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        total_loss = 0
        optimizer.zero_grad()
        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)):
            mel_specs = batch['mel_spec'].to(device)
            
            reconstructed, indices, encoded = model(mel_specs)
            
            recon_loss = F.mse_loss(reconstructed, mel_specs)
            commitment_loss = F.mse_loss(encoded, reconstructed.detach())
            loss = recon_loss + 0.25 * commitment_loss
            loss = loss / grad_accum_steps
            
            loss.backward()
            
            if (step + 1) % grad_accum_steps == 0 or step == len(dataloader) - 1:
                optimizer.step()
                optimizer.zero_grad()
            
            total_loss += loss.item() * grad_accum_steps
            
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
        
    return model

In [52]:
if __name__ == "__main__":
    print("Starting training...")

    model = train_acoustic_codec(
        num_epochs=50,
        batch_size=16,
        learning_rate=1e-4
    )
    
    torch.save(model.state_dict(), 'acoustic_codec_final.pth')
    print("Training completed and model saved!")

Starting training...
Using device: mps


  audio, sr = librosa.load(file_path, sr=sample_rate)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
Epochs:   0%|          | 0/30 [00:01<?, ?it/s]


FileNotFoundError: [Errno 2] No such file or directory: 'processed_qrXHw1_1648.mp3'