In [None]:
# Imports and Setup
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from tqdm import tqdm
from dotenv import load_dotenv
from audio_preprocessor import AudioFeatureExtractor
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Load environment variables (if you have any)
load_dotenv()

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
HYPERPARAMS = {
    'batch_size': 16,
    'learning_rate': 1e-4,
    'num_epochs': 10,
    'd_model': 256,
    'nhead': 8,
    'num_layers': 4,
    'dim_feedforward': 512,
    'sample_rate': 16000,
    'n_mels': 128,
    'window_size': 25,  # in milliseconds
    'hop_size': 10,  # in milliseconds
    'chunk_size': 1024,
    'patch_size': 32,
    'patch_overlap': 8
}

In [None]:
# Audio preprocessor
preprocessor = AudioFeatureExtractor(
    sample_rate=16000,
    n_mels=128,
    window_size=25,
    hop_size=10,
    chunk_size=1024,
    patch_size=32,
    chunk_overlap=256,
    patch_overlap=8,
)

# Dataset class for loading audio files
class AudioDataset(Dataset):
    def __init__(self, file_paths, preprocessor):
        self.file_paths = file_paths
        self.preprocessor = preprocessor
        self.data = []

        for file_path in tqdm(self.file_paths, desc="Preprocessing data"):
            chunks = self.preprocessor(file_path)
            self.data.extend(chunks)  # Add all chunks to the dataset     

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.float)

# Custom collate function to handle variable length audio chunks
def collate_fn(batch):
    batch = [torch.tensor(item).float() for item in batch]
    batch = torch.stack(batch)
    return batch

# Load data
file_paths = [os.path.join('../data/songs', f) for f in os.listdir('../data/songs') if f.endswith('.mp3')]
file_paths = file_paths[:100]  # Use only first 10 files for now

# Calculate mean and std for dataset
mean, std = AudioFeatureExtractor.calculate_dataset_stats(file_paths, preprocessor)

preprocessor.mean = mean
preprocessor.std = std

print(f"Mean: {mean}, Std: {std}")

dataset = AudioDataset(file_paths, preprocessor)

# Create train/val split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size'], shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=HYPERPARAMS['batch_size'], shuffle=False, collate_fn=collate_fn)


# Print dataset info
print(f"Total number of audio files: {len(dataset)}")
print(f"Number of batches in training set: {len(train_dataloader)}")
print(f"Number of batches in validation set: {len(val_dataloader)}")

In [None]:
file_path = file_paths[0]
features = preprocessor(file_path, do_chunk=True, do_patch=False)
print(f"Chunk, Mel_bin, Time : {np.array(features).shape}")

In [None]:
file_path = file_paths[0]
features = preprocessor(file_path)
print(f"Chunk, Number of Patches, Mel_bin, Time: {np.array(features).shape}")

In [None]:
# Load an example file
file_path = file_paths[0]

# Extract spectrogram
spec = preprocessor(file_path, do_chunk=False, do_normalize=True)
chunk = preprocessor(file_path, do_chunk=True, do_patch=False)[0]
patches = preprocessor(file_path, do_chunk=True, do_patch=True)[0] # Get patches for the first chunk

fig = make_subplots(rows=2, cols=1, subplot_titles=("Full spectogram", "First Chunk"))

# Display full spectrogram
fig.add_trace(
    go.Heatmap(
        z=spec,
        colorscale='Viridis',
    ),
    row=1,
    col=1,
)

# Display first chunk
fig.add_trace(
    go.Heatmap(
        z=chunk,
        colorscale='Viridis',
    ),
    row=2,
    col=1,
)

# Update layout
fig.update_layout(
    title_text=f"Mel Spectrogram for {os.path.basename(file_path)}",
    height=800,  # Adjust height as needed
    showlegend=False
)

# Update y-axis range to match between subplots
fig.update_yaxes(range=[0, spec.shape[0]], row=1, col=1)

# Show the plot
fig.show()


# Create a subplot figure with 2 rows and 4 columns
fig = make_subplots(rows=2, cols=4, subplot_titles=[f"Patch {i+1}" for i in range(4)] + [f"Patch {i+1}" for i in range(len(patches) - 4,len(patches))])
min_val = np.min([np.min(patch) for patch in patches])
max_val = np.max([np.max(patch) for patch in patches])

# Display first 8 patches
for i in range(4):
    patch = patches[i]
    fig.add_trace(
        go.Heatmap(
            z=patch,
            colorscale='Viridis',
            zmin=min_val,
            zmax=max_val,
        ),
        row=1,
        col=(i % 4) + 1
    )

for i in range(-4,0,1):
    patch = patches[i]
    fig.add_trace(
        go.Heatmap(
            z=patch,
            colorscale='Viridis',
            zmin=min_val,
            zmax=max_val,
        ),
        row=2,
        col=(i % 4) + 1
    )

