In [None]:
# Imports and Setup
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Datasetr
from tqdm import tqdm
from dotenv import load_dotenv
from audio_preprocessor import AudioFeatureExtractor
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Load environment variables (if you have any)
load_dotenv()

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
HYPERPARAMS = {
    'batch_size': 32,
    'learning_rate': 1e-6,
    'num_epochs': 10,
    'd_model': 256,
    'nhead': 8,
    'num_layers': 4,
    'dim_feedforward': 512,
    'sample_rate': 16000,
    'n_mels': 128,
    'window_size': 25,  # in milliseconds
    'hop_size': 10,  # in milliseconds
    'chunk_size': 1024,
    'patch_size': 32,
    'patch_overlap': 8
}

In [None]:
# Audio preprocessor
preprocessor = AudioFeatureExtractor(
    sample_rate=16000,
    n_mels=128,
    window_size=25,
    hop_size=10,
    chunk_size=1024,
    patch_size=32,
    chunk_overlap=256,
    patch_overlap=8,
)

# Dataset class for loading audio files
class AudioDataset(Dataset):
    def __init__(self, file_paths, preprocessor):
        self.file_paths = file_paths
        self.preprocessor = preprocessor
        self.data = []

        for file_path in tqdm(self.file_paths, desc="Preprocessing data"):
            chunks = self.preprocessor(file_path)
            self.data.extend(chunks)  # Add all chunks to the dataset     

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.float)

# Custom collate function to handle variable length audio chunks
def collate_fn(batch):
    batch = [torch.tensor(item).float() for item in batch]
    batch = torch.stack(batch)
    return batch

# Load data
file_paths = [os.path.join('../data/songs1', f) for f in os.listdir('../data/songs1') if f.endswith('.mp3')]
#file_paths = file_paths[:10000]  # Use only first 10 files for now

# Calculate mean and std for dataset
mean, std = AudioFeatureExtractor.calculate_dataset_stats(file_paths, preprocessor)

preprocessor.mean = mean
preprocessor.std = std

print(f"Mean: {mean}, Std: {std}")

dataset = AudioDataset(file_paths, preprocessor)

# Create train/val split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size'], shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=HYPERPARAMS['batch_size'], shuffle=False, collate_fn=collate_fn)


# Print dataset info
print(f"Total number of audio files: {len(dataset)}")
print(f"Number of batches in training set: {len(train_dataloader)}")
print(f"Number of batches in validation set: {len(val_dataloader)}")

In [None]:
file_path = file_paths[0]
features = preprocessor(file_path, do_chunk=True, do_patch=False)
print(f"Chunk, Mel_bin, Time : {np.array(features).shape}")

In [None]:
file_path = file_paths[0]
features = preprocessor(file_path)
print(f"Chunk, Number of Patches, Mel_bin, Time: {np.array(features).shape}")

In [None]:
# Load an example file
file_path = file_paths[0]

# Extract spectrogram
spec = preprocessor(file_path, do_chunk=False, do_normalize=True)
chunk = preprocessor(file_path, do_chunk=True, do_patch=False)[0]
patches = preprocessor(file_path, do_chunk=True, do_patch=True)[0] # Get patches for the first chunk

fig = make_subplots(rows=2, cols=1, subplot_titles=("Full spectogram", "First Chunk"))

# Display full spectrogram
fig.add_trace(
    go.Heatmap(
        z=spec,
        colorscale='Viridis',
    ),
    row=1,
    col=1,
)

# Display first chunk
fig.add_trace(
    go.Heatmap(
        z=chunk,
        colorscale='Viridis',
    ),
    row=2,
    col=1,
)

# Update layout
fig.update_layout(
    title_text=f"Mel Spectrogram for {os.path.basename(file_path)}",
    height=800,  # Adjust height as needed
    showlegend=False
)

# Update y-axis range to match between subplots
fig.update_yaxes(range=[0, spec.shape[0]], row=1, col=1)

# Show the plot
fig.show()


# Create a subplot figure with 2 rows and 4 columns
fig = make_subplots(rows=2, cols=4, subplot_titles=[f"Patch {i+1}" for i in range(4)] + [f"Patch {i+1}" for i in range(len(patches) - 4,len(patches))])
min_val = np.min([np.min(patch) for patch in patches])
max_val = np.max([np.max(patch) for patch in patches])

# Display first 8 patches
for i in range(4):
    patch = patches[i]
    fig.add_trace(
        go.Heatmap(
            z=patch,
            colorscale='Viridis',
            zmin=min_val,
            zmax=max_val,
        ),
        row=1,
        col=(i % 4) + 1
    )

for i in range(-4,0,1):
    patch = patches[i]
    fig.add_trace(
        go.Heatmap(
            z=patch,
            colorscale='Viridis',
            zmin=min_val,
            zmax=max_val,
        ),
        row=2,
        col=(i % 4) + 1
    )

# Update layout
fig.update_layout(
    title_text=f"First/ Last 4 patches for {os.path.basename(file_path)}",
    height=600,  # Adjust height as needed
    showlegend=False
)

# Show the plot
fig.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from custom_model import AudioTransformerModel

