**This code initializes the environment for a PyTorch-based project. It imports necessary libraries, sets a CUDA memory allocation configuration, checks if a GPU is available, and prints relevant GPU information, such as device name, count, and CUDA availability.**

In [1]:
import gc  # Garbage collection to free memory
import os
import pickle
from pathlib import Path

import numpy as np
import optuna as ot
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fx import traceback
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Device count:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

Using device: cuda
Torch version: 2.5.1+cu121
CUDA available: True
Device count: 1
Current device: 0
Device name: NVIDIA GeForce RTX 4070 Ti SUPER


In [2]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
# # music_df = pd.read_pickle("music_data.pkl")
# # music_df_with_lengths = pd.read_pickle("music_data_with_lengths.pkl")
# test_split = pd.read_pickle("test_split.pkl")
# train_split = pd.read_pickle("train_split.pkl")
# val_split = pd.read_pickle("val_split.pkl")
# train_split_with_spec_len = pd.read_pickle("train_split_with_spec_len.pkl")

In [4]:
# train_split_with_spec_len

In [5]:
# # Show the minimum and distribution of token sequence lengths
# token_lengths = music_df["Encoded_MIDI_Tokens"].apply(lambda x: len(x) if isinstance(x, list) else 0)
# print("Minimum token length:", token_lengths.min())
# print("Value counts:")
# print(token_lengths.value_counts().sort_index())

In [6]:
# def is_valid_token_list(x):
#     return isinstance(x, list) and len(x) > 0 and all(isinstance(t, (list, tuple)) and len(t) == 4 for t in x)
#
# bad_token_rows = music_df[~music_df["Encoded_MIDI_Tokens"].apply(is_valid_token_list)]
# print("🚨 Suspicious rows:", len(bad_token_rows))
# print(bad_token_rows[[".midi", "Encoded_MIDI_Tokens"]].head())

**This code defines two PyTorch classes: `ResidualBlock` and `CNNFeatureSequence`. The `ResidualBlock` implements a residual connection with two convolution layers, while `CNNFeatureSequence` builds a CNN feature extractor using multiple stages of residual blocks, allowing for sequence modeling from input spectrogram data.**

In [7]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dilation=1, downsample=None, dropout=0.0):
        """
        Constructs a Residual Block using two 2D convolutions with optional dilation and skip connections.
        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            dilation (int, optional): Dilation factor for convolutions, controls receptive field. Default is 1.
            downsample (callable, optional): Optional 1x1 conv layer to match input/output channels when needed.
            dropout (float, optional): Dropout rate applied between convolutions. Default is 0.0 (no dropout).
        """
        super(ResidualBlock, self).__init__()
        # First convolution layer with dilation
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               padding=dilation, dilation=dilation, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)  # Batch normalization for the first convolution
        self.relu = nn.ReLU(inplace=True)  # Activation function
        self.dropout = nn.Dropout2d(p=dropout) if dropout > 0 else nn.Identity()  # Dropout for regularization
        # Second convolution layer with dilation
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               padding=dilation, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)  # Batch normalization for the second convolution
        self.downsample = downsample  # Optional layer to match input and output dimensions for skip connections

    def forward(self, x):
        """
        Forward pass for the Residual Block.
        Args:
            x (Tensor): Input tensor of shape [B, C, F, T].
        Returns:
            Tensor: Output tensor of the same shape as input (unless downsample changes channels).
        """
        identity = x  # Save the original input for the skip connection
        if self.downsample is not None:
            identity = self.downsample(x)  # Match input dimensions with output dimensions using the downsample layer
        out = self.conv1(x)  # First convolution
        out = self.bn1(out)  # Normalize output
        out = self.relu(out)  # Apply activation
        out = self.dropout(out)  # Apply dropout if any
        out = self.conv2(out)  # Second convolution
        out = self.bn2(out)  # Normalize again
        out += identity  # Add skip connection (residual)
        out = self.relu(out)  # Activation after summing
        return out

