In [56]:
import os
import re
import numpy as np
import librosa
from collections import defaultdict
import shutil


def get_file_list(audio_folder, text_folder):
    """
    This function gets the base file names that are present in both the audio and text folders.

    Args:
    - audio_folder: Path to the folder containing the audio files.
    - text_folder: Path to the folder containing the text files.

    Returns:
    - A list of base file names (without extensions) that are present in both folders.
    """
    # Get the set of audio files (without extensions)
    audio_files = set(os.path.splitext(f)[0] for f in os.listdir(audio_folder) if f.endswith('.wav'))

    # Get the set of text files (without extensions)
    text_files = set(os.path.splitext(f)[0] for f in os.listdir(text_folder) if f.endswith('.txt'))

    # Find the intersection of both sets (i.e., files present in both folders)
    common_files = list(audio_files.intersection(text_files))

    return sorted(common_files)


def split_dataset(files):
    """
    Splits the dataset into train_tr, train_va, test_set1, and test_set2.

    Args:
    - files: List of file names (without extensions).

    Returns:
    - train_tr_files, train_va_files, test_set1_files, test_set2_files
    """
    virtual_instruments = {'AkPnCGdD', 'AkPnStgb', 'AkPnBcht', 'AkPnBsdf', 'SptkBGCl', 'SptkBGAm', 'StbgTGd2'}
    real_instruments = {'ENSTDkAm', 'ENSTDkCl'}

    pattern = re.compile(r'MAPS_MUS-(.*)_(\w+)')
    pieces_dict = defaultdict(set)

    for file_name in files:
        match = pattern.search(file_name)
        if match:
            piece_name = match.group(1)  # First capturing group for the piece name
            instrument_name = match.group(2)  # Second capturing group for the instrument name
            pieces_dict[piece_name].add(instrument_name)

    # Identify seen/unseen pieces
    unseen_pieces = {piece for piece, instruments in pieces_dict.items() if len(instruments) == 1}
    seen_pieces = {piece for piece in pieces_dict if piece not in unseen_pieces}

    # Initialize file lists for each set
    train_tr_files, train_va_files, test_set1_files, test_set2_files = [], [], [], []

    # Step 1: Split files
    for file_name in files:
        match = pattern.search(file_name)
        if match:
            piece_name = match.group(1)
            instrument_name = match.group(2)

            if instrument_name in virtual_instruments:
                if piece_name in seen_pieces:
                    train_tr_files.append(file_name)
                elif piece_name in unseen_pieces:
                    train_va_files.append(file_name)
            elif instrument_name in real_instruments:
                if piece_name in unseen_pieces:
                    test_set1_files.append(file_name)
                test_set2_files.append(file_name)

    return train_tr_files, train_va_files, test_set1_files, test_set2_files


