In [1]:
# semantic_codec=raw_audio -> s-tokens

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset,  random_split
import numpy as np
from transformers import  WhisperProcessor, WhisperForConditionalGeneration
import os
import glob
import librosa
import gc
import logging
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


## Extracting Features

### Loading Data and Extracting Features

In [25]:
# Set up logging
logging.basicConfig(level=logging.INFO, filename='training.log', filemode='w')

# Set random seed for reproducibility
torch.manual_seed(42)

# Define the dataset class
class Arabic_Processed_audios(Dataset):
    def __init__(self, audio_path, max_length=160000):
        self.audio_path = audio_path
        self.audio_files = glob.glob(os.path.join(self.audio_path, "*.mp3"))[:20000]
        self.max_length = max_length

    def __getitem__(self, idx):
        file = self.audio_files[idx]
        try:
            audio, _ = librosa.load(file, sr=16000)
            audio = librosa.util.normalize(audio)
            
            if len(audio) < self.max_length:
                audio = np.pad(audio, (0, self.max_length - len(audio)))
            else:
                audio = audio[:self.max_length]
            
            return audio, os.path.basename(file)  # Return filename for saving features
        except Exception as e:
            logging.error(f"Error loading {file}: {str(e)}")
            return None, None

    def __len__(self):
        return len(self.audio_files)

# Load audio files
Audio_files = Arabic_Processed_audios('phase2_data/subset_80k_audio')
print(f"Number of audio files: {len(Audio_files)}")

# Load Whisper processor and model
processor = WhisperProcessor.from_pretrained("openai/whisper-base", cache_dir='./processor_cache')
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base", cache_dir='./model_cache')
device = torch.device("cpu")  # Explicitly set to CPU since MPS is unsupported
model.to(device)

def extract_whisper_features(model, audio, processor, layer=-1, max_length=1500):
    try:
        # For 10-second audio at 16kHz, we need at least 160,000 samples
        min_samples = 160000  # 10 seconds * 16000 Hz
        
        # Pad audio if it's shorter than 10 seconds
        if len(audio) < min_samples:
            audio = np.pad(audio, (0, min_samples - len(audio)))
        else:
            audio = audio[:min_samples]  # Truncate if longer
            
        # Process with Whisper
        inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            encoder_outputs = model.get_encoder()(inputs["input_features"])
            features = encoder_outputs.last_hidden_state if layer == -1 else encoder_outputs.hidden_states[layer]
            
        # Pad or truncate features to max_length
        if features.shape[1] > max_length:
            features = features[:, :max_length, :]
        elif features.shape[1] < max_length:
            padding = (0, 0, 0, max_length - features.shape[1])
            features = torch.nn.functional.pad(features, padding)
            
        return features.cpu()
    except Exception as e:
        logging.error(f"Error processing audio: {str(e)}")
        return None


# Use DataLoader for batch processing
dataloader = DataLoader(Audio_files, batch_size=16, shuffle=False, num_workers=0)

print("Processing Audio Files")
for batch in tqdm(dataloader, desc="Processing Audio Files"):
    audios, filenames = batch
    for audio, filename in zip(audios, filenames):
        if audio is None:
            continue  
        features = extract_whisper_features(model, audio.numpy(), processor, layer=-1)
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

Number of audio files: 20000
Processing Audio Files


Processing Audio Files: 100%|██████████| 1250/1250 [1:12:57<00:00,  3.50s/it]


In [28]:
features.shape

torch.Size([1, 1500, 512])

In [30]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import glob
import os

# First, load all features from the files
features_dir = 'features_output'
feature_files = glob.glob(os.path.join(features_dir, "*_features.npy"))
features = []

print(f"Loading {len(feature_files)} feature files...")
for feature_file in feature_files:
    feature = np.load(feature_file)  # Shape: (1, 1500, 512)
    features.append(feature)

# Convert list to numpy array
features = np.array(features)  # Shape: (N, 1, 1500, 512) where N is number of files
print(f"Loaded features shape: {features.shape}")