class CNNFeatureSequence(nn.Module):
    def __init__(self, input_channels=11, feature_dim=512, depths=(6, 6, 6), dropout=0.1):
        """
        Constructs a CNN-based feature extractor for sequence modeling.
        Args:
            input_channels (int, optional): Number of input channels. Default is 11.
            feature_dim (int, optional): Dimensionality of feature outputs. Default is 512.
            depths (tuple of int): Number of residual blocks per stage. Default is (6, 6, 6).
            dropout (float): Dropout rate applied in residual blocks. Default is 0.1.
        """
        super(CNNFeatureSequence, self).__init__()
        # Initial convolutional layer to process the input
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=1, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)  # Batch normalization for the initial layer
        self.relu = nn.ReLU(inplace=True)  # Activation function
        # Pooling layer to downsample the frequency dimension
        self.pool_freq = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))

        # Depths of each stage
        d1, d2, d3 = depths
        # Sequential container for all stages of residual blocks
        self.layers = nn.Sequential(
            *self._make_layer(64, 128, num_blocks=d1, dilation=1, dropout=dropout),
            *self._make_layer(128, 256, num_blocks=d2, dilation=2, dropout=dropout),
            *self._make_layer(256, 512, num_blocks=d3, dilation=4, dropout=dropout),
        )

        # Final projection layer to produce the expected feature dimension
        self.projection = nn.Conv2d(512, feature_dim, kernel_size=1)

    def _make_layer(self, in_channels, out_channels, num_blocks, dilation, dropout):
        """
        Creates a sequence of residual blocks with optional dilation.
        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            num_blocks (int): Number of residual blocks.
            dilation (int): Dilation factor for convolution layers.
            dropout (float): Dropout rate for blocks.
        Returns:
            list[nn.Module]: A list of residual blocks to be used in a sequential layer.
        """
        # Downsample layer to match input and output dimensions when necessary
        downsample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
        ) if in_channels != out_channels else None

        # Create the first residual block with downsampling
        layers = [ResidualBlock(in_channels, out_channels, dilation=dilation, downsample=downsample, dropout=dropout)]
        # Add remaining residual blocks without downsampling
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, dilation=dilation, dropout=dropout))
        return layers

    def forward(self, x):
        """
        Forward pass for the CNN feature extractor.
        Args:
            x (Tensor): Input tensor of shape [B, C, F, T], where
                B = batch size,
                C = input channels (e.g. spectrogram types),
                F = frequency bins,
                T = time frames.
        Returns:
            Tensor: Output tensor of shape [B, T, D * F′], where D is the projected feature size
                    and F′ is the downsampled frequency dimension.
                    This format is suitable for Transformer input as a time-major sequence.
        """
        x = self.conv1(x)  # Initial convolution
        x = self.bn1(x)  # Batch normalization
        x = self.relu(x)  # Activation function
        x = self.pool_freq(x)  # Downsample frequency dimension
        x = self.layers(x)  # Pass through residual blocks
        x = self.projection(x)  # Project to feature_dim
        x = x.permute(0, 3, 1, 2)  # Rearrange dimensions for sequence modeling (time-major format)
        x = x.flatten(2)  # Flatten frequency dimension
        return x

In [8]:
train_split_with_spec_len_df = pd.read_pickle("train_split_with_spec_len.pkl")
train_split_with_spec_len_df = train_split_with_spec_len_df.dropna(subset=["Spectrogram_Path"])

# Create output directory for CNN features
cnn_output_dir = Path("cnn_outputs")
os.makedirs(cnn_output_dir, exist_ok=True)


In [9]:
MAX_T = 2048  # Maximum number of time steps per chunk
CHUNK_OVERLAP = 256  # Number of overlapping frames between chunks

def chunk_tensor(tensor, max_t, overlap=0, pad_short=True):
    """
    Splits an input tensor along the time dimension into chunks.
    Args:
        tensor (Tensor): Input tensor of shape [1, C, F, T].
        max_t (int): Maximum allowed time frames per chunk.
        overlap (int): Number of overlapping frames between consecutive chunks.
        pad_short (bool): Whether to pad the final chunk if it's shorter than max_t.
    Returns:
        list[Tensor]: List of [1, C, F, T_chunk] tensors.
    """
    chunks = []
    stride = max_t - overlap
    _, _, _, total_t = tensor.shape
    for start in range(0, total_t, stride):
        end = min(start + max_t, total_t)
        chunk = tensor[..., start:end]
        if chunk.shape[-1] < max_t:
            if pad_short:
                pad = max_t - chunk.shape[-1]
                chunk = F.pad(chunk, (0, pad))  # Pad at the end
            else:
                break  # Skip if we don't want short tails
        chunks.append(chunk)
    return chunks

In [10]:
class MultiHeadOutput(nn.Module):
    """
    Predicts pitch, velocity_bin, duration_bin, and time_bin from transformer outputs.
    """

    def __init__(self, d_model, pitch_classes=128, velocity_bins=32, duration_bins=64, time_bins=64):
        super().__init__()
        self.pitch_head = nn.Linear(d_model, pitch_classes)
        self.velocity_head = nn.Linear(d_model, velocity_bins)
        self.duration_head = nn.Linear(d_model, duration_bins)
        self.time_head = nn.Linear(d_model, time_bins)

    def forward(self, x):
        # x: [B, T, D]
        return {
            "pitch": self.pitch_head(x),  # [B, T, pitch_classes]
            "velocity": self.velocity_head(x),  # [B, T, velocity_bins]
            "duration": self.duration_head(x),  # [B, T, duration_bins]
            "time": self.time_head(x),  # [B, T, time_bins]
        }