In [57]:
def preprocess_files(audio_folder, text_folder, split_file_list, output_folder, sr=16000, hop_length=512, bins_per_octave=36, n_bins=252, frames_per_file=40000, batch_size=5000, normalization=None):
    """
    Preprocess audio files by applying CQT, aligning labels, concatenating frames of the same song,
    splitting into chunks of frames (default 40,000) and saving as sequentially numbered files.
    Mini-batches of 5000 frames are used for training, with silence-only chunks removed.

    Args:
    - audio_folder: Path to the folder containing the original audio files.
    - text_folder: Path to the folder containing the original text files.
    - split_file_list: List of files to process for a specific data split (train, val, test).
    - output_folder: Folder where the processed files should be saved.
    - sr: Target sample rate for downsampling.
    - hop_length: Hop length for CQT.
    - bins_per_octave: Number of bins per octave for CQT.
    - n_bins: Total number of bins for CQT (default 252 for 7 octaves).
    - frames_per_file: Number of frames per output file (adjusted for sample rate if needed).
    - batch_size: Number of frames per mini-batch (default 5000).
    - normalization: Tuple of (mean, min, max) for normalization, or None for training set.
    """

    os.makedirs(output_folder, exist_ok=True)
    mean_X, min_X, max_X = [], [], []

    # Initialize concatenated lists
    concatenated_cqt = []
    concatenated_labels = []

    # File counter for naming the output files sequentially
    file_counter = 0

    for file_base in split_file_list:
        audio_file = os.path.join(audio_folder, f"{file_base}.wav")
        text_file = os.path.join(text_folder, f"{file_base}.txt")

        if not os.path.exists(audio_file) or not os.path.exists(text_file):
            print(f"Missing files for {file_base}, skipping.")
            continue

        # Load audio and convert to mono
        y, sr_orig = librosa.load(audio_file, sr=sr, mono=True)

        # Apply CQT
        cqt_features = np.abs(librosa.cqt(y, sr=sr, hop_length=hop_length, n_bins=n_bins, bins_per_octave=bins_per_octave)).T
        num_frames = cqt_features.shape[0]

        # Vector of time stamps
        win_len = hop_length / float(sr)
        vector_aux = np.arange(1, num_frames + 1) * win_len

        # Align labels with CQT frames
        labels = np.zeros((num_frames, 88))  # 88 piano keys (from MIDI 21 to 108)
        with open(text_file, 'r') as f:
            for line in f:
                line = line.strip()
                if not line:  # Skip empty lines
                    continue
                parts = line.strip().split()
                if "OnsetTime" not in line and len(parts) == 3:
                    init_range, fin_range, pitch = float(parts[0]), float(parts[1]), int(parts[2])
                    pitch = int(pitch) - 21  # MIDI note to index (MIDI 21-108)

                    index_min = np.where(vector_aux >= init_range)[0]
                    index_max = np.where(vector_aux - 0.01 > int((fin_range) * 100) / float(100))[0]

                    if len(index_min) == 0 or len(index_max) == 0:
                        continue
                    labels[index_min[0]:index_max[0], pitch] = 1

        # Normalize features if applicable
        if normalization:
            min_train, max_train, _ = normalization  # Do not subtract mean if you want strict [0, 1] normalization
            cqt_features = (cqt_features - min_train) / (max_train - min_train)  # Min-max normalization to [0, 1]
        else:
            min_X.append(cqt_features.min())
            max_X.append(cqt_features.max())
            mean_X.append(cqt_features.sum(axis=0))

        # Concatenate frames and labels for this song
        concatenated_cqt.append(cqt_features)
        concatenated_labels.append(labels)

    # Once all songs are processed, concatenate into a single large array
    concatenated_cqt = np.concatenate(concatenated_cqt, axis=0)
    concatenated_labels = np.concatenate(concatenated_labels, axis=0)

    # Handle chunking and mini-batch splitting
    num_chunks = len(concatenated_cqt) // frames_per_file
    total_frames = len(concatenated_cqt)

    for i in range(num_chunks):
        chunk_cqt = concatenated_cqt[i * frames_per_file: (i + 1) * frames_per_file]
        chunk_labels = concatenated_labels[i * frames_per_file: (i + 1) * frames_per_file]

        num_mini_batches = frames_per_file // batch_size

        for j in range(num_mini_batches):
            mini_cqt = chunk_cqt[j * batch_size: (j + 1) * batch_size]
            mini_labels = chunk_labels[j * batch_size: (j + 1) * batch_size]

            # Skip silent mini-batches (all-zero frames)
            if np.all(mini_cqt == 0):
                print(f"Skipping silent mini-batch at chunk {i}, mini-batch {j}.")
                continue

            # Save the mini-batch
            np.save(os.path.join(output_folder, f"{file_counter}_X.npy"), mini_cqt)
            np.save(os.path.join(output_folder, f"{file_counter}_y.npy"), mini_labels)
            file_counter += 1

    # Handle the remainder of the data
    remainder_cqt = concatenated_cqt[num_chunks * frames_per_file:]
    remainder_labels = concatenated_labels[num_chunks * frames_per_file:]

    if len(remainder_cqt) > 0:
        # Pad the remainder to the nearest multiple of 5000
        num_remainder_frames = len(remainder_cqt)
        pad_size = batch_size * ((num_remainder_frames + batch_size - 1) // batch_size) - num_remainder_frames

        padded_cqt = np.pad(remainder_cqt, ((0, pad_size), (0, 0)), mode='constant')
        padded_labels = np.pad(remainder_labels, ((0, pad_size), (0, 0)), mode='constant')

        # Split the padded remainder into mini-batches
        num_mini_batches = len(padded_cqt) // batch_size

        for j in range(num_mini_batches):
            mini_cqt = padded_cqt[j * batch_size: (j + 1) * batch_size]
            mini_labels = padded_labels[j * batch_size: (j + 1) * batch_size]

            if np.all(mini_cqt == 0):
                print(f"Skipping silent mini-batch in remainder, mini-batch {j}.")
                continue

            # Save the remainder mini-batches
            np.save(os.path.join(output_folder, f"{file_counter}_X.npy"), mini_cqt)
            np.save(os.path.join(output_folder, f"{file_counter}_y.npy"), mini_labels)
            file_counter += 1

    # Return normalization values for training
    if not normalization:
        total_frames = sum([x.shape[0] for x in mean_X])
        train_mean = np.sum(mean_X, axis=0) / total_frames
        min_train = min(min_X)
        max_train = max(max_X)
        print(f'Train mean: {train_mean}, Min_train: {min_train}, Max_train: {max_train}')
        return train_mean, min_train, max_train


In [58]:
def normalize_dataset(audio_folder, text_folder, train_tr_files, train_va_files, test_set1_files, test_set2_files, output_folder):
    """
    Handles preprocessing and normalizing the dataset.
    """
    # Process train_tr first to get normalization parameters
    print("Processing train_tr set...")
    train_mean, min_train, max_train = preprocess_files(audio_folder, text_folder, train_tr_files, os.path.join(output_folder, 'train_tr'))

    normalization = (min_train, max_train, train_mean)

    # Normalize the training set after computing normalization parameters
    print("Normalizing train_tr set...")
    preprocess_files(audio_folder, text_folder, train_tr_files, os.path.join(output_folder, 'train_tr_normalized'), normalization=normalization)

    # Process validation and test sets with normalization
    print("Processing train_va set...")
    preprocess_files(audio_folder, text_folder, train_va_files, os.path.join(output_folder, 'train_va'), normalization=normalization)

    print("Processing test_set1...")
    preprocess_files(audio_folder, text_folder, test_set1_files, os.path.join(output_folder, 'test_set1'), normalization=normalization)

    print("Processing test_set2...")
    preprocess_files(audio_folder, text_folder, test_set2_files, os.path.join(output_folder, 'test_set2'), normalization=normalization)


In [59]:
audio_folder = '/root/dev/data/audio'
text_folder = '/root/dev/data/text'
output_folder = '/root/dev/data/paper_split'

file_list = get_file_list(audio_folder, text_folder)


# Split dataset
train_tr_files, train_va_files, test_set1_files, test_set2_files = split_dataset(file_list)

# Preprocess and normalize dataset
normalize_dataset(audio_folder, text_folder, train_tr_files, train_va_files, test_set1_files, test_set2_files, output_folder)


Processing train_tr set...
Train mean: [0.02457724 0.02529809 0.02786504 0.03002799 0.03079498 0.03267209
 0.03368827 0.03351117 0.03332503 0.03426971 0.03673549 0.04234582
 0.04762595 0.0425057  0.04863415 0.05898338 0.04873133 0.03821773
 0.03713151 0.03762642 0.03512997 0.03420176 0.03457216 0.03939056
 0.04131493 0.03959154 0.04318852 0.04799475 0.04913368 0.07063571
 0.08748765 0.07562285 0.08489244 0.10204061 0.10850244 0.15621868
 0.19656268 0.18472055 0.20605214 0.23337217 0.23000203 0.26452413
 0.29238722 0.29059452 0.41652662 0.55452853 0.44770378 0.3807099
 0.42793623 0.36022124 0.39972794 0.48097602 0.3743621  0.3845272
 0.46538952 0.4226783  0.59597176 0.7441719  0.60925376 0.6367992
 0.7473693  0.64782584 0.66562074 0.7969145  0.6706506  0.7637556
 0.99923366 0.7478186  0.6015459  0.7401919  0.60178983 0.6925185
 0.9598573  0.7034014  0.60420597 0.71232414 0.5562162  0.59739554
 0.7230227  0.5760288  0.59792984 0.74375755 0.5571973  0.5982544
 0.8099549  0.6414364  0.7892