# Update layout
fig.update_layout(
    title_text=f"First/ Last 4 patches for {os.path.basename(file_path)}",
    height=600,  # Adjust height as needed
    showlegend=False
)

# Show the plot
fig.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AudioTransformerModel(nn.Module):
    def __init__(self, patch_size, num_layers, num_heads, d_model, dim_feedforward, dropout=0.1):
        super(AudioTransformerModel, self).__init__()

        self.conv2d = nn.Conv2d(in_channels=1, out_channels=d_model, kernel_size=(2, 2), stride=2)
        self.gelu = nn.GELU()
        self.linear_proj = nn.Linear(d_model, d_model)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.pos_encoder = nn.Parameter(torch.zeros(1, 258, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.output_layer = nn.Linear(d_model, 768)

    def forward(self, x):
        batch_size, num_patches, mel_bins, time = x.size()
        
        # Reshape input to (batch_size * num_patches, 1, mel_bins, time)
        x = x.view(batch_size * num_patches, 1, mel_bins, time)
        
        # Apply 2D convolution
        x = self.conv2d(x)
        x = self.gelu(x)
        
        # Reshape to (batch_size, num_patches, d_model, new_height * new_width)
        _, d_model, new_height, new_width = x.size()
        x = x.view(batch_size, num_patches, d_model, -1).mean(dim=-1)
        
        # Linear projection
        x = self.linear_proj(x)

        # Layer normalization
        x = self.layer_norm1(x)
        
        # Add positional encoding
        x = x + self.pos_encoder[:, :num_patches, :]
        
        # Transformer encoder
        x = self.transformer_encoder(x.transpose(0, 1)).transpose(0, 1)

        # Layer normalization
        x = self.layer_norm2(x)
        
        # Output layer
        x = self.output_layer(x.mean(dim=1))  # Global average pooling
        
        return x

# Initialize the model
model = AudioTransformerModel(
    patch_size=HYPERPARAMS['patch_size'],
    num_layers=HYPERPARAMS['num_layers'],
    num_heads=HYPERPARAMS['nhead'],
    d_model=HYPERPARAMS['d_model'],
    dim_feedforward=HYPERPARAMS['dim_feedforward']
).to(device)

# Print model summary
print(model)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

In [None]:
import plotly.graph_objects as go
from IPython.display import clear_output

# Plotting function for real-time loss visualization
def plot_losses(train_losses, val_losses=None):
    clear_output(wait=True)
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=list(range(len(train_losses))),
        y=train_losses,
        mode='lines',
        name='Training Loss'
    ))
    
    if val_losses:
        fig.add_trace(go.Scatter(
            x=list(range(len(val_losses))),
            y=val_losses,
            mode='lines',
            name='Validation Loss'
        ))
    
    fig.update_layout(
        title='Loss over time',
        xaxis_title='Epoch',
        yaxis_title='Loss',
        legend=dict(x=0.1, y=0.9),
        height=600,  # Adjust height as needed
        width=800,  # Adjust width as needed
        showlegend=True
    )
    
    fig.show()


In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.utils import clip_grad_norm_

# Training loop
def train_model(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler, num_epochs):
    model.train()
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        epoch_loss = 0
        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # Split batch into two halves for contrastive learning
            half = batch.size(0) // 2
            embeddings = model(batch)
            emb1, emb2 = embeddings[:half], embeddings[half:]
            
            # Compute loss
            target = torch.ones(half).to(device)  # Positive pairs
            loss = criterion(emb1, emb2, target)
            
            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += loss.item()

        # Save statistics for plotting
        train_losses.append(epoch_loss/len(train_dataloader))

        # Validation step
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                batch = batch.to(device)
                half = batch.size(0) // 2
                embeddings = model(batch)
                emb1, emb2 = embeddings[:half], embeddings[half:]
                target = torch.ones(half).to(device)
                loss = criterion(emb1, emb2, target)
                val_loss += loss.item()
        
        val_losses.append(val_loss/len(val_dataloader))

        scheduler.step()

        plot_losses(train_losses, val_losses)
        
        print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(train_dataloader):.4f}")
        print()
    
    return model

# Loss function and optimizer
criterion = nn.CosineEmbeddingLoss()
optimizer = optim.Adam(model.parameters(), lr=HYPERPARAMS['learning_rate'])
scheduler = CosineAnnealingLR(optimizer, T_max=HYPERPARAMS['num_epochs'])

# Train the model
trained_model = train_model(model, train_dataloader, val_dataloader, criterion, optimizer, scheduler, HYPERPARAMS['num_epochs'])

# Save the trained model
torch.save(trained_model.state_dict(), 'audio_embedding_model.pth')
print("Model saved successfully.")