In [11]:
class WAVtoMIDIModel(nn.Module):
    def __init__(self, input_channels=11, d_model=512,
                 pitch_classes=128, velocity_bins=32, duration_bins=64, time_bins=64):
        super().__init__()
        self.d_model = d_model
        self.cnn = CNNFeatureSequence(input_channels=input_channels, feature_dim=d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.output_heads = MultiHeadOutput(d_model, pitch_classes, velocity_bins, duration_bins, time_bins)

        # Not a module, will be set up later
        self._dynamic_projection = None

    def _ensure_projection(self, in_dim):
        if self._dynamic_projection is None or self._dynamic_projection.in_features != in_dim:
            self._dynamic_projection = nn.Linear(in_dim, self.d_model).to(next(self.parameters()).device)
            # Register properly for training & saving
            self.add_module("post_cnn_projection", self._dynamic_projection)

    def forward(self, x):
        if x.dim() == 4:
            # Input is spectrogram: [B, C=11, F, T]
            x = self.cnn(x)  # ➜ [B, T, feature_dim]
        elif x.dim() == 3:
            # Input is precomputed features: [B, T, feature_dim]
            pass  # Already suitable for projection
        else:
            raise ValueError(f"Unsupported input shape: {x.shape}")

        self._ensure_projection(x.shape[-1])
        x = self._dynamic_projection(x)
        assert x.shape[-1] == self.d_model, f"❌ Expected projected dim {self.d_model}, got {x.shape[-1]}"
        x = self.transformer(x)
        return self.output_heads(x)

    def transformer_only(self, features_tensor):
        self._ensure_projection(features_tensor.shape[-1])
        x = self._dynamic_projection(features_tensor)
        return self.transformer(x)

In [12]:
def compute_loss_classical(predictions, targets, component_weights=None, class_weights=None):
    """
    Compute the total weighted loss for quantized classical music tokens.
    Parameters:
        predictions (dict): Model outputs with keys ["pitch", "velocity", "duration", "time"],
            where each tensor has shape [B, T, C].
        targets (dict): True labels with the same keys as predictions, each with shape [B, T].
        component_weights (dict, optional): Scalar importance weights for each component.
            Defaults to {"pitch": 1.0, "velocity": 0.3, "duration": 1.0, "time": 1.2}.
        class_weights (dict, optional): Dictionary of torch.FloatTensor for per-class weights.
            Keys correspond to components.
    Returns:
        Tuple[Tensor, dict]: Total loss as a tensor and individual component losses as a dictionary.
    """
    total_loss = 0.0
    component_losses = {}

    if component_weights is None:
        component_weights = {"pitch": 1.0, "velocity": 0.3, "duration": 1.0, "time": 1.2}

    for key in predictions:
        pred = predictions[key].reshape(-1, predictions[key].shape[-1])  # [B*T, C]
        target = targets[key].reshape(-1)  # [B*T]
        C = pred.shape[-1]

        # Check for invalid class indices BEFORE calling cross_entropy
        if target.max() >= C or target.min() < 0:
            print(f"❌ Invalid class index in {key}: min={target.min().item()}, max={target.max().item()}, num_classes={C}")
            target = torch.clamp(target, 0, C - 1)  # Clamp to valid range to avoid crashing

        # Compute loss
        if class_weights and key in class_weights:
            weight_tensor = class_weights[key].to(pred.device)
            loss = F.cross_entropy(pred, target, weight=weight_tensor)
        else:
            loss = F.cross_entropy(pred, target)

        weighted_loss = component_weights.get(key, 1.0) * loss
        component_losses[key] = weighted_loss.item()
        total_loss += weighted_loss

    return total_loss, component_losses

In [13]:
def timing_similarity(pred_time, target_time, p=2.0, tau=5.0):
    """
    Computes a soft similarity score between predicted and target timing using a p-norm kernel.
    """
    error = torch.abs(pred_time - target_time).float()
    scaled_error = (error / tau).pow(p)
    return torch.exp(-scaled_error)


In [14]:
def compute_streak_compensated_score(preds, targets, alpha=1.0, beta=1.0):
    """
    Calculates a score where longer streaks of correct tokens reduce the penalty from mistakes.

    Args:
        preds (dict): logits for each component, shape [B, T, C]
        targets (dict): true class indices, shape [B, T]
        alpha (float): reward weight for streak length
        beta (float): penalty weight for mistakes

    Returns:
        float: streak-compensated performance score (higher is better)
    """
    pred_pitch = torch.argmax(preds["pitch"], dim=-1)
    pred_velocity = torch.argmax(preds["velocity"], dim=-1)
    pred_duration = torch.argmax(preds["duration"], dim=-1)
    pred_time = torch.argmax(preds["time"], dim=-1)

    all_correct = (
        (pred_pitch == targets["pitch"]) &
        (pred_velocity == targets["velocity"]) &
        (pred_duration == targets["duration"]) &
        (pred_time == targets["time"])
    ).int()  # [B, T]

    scores = []
    for sequence in all_correct:
        streak = 0
        streak_sum = 0
        error_count = 0
        for token in sequence:
            if token:
                streak += 1
                streak_sum += streak  # reward increases with streak length
            else:
                error_count += 1
                streak = 0
        N = len(sequence)
        reward = alpha * (streak_sum / N)
        penalty = beta * (error_count / N)
        scores.append(reward - penalty)

    return sum(scores) / len(scores)

In [15]:
def compute_soft_streak_score(preds, targets, alpha=1.0, beta=1.0, p=2.0, tau=5.0):
    pred_pitch = torch.argmax(preds["pitch"], dim=-1)
    pred_velocity = torch.argmax(preds["velocity"], dim=-1)
    pred_duration = torch.argmax(preds["duration"], dim=-1)
    pred_time = torch.argmax(preds["time"], dim=-1)

    pitch_correct = (pred_pitch == targets["pitch"]).float()
    velocity_correct = (pred_velocity == targets["velocity"]).float()
    duration_correct = (pred_duration == targets["duration"]).float()
    time_score = timing_similarity(pred_time, targets["time"], p=p, tau=tau)

    token_scores = pitch_correct * velocity_correct * duration_correct * time_score

    scores = []
    for sequence in token_scores:
        streak = 0.0
        streak_score = 0.0
        for token_score in sequence:
            if token_score > 0.5:  # Optional softness threshold
                streak += 1
                streak_score += streak
            else:
                streak = 0
        scores.append(streak_score / len(sequence))

    return alpha * torch.tensor(scores).mean().item() - beta * (1 - torch.tensor(scores).mean().item())


In [16]:
def collate_fn(batch):
    # Filter out bad samples and keep track of original indices
    batch = [(i, b) for i, b in enumerate(batch) if b is not None]
    if len(batch) == 0:
        return None, None

    indices, valid_batch = zip(*batch)
    spectrograms, midi_tokens = zip(*valid_batch)

    # Filter out spectrograms with incorrect channel count
    filtered = []
    for i, (spec, tokens) in enumerate(zip(spectrograms, midi_tokens)):
        if spec.shape[0] != 11:
            print(f"❌ Spectrogram at idx {indices[i]} has shape {spec.shape}, expected 11 channels")
            continue
        filtered.append((spec, tokens))

    if len(filtered) == 0:
        print("⚠️ All spectrograms had invalid channel counts")
        return None, None

    spectrograms, midi_tokens = zip(*filtered)

    # Stack spectrograms and pad MIDI tokens
    midi_tokens = [torch.tensor(seq, dtype=torch.long) for seq in midi_tokens]
    midi_tokens_padded = pad_sequence(midi_tokens, batch_first=True, padding_value=0)

    try:
        spectrograms = torch.stack(spectrograms)
    except RuntimeError as e:
        print(f"Error stacking spectrograms: {e}")
        return None, None

    print("\n📦 Batch sizes:")
    for i, tok in enumerate(midi_tokens):
        print(f"  ▸ Token {i}: shape={tok.shape}")
    print(f"Spectrogram shapes: {[s.shape for s in spectrograms]}\n")

    return spectrograms, midi_tokens_padded

In [17]:
class Quantizer:
    def __init__(
        self,
        velocity_bins=32,
        duration_bins=64,
        duration_range=(0.01, 10.0),
        time_bins=32,
        time_range=(0.01, 8.0),
    ):
        self.velocity_bins = velocity_bins
        self.duration_bins = duration_bins
        self.time_bins = time_bins

        # Velocity edges are implicit (fixed-width)
        self.velocity_bin_size = 128 / velocity_bins

        # Log-space bins for duration and time
        self.duration_edges = np.logspace(np.log10(duration_range[0]), np.log10(duration_range[1]), num=duration_bins + 1)
        self.time_edges = np.logspace(np.log10(time_range[0]), np.log10(time_range[1]), num=time_bins + 1)

    def quantize_velocity(self, velocity):
        """Quantize velocity [0–127] into bins."""
        bin_idx = int(velocity // self.velocity_bin_size)
        return min(bin_idx, self.velocity_bins - 1)

    def quantize_duration(self, duration_sec):
        """Quantize duration (in seconds) into log-spaced bins."""
        bin_idx = np.digitize(duration_sec, self.duration_edges) - 1
        return int(np.clip(bin_idx, 0, self.duration_bins - 1))

    def quantize_time(self, time_sec):
        """Quantize time since the last note (in seconds) into log-spaced bins."""
        bin_idx = np.digitize(time_sec, self.time_edges) - 1
        return int(np.clip(bin_idx, 0, self.time_bins - 1))

    def inverse_velocity(self, bin_idx):
        """Approximate original velocity from bin."""
        return int((bin_idx + 0.5) * self.velocity_bin_size)

    def inverse_duration(self, bin_idx):
        """Approximate original duration (seconds) from bin."""
        bin_idx = np.clip(bin_idx, 0, self.duration_bins - 1)
        return float((self.duration_edges[bin_idx] + self.duration_edges[bin_idx + 1]) / 2)

    def inverse_time(self, bin_idx):
        """Approximate original time gap (seconds) from bin."""
        bin_idx = np.clip(bin_idx, 0, self.time_bins - 1)
        return float((self.time_edges[bin_idx] + self.time_edges[bin_idx + 1]) / 2)

In [18]:
quantizer = Quantizer(
    velocity_bins=32,
    duration_bins=64,
    time_bins=64,
    duration_range=(0.01, 5.0),   # reasonable bounds from distribution
    time_range=(0.01, 5.0),       # ditto for time since last note
)


**This code defines a `MaestroDataset` class, a PyTorch `Dataset` for loading and preprocessing data for MIDI and spectrogram files.**

- It drops rows with missing `Spectrogram_Path` or `Encoded_MIDI_Tokens`.
- In `__getitem__`, it:
  - Loads and validates the spectrogram file.
  - Fixes spectrogram shape issues if needed.
  - Validates the structure of `Encoded_MIDI_Tokens`.
  - Returns processed `spectrogram` and `midi_tokens` as tensors.

Invalid data is logged and skipped.

In [19]:
class MaestroDataset(Dataset):
    def __init__(self, df):
        # Drop rows with missing spectrogram paths or encoded MIDI tokens
        self.df = df.dropna(subset=["Spectrogram_Path", "Encoded_MIDI_Tokens"])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        spectrogram_path = row["Spectrogram_Path"]

        if not spectrogram_path or not os.path.exists(spectrogram_path):
            raise IndexError(f"Missing spectrogram file at index {idx}: {spectrogram_path}")

        try:
            with open(spectrogram_path, "rb") as f:
                spectrogram = pickle.load(f)
            if not isinstance(spectrogram, torch.Tensor):
                spectrogram = torch.tensor(spectrogram, dtype=torch.float32)

            # Ensure spectrogram shape is [11, F, T]
            if spectrogram.ndim == 3 and spectrogram.shape[0] == 11:
                pass
            elif spectrogram.ndim == 2 and spectrogram.shape[0] == 11:
                spectrogram = spectrogram.unsqueeze(0)
            elif spectrogram.ndim == 3 and spectrogram.shape[1] == 11:
                spectrogram = spectrogram.permute(1, 0, 2)
            else:
                raise IndexError(f"Invalid spectrogram shape at index {idx}: {spectrogram.shape}")

        except Exception as e:
            raise IndexError(f"Error loading spectrogram at index {idx}: {e}")

        # Load and validate MIDI tokens
        midi_tokens = row["Encoded_MIDI_Tokens"]
        if not isinstance(midi_tokens, list) or len(midi_tokens) == 0:
            raise IndexError(f"Empty or invalid MIDI token list at index {idx}")

        if not all(isinstance(t, (list, tuple)) and len(t) == 4 for t in midi_tokens):
            raise IndexError(f"Malformed MIDI token structure at index {idx}: {midi_tokens[:3]}")

        # Quantize tokens
        quantized_tokens = []
        for pitch, velocity, duration, time in midi_tokens:
            velocity_bin = quantizer.quantize_velocity(velocity)
            duration_bin = quantizer.quantize_duration(duration)
            time_bin = quantizer.quantize_time(time)
            quantized_tokens.append((pitch, velocity_bin, duration_bin, time_bin))

        midi_tokens = torch.tensor(quantized_tokens, dtype=torch.long)

        if midi_tokens.ndim != 2 or midi_tokens.shape[1] != 4:
            raise IndexError(f"Unexpected MIDI tensor shape at index {idx}: {midi_tokens.shape}")

        return spectrogram, midi_tokens

In [20]:
train_dataset = MaestroDataset(train_split_with_spec_len_df)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

**This cell defines the `objective` function for an Optuna study to optimize hyperparameters for a model converting WAV to MIDI.**

Key steps:
1. **Clean Resources**: Frees GPU/CPU memory before training.
2. **Hyperparameter Sampling**: Samples CNN depths, dropout, transformer layers, number of heads, and learning rate.
3. **Model Setup**: Configures the model, transformer encoder, and optimizer with sampled hyperparameters.
4. **Training Loop**: Processes batches - applies chunking if spectrograms are too long, computes predictions, loss, and performs backpropagation.
5. **Streak Evaluation**: Computes soft streak scores for performance feedback.
6. **Plateau Handling**: Adjusts penalty (`p`) if performance improvement stalls beyond a patience threshold.
7. **Result Output**: Returns the negative mean score for minimization by Optuna.

In [21]:
plateau_count = 0
best_score = float("-inf")
p = 1.0  # Start soft
delta_p = 0.5  # Penalty increment value
max_p = 4.0  # Maximum penalty weight
patience = 15  # Trials without significant improvement before increasing penalty

# Store the best model per trial
best_model_state = None
best_trial_score = float("-inf")

use_fine_timing = False  # Start with snapped timing, refine later

MAX_T = 1048  # Lowered max time length for chunking
CHUNK_OVERLAP = 128
MEMORY_THRESHOLD = 0.8  # 70% of combined memory (GPU + CPU fallback)

import pynvml
import psutil
pynvml.nvmlInit()
gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0)

def total_memory_bytes():
    meminfo = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
    return meminfo.total + psutil.virtual_memory().total

def used_memory_bytes():
    meminfo = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
    return meminfo.used + psutil.virtual_memory().used

def gpu_memory_fraction():
    total = total_memory_bytes()
    used = used_memory_bytes()
    return used / total if total else 0

def objective(trial):
    global best_score, plateau_count, p, best_model_state, best_trial_score, use_fine_timing
    gc.collect()
    torch.cuda.empty_cache()

    # Sample hyperparameters
    depths = (
        trial.suggest_int("depth_stage1", 2, 6),
        trial.suggest_int("depth_stage2", 2, 6),
        trial.suggest_int("depth_stage3", 2, 6),
    )
    dropout = 0.0  # 🔒 Force dropout off to save memory
    num_layers = trial.suggest_int("num_transformer_layers", 2, 8)
    nhead = trial.suggest_categorical("nhead", [4, 8, 16])
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)

    model = WAVtoMIDIModel(input_channels=11, d_model=512).to(device)
    encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=nhead, batch_first=True)
    model.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()

    from torch.utils.checkpoint import checkpoint_sequential
    model.cnn.layers.forward = lambda x: checkpoint_sequential(model.cnn.layers, segments=3, input=x, use_reentrant=False)

    scores = []

    for i, batch in enumerate(train_loader):
        if batch is None:
            continue

        spectrograms, midi_tokens = batch
        t_lengths = [s.shape[-1] for s in spectrograms]
        print(f"\U0001f9ea Spectrogram time lengths → min: {min(t_lengths)}, max: {max(t_lengths)}, avg: {sum(t_lengths) / len(t_lengths):.2f}")
        spectrograms = spectrograms.to(device)
        midi_tokens = midi_tokens.to(device)

        try:
            targets = {
                "pitch": midi_tokens[..., 0],
                "velocity": midi_tokens[..., 1],
                "duration": midi_tokens[..., 2],
                "time": midi_tokens[..., 3],
            }
        except IndexError:
            print(f"⚠️ Bad token shape: {midi_tokens.shape}")
            return float("inf")

        optimizer.zero_grad()

        if spectrograms.shape[-1] > MAX_T:
            chunks = chunk_tensor(spectrograms[0].unsqueeze(0), MAX_T, overlap=CHUNK_OVERLAP, pad_short=True)
            all_preds = []
            accumulated_tokens = []
            for chunk in chunks:
                current_fraction = gpu_memory_fraction()
                if current_fraction >= MEMORY_THRESHOLD:
                    print(f"🚨 Combined memory usage exceeded {MEMORY_THRESHOLD*100:.0f}% ({current_fraction:.2f}), stopping chunk loading early.")
                    break
                print(f"🚦 Chunk shape: {chunk.shape}")
                chunk = chunk.to(device)
                chunk = chunk.float()
                feat = model.cnn(chunk)
                if feat.dim() == 2:
                    feat = feat.unsqueeze(0)
                out = model.transformer_only(feat)
                out = model.output_heads(out)
                all_preds.append(out)
                accumulated_tokens.append(MAX_T)
                del chunk, feat, out
                torch.cuda.empty_cache()

            if not all_preds:
                print("❌ No chunks processed before hitting memory limit.")
                # return float("inf")
                continue
            final_preds = {}
            for key in all_preds[0]:
                final_preds[key] = torch.cat([p[key] for p in all_preds], dim=1)
            preds = final_preds

            # Truncate targets to match final prediction length
            # chunked_length = sum(accumulated_tokens)
            chunked_length = final_preds["pitch"].shape[1]

            for key in targets:
                targets[key] = targets[key][:, :chunked_length]
        else:
            preds = model(spectrograms)

        target_len = targets["pitch"].shape[1]
        for key in preds:
            preds[key] = preds[key][:, :target_len, :]

        if use_fine_timing:
            print("🎯 Using fine timing targets")

        loss, _ = compute_loss_classical(preds, targets)
        assert loss.requires_grad, "Loss does not require grad — check no_grad() usage"
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            score = compute_soft_streak_score(preds, targets, p=p)
            scores.append(score)

        if i >= 5:
            break

    mean_score = sum(scores) / len(scores)
    improvement = mean_score - best_score

    if improvement < 0.01:
        plateau_count += 1
    else:
        best_score = mean_score
        plateau_count = 0

    if plateau_count >= patience and p < max_p:
        p = min(max_p, p + delta_p)
        print(f"⏫ Increasing streak penalty: p → {p}")
        plateau_count = 0
        use_fine_timing = True

    if mean_score > best_trial_score:
        best_trial_score = mean_score
        best_model_state = model.state_dict()

    return -mean_score