class WhisperFeaturesDataset(Dataset):
    def __init__(self, features):
        self.features = features
    
    def __getitem__(self, idx):
        feature = self.features[idx]
        feature = feature.squeeze(0)  # Shape: (1500, 512)
        return torch.FloatTensor(feature)
    
    def __len__(self):
        return len(self.features)

# Create dataset
features_dataset = WhisperFeaturesDataset(features)
print(f"Number of features in dataset: {len(features_dataset)}")

# Create train/test split (80% train, 20% test)
train_size = int(0.8 * len(features_dataset))
test_size = len(features_dataset) - train_size
train_dataset, test_dataset = random_split(features_dataset, [train_size, test_size])

# Create dataloaders
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=16, 
    shuffle=True,
    num_workers=0  # Set to 0 to avoid multiprocessing issues
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=16, 
    shuffle=False,
    num_workers=0
)

# Verify the shapes
for batch in train_dataloader:
    print("Train batch shape:", batch.shape)  # Should be (16, 1500, 512)
    break

for batch in test_dataloader:
    print("Test batch shape:", batch.shape)  # Should be (16, 1500, 512)
    break

print(f"\nNumber of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

Loading 145 feature files...
Loaded features shape: (145, 1, 1500, 512)
Number of features in dataset: 145
Train batch shape: torch.Size([16, 1500, 512])
Test batch shape: torch.Size([16, 1500, 512])

Number of training samples: 116
Number of test samples: 29


## Semantic Codec Arch

In [20]:
# class FeaturesDataset(Dataset):
#     def __init__(self, features_dir):
#         self.features_dir = features_dir
#         self.feature_files = glob.glob(os.path.join(features_dir, "*_features.npy"))
        
#     def __getitem__(self, idx):
#         feature_file = self.feature_files[idx]
#         features = np.load(feature_file)  # Shape: (1, 1500, 512)
#         # Remove the batch dimension since DataLoader will handle batching
#         features = features.squeeze(0)  # Shape: (1500, 512)
#         return torch.FloatTensor(features)
    
#     def __len__(self):
#         return len(self.feature_files)

# # Create dataset
# features_dataset = FeaturesDataset('features_output')
# print(f"Number of feature files: {len(features_dataset)}")

# # Create train/test split
# train_size = int(0.8 * len(features_dataset))
# test_size = len(features_dataset) - train_size
# train_dataset, test_dataset = random_split(features_dataset, [train_size, test_size])

# # Create dataloaders
# train_dataloader = DataLoader(
#     train_dataset, 
#     batch_size=16, 
#     shuffle=True,
#     num_workers=0  
# )

# test_dataloader = DataLoader(
#     test_dataset, 
#     batch_size=16, 
#     shuffle=False,
#     num_workers=0
# )

Number of feature files: 145


In [31]:
# Test the dataloader
for batch in train_dataloader:
    print("Batch shape:", batch.shape)  # Should be (16, 1500, 512)
    break

Batch shape: torch.Size([16, 1500, 512])


##### Conv Block

In [32]:
class ConvNextBlock(nn.Module):
    def __init__(self, dim, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim)  # Depthwise
        self.norm = nn.LayerNorm(dim)
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.act = nn.GELU()

    def forward(self, x):
        # x: (batch, dim, seq_len)
        residual = x
        x = self.conv(x)
        # Transpose for LayerNorm
        x = x.transpose(1, 2)  # (batch, seq_len, dim)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        # Transpose back
        x = x.transpose(1, 2)  # (batch, dim, seq_len)
        return x + residual

##### Vector Quantization Layer

In [33]:
# Vector Quantization Layer
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings=8192, embedding_dim=8, commitment_cost=0.25):
        
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.embeddings = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
        self.register_buffer('ema_count', torch.zeros(num_embeddings))
        self.register_buffer('ema_weight', self.embeddings.clone())

    def forward(self, x):

        flat_x = x.reshape(-1, self.embedding_dim)
        distances = torch.cdist(flat_x, self.embeddings)
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.embeddings[encoding_indices].reshape(x.shape)
        codebook_loss = F.mse_loss(quantized.detach(), x)
        commitment_loss = self.commitment_cost * F.mse_loss(quantized, x.detach())
        loss = codebook_loss + commitment_loss
        quantized = x + (quantized - x).detach()
        
        if self.training:
            with torch.no_grad():
                one_hot = F.one_hot(encoding_indices, self.num_embeddings).float()
                self.ema_count = 0.999 * self.ema_count + 0.001 * torch.sum(one_hot, dim=0)
                n = torch.sum(self.ema_count)
                self.ema_count = (self.ema_count + 1e-8) / (n + self.num_embeddings * 1e-8) * n
                dw = torch.matmul(one_hot.transpose(0, 1), flat_x)
                self.ema_weight = 0.999 * self.ema_weight + 0.001 * dw
                self.embeddings.data = (self.ema_weight / (self.ema_count.unsqueeze(-1) + 1e-8))
        
        return quantized, loss, encoding_indices

