# Imports:

In [6]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


In [7]:
class MelDataset(Dataset):
    def __init__(self, csv_path, transform=None):
        self.df = pd.read_csv(csv_path)
        self.mel_paths = self.df["mel_npy_path"].tolist()
        self.transform = transform

    def __len__(self):
        return len(self.mel_paths)

    def __getitem__(self, idx):
        mel = np.load(self.mel_paths[idx])  # Shape: (n_mels, time)
        mel = (mel - mel.min()) / (mel.max() - mel.min()) # Normalize to [0, 1]
        mel = torch.tensor(mel, dtype=torch.float32).unsqueeze(0)  # (1, n_mels, time)
        if self.transform:
            mel = self.transform(mel)
        return mel

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
            )  # Better than ReLU for autoencoders
    
    def forward(self, x):
        return self.block(x)

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            nn.MaxPool2d(2)
        )
    
    def forward(self, x):
        return self.block(x)
    
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = ConvBlock(out_channels * 2, out_channels)  # *2 for skip connection
    
    def forward(self, x, skip):
        x = self.up(x)
        # Handle potential size mismatches due to odd dimensions
        if x.shape != skip.shape:
            x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

In [None]:
class MelEncoder(nn.Module):
    def __init__(self, embedding_dim=512):
        super().__init__()
        # Input: (1, 128, 861)
        self.initial_conv = ConvBlock(1, 32)  # -> (32, 128, 861)
        
        self.down1 = DownBlock(32, 64)       # -> (64, 64, 430)
        self.down2 = DownBlock(64, 128)      # -> (128, 32, 215)
        self.down3 = DownBlock(128, 256)     # -> (256, 16, 107)
        self.down4 = DownBlock(256, 512)     # -> (512, 8, 53)
        
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # -> (512, 1, 1)
        self.proj = nn.Linear(512, embedding_dim)
    
    def forward(self, x):
        skips = []
        
        x = self.initial_conv(x)
        skips.append(x)
        
        x = self.down1(x)
        skips.append(x)
        
        x = self.down2(x)
        skips.append(x)
        
        x = self.down3(x)
        skips.append(x)
        
        x = self.down4(x)
        
        x = self.pool(x)
        x = x.flatten(1)
        x = self.proj(x)
        
        return x, skips

In [None]:
class MelDecoder(nn.Module):
    def __init__(self, embedding_dim=512):
        super().__init__()
        self.embedding_dim = embedding_dim
        
        # Project latent vector back to feature map
        self.fc = nn.Linear(embedding_dim, 512 * 8 * 8)
        
        # Upsampling blocks
        self.up1 = UpBlock(512, 256)  # Input: 512, Output: 256
        self.up2 = UpBlock(256, 128)   # Input: 256, Output: 128
        self.up3 = UpBlock(128, 64)    # Input: 128, Output: 64
        self.up4 = UpBlock(64, 32)     # Input: 64, Output: 32
        
        # Final convolution to get to original channels
        self.final_conv = nn.Sequential(
            nn.Conv2d(32, 1, kernel_size=3, padding=1),
            nn.Sigmoid()  # Since mel-specs are normalized to [0,1]
        )
    
    def forward(self, x, skips):
        # Project latent vector
        x = self.fc(x)
        x = x.view(-1, 512, 8, 8)
        
        # Reverse the skip connections for the decoder
        skips = skips[::-1]
        
        # Upsampling with skip connections
        x = self.up1(x, skips[3][:, :, :x.shape[2], :x.shape[3]])  # Match dimensions
        x = self.up2(x, skips[2][:, :, :x.shape[2], :x.shape[3]])
        x = self.up3(x, skips[1][:, :, :x.shape[2], :x.shape[3]])
        x = self.up4(x, skips[0][:, :, :x.shape[2], :x.shape[3]])
        
        # Final convolution
        x = self.final_conv(x)
        
        # Ensure output matches input dimensions
        x = F.interpolate(x, size=(128, 861), mode='bilinear', align_corners=True)
        
        return x

In [None]:
class MelAutoencoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
        z, skips = self.encoder(x)
        recon = self.decoder(z, skips)
        return recon

In [24]:
def train_auto_encoder(model, train_loader, val_loader, epochs=200, checkpoint_dir='checkpoints/reference_encoder'):
    model.to(device)
    optimizer = optim.Adam(
        list(model.parameters()),
        lr=1e-3, weight_decay=1e-4  # L2 regularization term
    )

    os.makedirs(checkpoint_dir, exist_ok=True)

    for epoch in range(1, epochs + 1):
        # ----- Training -----
        model.train()
        train_loss = 0
        for mel in train_loader:
            mel = mel.to(device)  # [B, 1, M, T]
            recon = model(mel)
            loss = F.l1_loss(recon, mel) + 0.1 * F.mse_loss(recon, mel)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # ----- Validation -----
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for mel in val_loader:
                mel = mel.to(device)
                recon = model(mel)
                loss = F.l1_loss(recon, mel)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch}, Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

        # ----- Save checkpoint -----
        if epoch % 5 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")


In [25]:
train_df = pd.read_csv("Data/split_train_val/train_split.csv")
val_df = pd.read_csv("Data/split_train_val/val_split.csv")

train_df["mel_npy_path"] = train_df["mel_npy_path"].apply(lambda x: x.replace("\\", "/"))
val_df["mel_npy_path"] = val_df["mel_npy_path"].apply(lambda x: x.replace("\\", "/"))

train_dataset = MelDataset(csv_path="Data/split_train_val/train_split.csv")
val_dataset = MelDataset(csv_path="Data/split_train_val/val_split.csv")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = MelEncoder(embedding_dim=512).to(device)
decoder = MelDecoder(embedding_dim=512, output_shape=(128, 861)).to(device)
autoencoder = MelAutoencoder(encoder=encoder, decoder=decoder).to(device)

train_auto_encoder(autoencoder, train_loader, val_loader, epochs=200)

RuntimeError: Given transposed=1, weight of size [256, 256, 4, 4], expected input[32, 192, 32, 64] to have 256 channels, but got 192 channels instead

In [21]:
def visualize_first_10(autoencoder, dataloader, device, checkpoint_path = 'checkpoints/reference_encoder/checkpoint_epoch_20.pth'):
    autoencoder.load_state_dict(torch.load(checkpoint_path, map_location=device))
    autoencoder.eval()

    with torch.no_grad():
        for batch in dataloader:
            original = batch.to(device)
            reconstructed = autoencoder(original)
            break  # Only take the first batch

    original = original.cpu()
    reconstructed = reconstructed.cpu()

    # Plot the first 10 samples
    for i in range(10):
        plt.figure(figsize=(10, 3))

        # Original
        plt.subplot(1, 2, 1)
        plt.imshow(original[i][0], aspect='auto', origin='lower', cmap='magma')
        plt.title(f'Original #{i+1}')
        plt.colorbar()

        # Reconstructed
        plt.subplot(1, 2, 2)
        plt.imshow(reconstructed[i][0], aspect='auto', origin='lower', cmap='magma')
        plt.title(f'Reconstructed #{i+1}')
        plt.colorbar()

        plt.tight_layout()
        plt.show()


In [None]:
visualize_first_10(autoencoder, val_loader, device)