In [None]:
study = ot.create_study(direction="minimize")
study.optimize(objective, n_trials=100)

In [None]:
storage_path = "sqlite:///optuna_study_music_model.db"
study_name = "music_model_tuning"

# Create a persistent study and migrate the trials
persistent_study = ot.create_study(
    direction=study.direction,
    study_name=study_name,
    storage=storage_path,
    load_if_exists=True
)

# Copy all trials from the in-memory study
for trial in study.trials:
    persistent_study.add_trial(trial)

print(f"✅ Study saved to: {storage_path}")

In [22]:
study = ot.load_study(
    study_name="music_model_tuning",
    storage="sqlite:///optuna_study_music_model.db"
)

print(study.best_trial)

FrozenTrial(number=22, state=1, values=[0.9977099236566573], datetime_start=datetime.datetime(2025, 4, 20, 21, 36, 50, 758089), datetime_complete=datetime.datetime(2025, 4, 20, 22, 4, 20, 213042), params={'depth_stage1': 5, 'depth_stage2': 3, 'depth_stage3': 4, 'num_transformer_layers': 6, 'nhead': 4, 'lr': 0.0006127182157069003}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'depth_stage1': IntDistribution(high=6, log=False, low=2, step=1), 'depth_stage2': IntDistribution(high=6, log=False, low=2, step=1), 'depth_stage3': IntDistribution(high=6, log=False, low=2, step=1), 'num_transformer_layers': IntDistribution(high=8, log=False, low=2, step=1), 'nhead': CategoricalDistribution(choices=(4, 8, 16)), 'lr': FloatDistribution(high=0.001, log=True, low=1e-05, step=None)}, trial_id=23, value=None)