#### Semantic Codec model

In [34]:
# VQ-VAE Model
# Custom Lambda module for applying arbitrary functions (e.g., transpose)
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)
class VQVAE(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=384, codebook_size=8192, codebook_dim=8):
        super().__init__()
        self.encoder = nn.Sequential(
                    nn.Conv1d(input_dim, hidden_dim, kernel_size=7, padding=3),
                    Lambda(lambda x: x.transpose(1, 2)),  # To (batch_size, sequence_length, hidden_dim)
                    nn.LayerNorm(hidden_dim),
                    Lambda(lambda x: x.transpose(1, 2)),  # Back to (batch_size, hidden_dim, sequence_length)
                    # Assume ConvNextBlock is defined and works with (batch_size, hidden_dim, sequence_length)
                    *[ConvNextBlock(hidden_dim) for _ in range(6)],
                    nn.Conv1d(hidden_dim, codebook_dim, kernel_size=1)
                )
        self.quantizer = VectorQuantizer(num_embeddings=codebook_size, embedding_dim=codebook_dim)
        self.decoder = nn.Sequential(
            nn.Conv1d(codebook_dim, hidden_dim, kernel_size=7, padding=3),
            *[ConvNextBlock(hidden_dim) for _ in range(6)],
            nn.Conv1d(hidden_dim, input_dim, kernel_size=1)
        )

    def forward(self, x):
            x = x.transpose(1, 2)  # From (batch_size, sequence_length, input_dim) to (batch_size, input_dim, sequence_length)
            z = self.encoder(x)  # Output: (batch_size, codebook_dim, sequence_length)
            z = z.transpose(1, 2)  # To (batch_size, sequence_length, codebook_dim) for quantizer
            quantized, vq_loss, indices = self.quantizer(z)  # quantized: (batch_size, sequence_length, codebook_dim)
            quantized = quantized.transpose(1, 2)  # To (batch_size, codebook_dim, sequence_length)
            recon = self.decoder(quantized)  # Input to decoder: (batch_size, codebook_dim, sequence_length)
            return recon, quantized, vq_loss, indices

#### Dataloader

In [35]:
# # Load features dataset
# features_dataset = FeaturesDataset('features_output')
# print(f"Number of feature files: {len(features_dataset)}")

# # Create train/test split
# train_size = int(0.8 * len(features_dataset))
# test_size = len(features_dataset) - train_size
# train_dataset, test_dataset = random_split(features_dataset, [train_size, test_size])

# # Create dataloaders
# train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [36]:
# Initialize VQVAE model
device='mps'
model = VQVAE(input_dim=512, hidden_dim=384, codebook_size=8192, codebook_dim=8)
model = model.to(device)

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

#### Training and evaluation of semantic codec