# Initialize the model
model = AudioTransformerModel(
    patch_size=HYPERPARAMS['patch_size'],
    num_layers=HYPERPARAMS['num_layers'],
    num_heads=HYPERPARAMS['nhead'],
    d_model=HYPERPARAMS['d_model'],
    dim_feedforward=HYPERPARAMS['dim_feedforward']
).to(device)

# Print model summary
print(model)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

In [None]:
import plotly.graph_objects as go
from IPython.display import clear_output
import os

# Plotting function for real-time loss visualization
def plot_losses(train_losses, val_losses=None, epoch_batches=None, save_path=None):
    clear_output(wait=True)
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=list(range(len(train_losses))),
        y=train_losses,
        mode='lines',
        name='Training Loss'
    ))
    
    if val_losses:
        val_x = [(i + 1) * epoch_batches - 1 for i in range(len(val_losses))]
        val_x = [0] + val_x  # Ensure validation loss starts from batch 0
        val_losses = [val_losses[0]] + val_losses  # Duplicate the first validation loss for batch 0
        fig.add_trace(go.Scatter(
            x=val_x,
            y=val_losses,
            mode='lines',
            name='Validation Loss'
        ))
    
    fig.update_layout(
        title='Training and Validation Loss over time',
        xaxis_title='Batch',
        yaxis_title='Loss',
        legend=dict(x=0.1, y=0.9),
        height=600,  # Adjust height as needed
        width=800,  # Adjust width as needed
        showlegend=True
    )
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        fig.write_image(save_path)
    
    fig.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

def create_triplets(embeddings, batch_size):
    # Ensure the batch size is even
    if batch_size % 2 != 0:
        embeddings = embeddings[:-1]
        batch_size -= 1
    
    # Split batch into anchors and positives
    anchors = embeddings[:batch_size//2]
    positives = embeddings[batch_size//2:]
    
    # Create negative samples by rolling the positives
    negatives = torch.roll(positives, shifts=1, dims=0)
    
    return anchors, positives, negatives

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.utils import clip_grad_norm_

def train_model(model, train_dataloader, val_dataloader, optimizer, scheduler, num_epochs, temperature=0.5):
    criterion = TripletLoss(margin=1.0)
    train_losses = []
    avg_val_losses = []
    num_batches = len(train_dataloader)

    for epoch in range(num_epochs):
        model.train()
        epoch_train_loss = 0
        
        for batch_idx, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            batch = batch.to(device)
            optimizer.zero_grad()
            
            embeddings = model(batch).squeeze()
            anchors, positives, negatives = create_triplets(embeddings, batch.size(0))
            
            loss = criterion(anchors, positives, negatives)
            
            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_train_loss += loss.item()

            # Append the current batch loss to train_losses
            train_losses.append(loss.item())

            # Plot training loss after each batch
            plot_losses(train_losses, avg_val_losses, num_batches)

            print(f"Epoch {epoch+1}, Batch {batch_idx+1}, Training Loss: {loss.item():.4f}")

        # Calculate and save the average training loss for the epoch
        avg_epoch_train_loss = epoch_train_loss / num_batches

        # Validation step
        model.eval()
        epoch_val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                batch = batch.to(device)
                embeddings = model(batch).squeeze()
                anchors, positives, negatives = create_triplets(embeddings, batch.size(0))
                
                loss = criterion(anchors, positives, negatives)
                epoch_val_loss += loss.item()
        
        # Calculate and save the average validation loss for the epoch
        avg_epoch_val_loss = epoch_val_loss / len(val_dataloader)
        avg_val_losses.append(avg_epoch_val_loss)

        scheduler.step()

        # Plot both training and validation losses after each epoch
        plot_losses(train_losses, avg_val_losses, num_batches)

        # Save the model after each epoch
        os.makedirs('saved_models', exist_ok=True)
        torch.save(model.state_dict(), f'saved_models/audio_embedding_model_epoch_{epoch+1}.pth')
        print(f"Model saved successfully for epoch {epoch+1}.")
        
        print(f"Epoch {epoch+1}, Avg Train Loss: {avg_epoch_train_loss:.4f}, Val Loss: {avg_epoch_val_loss:.4f}")

    # Save the final plot
    plot_losses(train_losses, avg_val_losses, num_batches, save_path='loss_plot/final_loss_plot.png')
    
    return model


# Initialize the model
model = AudioTransformerModel(
    patch_size=HYPERPARAMS['patch_size'],
    num_layers=HYPERPARAMS['num_layers'],
    num_heads=HYPERPARAMS['nhead'],
    d_model=HYPERPARAMS['d_model'],
    dim_feedforward=HYPERPARAMS['dim_feedforward']
).to(device)

# Print model summary
print(model)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Loss function and optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=HYPERPARAMS['learning_rate'])
scheduler = CosineAnnealingLR(optimizer, T_max=HYPERPARAMS['num_epochs'])

# Train the model
trained_model = train_model(model, train_dataloader, val_dataloader, optimizer, scheduler, HYPERPARAMS['num_epochs'])

# Save the trained model
torch.save(trained_model.state_dict(), 'audio_embedding_model.pth')
print("Model saved successfully.")