In [23]:
best_params = study.best_params
print("🔧 Best hyperparameters:", best_params)

🔧 Best hyperparameters: {'depth_stage1': 5, 'depth_stage2': 3, 'depth_stage3': 4, 'num_transformer_layers': 6, 'nhead': 4, 'lr': 0.0006127182157069003}


In [24]:
# Unpack the best parameters
depths = (
    best_params["depth_stage1"],
    best_params["depth_stage2"],
    best_params["depth_stage3"]
)
num_layers = best_params["num_transformer_layers"]
nhead = best_params["nhead"]
lr = best_params["lr"]

# Create model
model = WAVtoMIDIModel(input_channels=11, d_model=512).to(device)
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=nhead, batch_first=True)
model.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
# Initialize the model with tuned parameters and move it to the correct device
best_depths = depths
best_dropout = 0.0       # From tuning
input_channels = 11
feature_dim = 512

model = CNNFeatureSequence(
    input_channels=input_channels,
    feature_dim=feature_dim,
    depths=best_depths,
    dropout=best_dropout
).to(device)
model.eval()

# Create a list to store paths to CNN feature outputs
cnn_feature_paths = []

# Process and save CNN outputs for each row in the DataFrame
for idx, row in train_split_with_spec_len_df.iterrows():
    feature_path = cnn_output_dir / f"cnn_output_{idx}.pt"

    if feature_path.exists():
        train_split_with_spec_len_df.at[idx, "CNN_Feature_Path"] = str(feature_path)
        continue

    print(f"{idx}")
    try:
        with open(row["Spectrogram_Path"], "rb") as f:
            spec = pickle.load(f)

        if not isinstance(spec, torch.Tensor):
            spec = torch.tensor(spec, dtype=torch.float32)

        spec = spec.unsqueeze(0).to(device)  # Shape: [1, C, F, T]

        with torch.no_grad():
            if spec.shape[-1] > MAX_T:
                chunks = chunk_tensor(spec, MAX_T, overlap=CHUNK_OVERLAP)
                outputs = [model(chunk).squeeze(0).cpu() for chunk in chunks]
                output = torch.cat(outputs, dim=0)
            else:
                output = model(spec).squeeze(0).cpu()

        torch.save(output, feature_path)
        train_split_with_spec_len_df.at[idx, "CNN_Feature_Path"] = str(feature_path)

        del spec, output
        gc.collect()
        print(f"RAM used: {psutil.virtual_memory().used / 1e9:.2f} GB")

        if idx % 10 == 0:
            train_split_with_spec_len_df.to_pickle("train_split_with_cnn_paths_progress.pkl")

    except RuntimeError as e:
        print(f"⚠️ RuntimeError at index {idx}: {e}")
        torch.cuda.empty_cache()
        gc.collect()
        train_split_with_spec_len_df.at[idx, "CNN_Feature_Path"] = None

    except Exception as e:
        print(f"⚠️ Unexpected error at index {idx}: {e}")
        traceback.print_exc()
        train_split_with_spec_len_df.at[idx, "CNN_Feature_Path"] = None