In [39]:
def train_model(model, train_dataloader, test_dataloader, num_epochs):   
    model.train()
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        epoch_loss = 0
        epoch_recon_loss = 0
        epoch_vq_loss = 0
        
        for batch in tqdm(train_dataloader, desc="Training", leave=False):
            features = batch.to(device)  # Shape: (batch_size, sequence_length, input_dim)
            
            # Ensure features are in the correct shape
            if len(features.shape) == 2:
                features = features.unsqueeze(0)  # Add batch dimension if missing
            
            recon, quantized, vq_loss, indices = model(features)
            recon = recon.transpose(1, 2)  # Back to (batch_size, sequence_length, input_dim)
            
            # Normalize the loss by batch size
            recon_loss = torch.nn.functional.mse_loss(recon, features, reduction='mean')
            loss = recon_loss + vq_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Track individual losses
            epoch_recon_loss += recon_loss.item()
            epoch_vq_loss += vq_loss.item()
            epoch_loss += loss.item()

        # Print epoch statistics
        avg_recon_loss = epoch_recon_loss / len(train_dataloader)
        avg_vq_loss = epoch_vq_loss / len(train_dataloader)
        avg_total_loss = epoch_loss / len(train_dataloader)
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Training - Recon Loss: {avg_recon_loss:.4f}, VQ Loss: {avg_vq_loss:.4f}, Total Loss: {avg_total_loss:.4f}")

        # Testing loop
        model.eval()
        with torch.no_grad():
            test_recon_loss = 0
            test_vq_loss = 0
            test_total_loss = 0
            
            for batch in tqdm(test_dataloader, desc="Testing"):
                features = batch.to(device)
                if len(features.shape) == 2:
                    features = features.unsqueeze(0)
                    
                recon, quantized, vq_loss, indices = model(features)
                recon = recon.transpose(1, 2)
                recon_loss = torch.nn.functional.mse_loss(recon, features, reduction='mean')
                total_loss = recon_loss + vq_loss
                
                test_recon_loss += recon_loss.item()
                test_vq_loss += vq_loss.item()
                test_total_loss += total_loss.item()
            
            # Print test statistics
            avg_test_recon_loss = test_recon_loss / len(test_dataloader)
            avg_test_vq_loss = test_vq_loss / len(test_dataloader)
            avg_test_total_loss = test_total_loss / len(test_dataloader)
            print(f"  Testing - Recon Loss: {avg_test_recon_loss:.4f}, VQ Loss: {avg_test_vq_loss:.4f}, Total Loss: {avg_test_total_loss:.4f}\n")
    
    return model

In [40]:
# Start training
model = train_model(model, train_dataloader, test_dataloader, num_epochs=8)



Epochs:   0%|          | 0/8 [00:00<?, ?it/s]


Epoch 1/8
  Training - Recon Loss: 1.8586, VQ Loss: 0.7210, Total Loss: 2.5796


Testing: 100%|██████████| 2/2 [00:00<00:00,  4.30it/s]
Epochs:  12%|█▎        | 1/8 [00:08<00:59,  8.45s/it]

  Testing - Recon Loss: 1.5068, VQ Loss: 0.1194, Total Loss: 1.6262






Epoch 2/8
  Training - Recon Loss: 1.3822, VQ Loss: 0.1072, Total Loss: 1.4894


Testing: 100%|██████████| 2/2 [00:00<00:00,  4.44it/s]
Epochs:  25%|██▌       | 2/8 [00:14<00:41,  6.86s/it]

  Testing - Recon Loss: 1.2460, VQ Loss: 0.1024, Total Loss: 1.3483






Epoch 3/8
  Training - Recon Loss: 1.2031, VQ Loss: 0.0864, Total Loss: 1.2895


Testing: 100%|██████████| 2/2 [00:00<00:00,  4.49it/s]
Epochs:  38%|███▊      | 3/8 [00:19<00:31,  6.35s/it]

  Testing - Recon Loss: 1.1380, VQ Loss: 0.0876, Total Loss: 1.2256






Epoch 4/8
  Training - Recon Loss: 1.1317, VQ Loss: 0.0756, Total Loss: 1.2073


Testing: 100%|██████████| 2/2 [00:00<00:00,  4.50it/s]
Epochs:  50%|█████     | 4/8 [00:25<00:24,  6.10s/it]

  Testing - Recon Loss: 1.0954, VQ Loss: 0.0692, Total Loss: 1.1646






Epoch 5/8
  Training - Recon Loss: 1.0930, VQ Loss: 0.0731, Total Loss: 1.1661


