# Imports:

In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


In [None]:
class MelDataset(Dataset):
    def __init__(self, csv_path, transform=None):
        self.df = pd.read_csv(csv_path)
        self.mel_paths = self.df["mel_npy_path"].tolist()
        self.transform = transform

    def __len__(self):
        return len(self.mel_paths)

    def __getitem__(self, idx):
        mel = np.load(self.mel_paths[idx])  # Shape: (n_mels, time)
        mel = (mel - mel.min()) / (mel.max() - mel.min()) # Normalize to [0, 1]
        mel = torch.tensor(mel, dtype=torch.float32).unsqueeze(0)  # (1, n_mels, time)
        if self.transform:
            mel = self.transform(mel)
        return mel

In [None]:
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=2)
        self.activation = nn.ReLU()

    def forward(self, x):
        residual = self.shortcut(x)
        x = self.conv(x)
        x += residual
        return self.activation(x)


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        x = self.upsample(x)
        return self.conv(x)


class MelAutoencoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        z, skips = self.encoder(x)
        recon = self.decoder(z, skips)
        return recon

In [None]:
class MelEncoder(nn.Module):
    def __init__(self, embedding_dim=256):
        super().__init__()
        
        # Initial convolution with immediate downsampling
        self.initial_conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2),  # /2
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        
        # Downsampling blocks
        self.down1 = DownBlock(32, 64)    # /4
        self.down2 = DownBlock(64, 128)   # /8
        self.down3 = DownBlock(128, 256)  # /16
        
        # Final projection
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.proj = nn.Linear(256, embedding_dim)

    def forward(self, x):
        skips = []
        
        x = self.initial_conv(x)  # [B, 32, H/2, W/2]
        skips.append(x)
        
        x = self.down1(x)        # [B, 64, H/4, W/4]
        skips.append(x)
        x = self.down2(x)        # [B, 128, H/8, W/8]
        skips.append(x)
        x = self.down3(x)        # [B, 256, H/16, W/16]
        
        x = self.pool(x)         # [B, 256, 1, 1]
        x = x.flatten(1)         # [B, 256]
        x = self.proj(x)         # [B, embedding_dim]
        
        return x, skips

In [None]:
class MelDecoder(nn.Module):
    def __init__(self, embedding_dim=256, output_shape=(128, 861)):
        super().__init__()
        self.output_shape = output_shape
        
        # Initial expansion
        self.fc = nn.Linear(embedding_dim, 256 * 4 * 8)  # Matches encoder's final spatial dim
        
        # Upsampling blocks
        self.up1 = UpBlock(256, 128)       # 4x8 -> 8x16
        self.up2 = UpBlock(128 + 128, 64)  # 8x16 -> 16x32 (with skip)
        self.up3 = UpBlock(64 + 64, 32)    # 16x32 -> 32x64 (with skip)
        
        # Final convolution
        self.final = nn.Sequential(
            nn.Conv2d(32 + 32, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 1, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, x, skips):
        x = self.fc(x)
        x = x.view(-1, 256, 4, 8)  # [B, 256, 4, 8]
        
        # Reverse skip connections order to match decoder progression
        skip1, skip2, skip3 = skips[2], skips[1], skips[0]
        
        x = self.up1(x)                         # [B, 128, 8, 16]
        x = self.up2(torch.cat([x, skip3], dim=1))  # [B, 64, 16, 32]
        x = self.up3(torch.cat([x, skip2], dim=1))  # [B, 32, 32, 64]
        
        # Final upsampling with initial skip
        x = F.interpolate(x, scale_factor=2, mode='bilinear')  # [B, 32, 64, 128]
        x = self.final(torch.cat([x, skip1], dim=1))          # [B, 1, 64, 128]
        x = F.interpolate(x, size=self.output_shape, mode='bilinear')
        return x

In [None]:
class MelAutoencoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        z, skips = self.encoder(x)  # Now unpack both embedding and skip connections
        recon = self.decoder(z, skips)  # Pass both to decoder
        return recon

In [None]:
def train_auto_encoder(model, train_loader, val_loader, epochs=200, checkpoint_dir='checkpoints/reference_encoder'):
    model.to(device)
    optimizer = optim.Adam(
        list(model.parameters()),
        lr=1e-3, weight_decay=1e-4  # L2 regularization term
    )

    os.makedirs(checkpoint_dir, exist_ok=True)

    for epoch in range(1, epochs + 1):
        # ----- Training -----
        model.train()
        train_loss = 0
        for mel in train_loader:
            mel = mel.to(device)  # [B, 1, M, T]
            recon = model(mel)
            loss = F.l1_loss(recon, mel) + 0.1 * F.mse_loss(recon, mel)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # ----- Validation -----
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for mel in val_loader:
                mel = mel.to(device)
                recon = model(mel)
                loss = F.l1_loss(recon, mel)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch}, Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

        # ----- Save checkpoint -----
        if epoch % 5 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")


In [None]:
train_df = pd.read_csv("Data/split_train_val/train_split.csv")
val_df = pd.read_csv("Data/split_train_val/val_split.csv")

train_dataset = MelDataset(csv_path="Data/split_train_val/train_split.csv")
val_dataset = MelDataset(csv_path="Data/split_train_val/val_split.csv")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = MelEncoder(embedding_dim=512).to(device)
decoder = MelDecoder(embedding_dim=512, output_shape=(128, 861)).to(device)
autoencoder = MelAutoencoder(encoder=encoder, decoder=decoder).to(device)

train_auto_encoder(autoencoder, train_loader, val_loader, epochs=200)

In [None]:
def visualize_first_10(autoencoder, dataloader, device, checkpoint_path = 'checkpoints/reference_encoder/checkpoint_epoch_20.pth'):
    autoencoder.load_state_dict(torch.load(checkpoint_path, map_location=device))
    autoencoder.eval()

    with torch.no_grad():
        for batch in dataloader:
            original = batch.to(device)
            reconstructed = autoencoder(original)
            break  # Only take the first batch

    original = original.cpu()
    reconstructed = reconstructed.cpu()

    # Plot the first 10 samples
    for i in range(10):
        plt.figure(figsize=(10, 3))

        # Original
        plt.subplot(1, 2, 1)
        plt.imshow(original[i][0], aspect='auto', origin='lower', cmap='magma')
        plt.title(f'Original #{i+1}')
        plt.colorbar()

        # Reconstructed
        plt.subplot(1, 2, 2)
        plt.imshow(reconstructed[i][0], aspect='auto', origin='lower', cmap='magma')
        plt.title(f'Reconstructed #{i+1}')
        plt.colorbar()

        plt.tight_layout()
        plt.show()


In [None]:
visualize_first_10(autoencoder, val_loader, device)