print("✅ Done.")

In [None]:
def train_full_model(
        model: torch.nn.Module,
        train_loader: torch.utils.data.DataLoader,
        epochs: int = 10,
        use_fine_timing: bool = False,
        checkpoint_dir: str = "checkpoints"
) -> None:
    """
    Trains the specified model on the provided training data loader for a given number of epochs. This function
    includes memory management for GPU usage and handles fine-grained timing sensitivity during training. It saves
    the model's checkpoints at each epoch and retains the best model based on evaluation scores.

    :param model: The model to be trained, equipped with methods for forward propagation and output prediction.
    :param train_loader: An iterable data loader providing batches of training data.
    :param epochs: The number of training epochs. Default is 10.
    :param use_fine_timing: Boolean flag to enable or disable fine-grained timing sensitivity during training.
    :type use_fine_timing: bool
    :param checkpoint_dir: Directory path to store the model checkpoints.
    :type checkpoint_dir: str
    :return: None
    """
    import traceback
    import pynvml
    import psutil

    def gpu_memory_fraction():
        try:
            pynvml.nvmlInit()
            handle = pynvml.nvmlDeviceGetHandleByIndex(0)
            gpu_mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
            cpu_mem = psutil.virtual_memory()
            total = gpu_mem.total + cpu_mem.total
            used = gpu_mem.used + cpu_mem.used
            return used / total if total else 0
        except Exception as e:
            print(f"⚠️ Could not read GPU memory stats: {e}")
            return 0.0

    model.train()
    global p
    p = 2.0  # Final fine-grained timing sensitivity

    os.makedirs(checkpoint_dir, exist_ok=True)
    best_score = float("-inf")

    for epoch in range(epochs):
        total_loss = 0.0
        score_sum = 0.0
        count = 0

        for batch_idx, batch in enumerate(train_loader):
            try:
                if batch is None:
                    print(f"⚠️ Skipping batch {batch_idx} due to None value")
                    continue

                spectrograms, midi_tokens = batch
                if spectrograms.shape[0] != midi_tokens.shape[0]:
                    print(f"❌ Mismatched batch size: spectrograms={spectrograms.shape[0]}, tokens={midi_tokens.shape[0]}")
                    continue

                spectrograms = spectrograms.to(device)
                midi_tokens = midi_tokens.to(device)

                if midi_tokens.shape[-1] != 4:
                    print(f"❌ Unexpected MIDI token shape: {midi_tokens.shape}")
                    continue

                targets = {
                    "pitch": midi_tokens[..., 0],
                    "velocity": midi_tokens[..., 1],
                    "duration": midi_tokens[..., 2],
                    "time": midi_tokens[..., 3],
                }

                optimizer.zero_grad()

                if spectrograms.shape[-1] > MAX_T:
                    chunks = chunk_tensor(spectrograms[0].unsqueeze(0), MAX_T, overlap=CHUNK_OVERLAP, pad_short=True)
                    all_preds = []
                    for chunk_idx, chunk in enumerate(chunks):
                        if gpu_memory_fraction() > MEMORY_THRESHOLD:
                            print(f"🚨 Memory threshold hit after chunk {chunk_idx}, skipping remaining chunks")
                            break
                        chunk = chunk.to(device)
                        feat = model.cnn(chunk)
                        if feat.dim() == 2:
                            feat = feat.unsqueeze(0)
                        out = model.transformer_only(feat)
                        preds = model.output_heads(out)
                        all_preds.append(preds)
                        del chunk, feat, out
                        torch.cuda.empty_cache()

                    if not all_preds:
                        print(f"⚠️ No valid chunks processed for batch {batch_idx}")
                        continue

                    final_preds = {k: torch.cat([p[k] for p in all_preds], dim=1) for k in all_preds[0]}
                    preds = final_preds

                    for key in targets:
                        T_pred = preds[key].shape[1]
                        T_target = targets[key].shape[1]
                        min_T = min(T_pred, T_target)
                        preds[key] = preds[key][:, :min_T, :]
                        targets[key] = targets[key][:, :min_T]
                else:
                    preds = model(spectrograms)

                loss, _ = compute_loss_classical(preds, targets)
                if not loss.requires_grad:
                    print(f"❌ Loss does not require grad in batch {batch_idx}, skipping")
                    continue
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    score = compute_soft_streak_score(preds, targets, p=p)
                    score_sum += score
                    total_loss += loss.item()
                    count += 1

            except Exception as e:
                print(f"🛑 Error in batch {batch_idx}: {e}")
                traceback.print_exc()
                continue

        if count == 0:
            print("⚠️ No valid batches processed in this epoch")
            continue

        avg_loss = total_loss / count
        avg_score = score_sum / count
        print(f"📚 Epoch {epoch+1}/{epochs} — Loss: {avg_loss:.4f}, Score: {avg_score:.4f}")

        checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch+1:03d}_score_{avg_score:.4f}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
            'score': avg_score
        }, checkpoint_path)

        if avg_score > best_score:
            best_score = avg_score
            best_path = os.path.join(checkpoint_dir, "best_model.pt")
            torch.save(model.state_dict(), best_path)
            print(f"💾 Best model updated: {best_path}")

In [None]:
train_full_model(model, train_loader, epochs=20, use_fine_timing=True)