Testing: 100%|██████████| 2/2 [00:00<00:00,  4.49it/s]
Epochs:  62%|██████▎   | 5/8 [00:31<00:17,  5.97s/it]

  Testing - Recon Loss: 1.0572, VQ Loss: 0.0741, Total Loss: 1.1314






Epoch 6/8
  Training - Recon Loss: 1.0588, VQ Loss: 0.0723, Total Loss: 1.1311


Testing: 100%|██████████| 2/2 [00:00<00:00,  4.47it/s]
Epochs:  75%|███████▌  | 6/8 [00:37<00:11,  5.89s/it]

  Testing - Recon Loss: 1.0318, VQ Loss: 0.0659, Total Loss: 1.0978






Epoch 7/8
  Training - Recon Loss: 1.0243, VQ Loss: 0.0616, Total Loss: 1.0859


Testing: 100%|██████████| 2/2 [00:00<00:00,  4.48it/s]
Epochs:  88%|████████▊ | 7/8 [00:42<00:05,  5.83s/it]

  Testing - Recon Loss: 1.0023, VQ Loss: 0.0514, Total Loss: 1.0537






Epoch 8/8
  Training - Recon Loss: 1.0049, VQ Loss: 0.0569, Total Loss: 1.0619


Testing: 100%|██████████| 2/2 [00:00<00:00,  4.48it/s]
Epochs: 100%|██████████| 8/8 [00:48<00:00,  6.07s/it]

  Testing - Recon Loss: 0.9777, VQ Loss: 0.0560, Total Loss: 1.0337






In [17]:
# Save the trained model
torch.save(model.state_dict(), 'vqvae_model_train2.pth')

In [42]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

# Assuming you have a validation dataset (val_dataset) and corresponding dataloader (val_dataloader)
# If not, you'll need to create one.  This example assumes you have one.

def evaluate_model(model, dataloader, device):
    model.eval()  # Set the model to evaluation mode
    total_recon_loss = 0.0
    total_vq_loss = 0.0
    total_loss = 0.0
    num_batches = len(dataloader)

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for batch in tqdm(dataloader, desc="Evaluating"):
            features = batch[0].to(device)  # Assuming your data loader returns a list, and the first element is your feature tensor
            if len(features.shape) == 2:
                features = features.unsqueeze(0)  # Add batch dimension if missing
            recon, quantized, vq_loss, indices = model(features)
            recon = recon.transpose(1, 2)
            recon_loss = torch.nn.functional.mse_loss(recon, features, reduction='mean')
            loss = recon_loss + vq_loss

            total_recon_loss += recon_loss.item()
            total_vq_loss += vq_loss.item()
            total_loss += loss.item()

    avg_recon_loss = total_recon_loss / num_batches
    avg_vq_loss = total_vq_loss / num_batches
    avg_total_loss = total_loss / num_batches

    print(f"Evaluation Results:")
    print(f"  Reconstruction Loss: {avg_recon_loss:.4f}")
    print(f"  VQ Loss: {avg_vq_loss:.4f}")
    print(f"  Total Loss: {avg_total_loss:.4f}")

    return avg_recon_loss, avg_vq_loss, avg_total_loss

# Example usage (assuming you have a validation dataloader named 'val_dataloader')
# Make sure your model is on the correct device before evaluating
device = 'mps'
model = model.to(device)  # Ensure model is on the correct device

recon_loss, vq_loss, total_loss = evaluate_model(model, test_dataloader, device)

# You can also evaluate on the training data to check for overfitting:
train_recon_loss, train_vq_loss, train_total_loss = evaluate_model(model, train_dataloader, device)



Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.82it/s]


Evaluation Results:
  Reconstruction Loss: 0.9472
  VQ Loss: 0.0544
  Total Loss: 1.0015


Evaluating: 100%|██████████| 8/8 [00:00<00:00, 39.72it/s]

Evaluation Results:
  Reconstruction Loss: 0.9797
  VQ Loss: 0.0543
  Total Loss: 1.0340





In [43]:
torch.save(model.state_dict(), 'semantic_codec_final_20k_2.pth') #saved with state