# MusicNetEM train data
模仿 `multi_large_256`: 900帧，hop=256，fs=22050

训练集和测试集大小比为10:1

为了音色分离的训练，将相同的乐器合并了

- `.mid` 为midi
- `.npy` 为midi转pianoroll
- `.wav` 为22050Hz的900帧音频

In [1]:
hop = 256
fs = 22050
frames = 900
s_per_frame = hop / fs

In [2]:
import os
for d in ("train", "test"):
    os.makedirs(d, exist_ok=True)

wave_dirs = [
    "./musicnet/train_data/",
    "./musicnet/test_data/",
]
wave_id_path = {}   # key: wave_id, value: path relative to this file
for d in wave_dirs:
    for root, _, files in os.walk(d):
        for fname in files:
            if fname.lower().endswith(('.wav')):
                path = os.path.join(root, fname)
                key = os.path.splitext(fname)[0]
                wave_id_path[key] = os.path.relpath(path, start=os.getcwd())

midi_dir = "./musicnet_em"
midi_wave_pairs = []
for root, _, files in os.walk(midi_dir):
    for fname in files:
        if fname.lower().endswith(('.mid', '.midi')):
            key = os.path.splitext(fname)[0]
            midi_path = os.path.relpath(os.path.join(root, fname), start=os.getcwd())
            wave_path = wave_id_path.get(key)
            if wave_path:
                midi_wave_pairs.append([midi_path, wave_path])
            else:
                print(f"warning: no matching wave for midi '{key}'")

In [3]:
import sys
sys.path.append("../..")

import soundfile as sf
from scipy.signal import resample_poly
import math
import mido
from utils.midiarray import numpy2midi, midi2numpy, midiInstruments
import numpy as np
os.makedirs("temp", exist_ok=True)

i = 0
for midi_path, wave_path in midi_wave_pairs:
    y, sr = sf.read(wave_path)
    target_sr = fs
    if sr != target_sr:
        gcd = math.gcd(sr, target_sr)
        up = target_sr // gcd
        down = sr // gcd
        y = resample_poly(y, up, down, axis=0)
    y = y.astype('float32', copy=False)
    # convert to mono if multi-channel
    if getattr(y, "ndim", 1) > 1:
        y = y.mean(axis=1)
    
    mid = mido.MidiFile(midi_path)
    midiarr = midi2numpy(mid, time_step=s_per_frame, track_separate=True)   # midiarr: (track, notes, time)
    instruments = midiInstruments(mid)
    # group track indices by instrument identity (fall back to string key if unhashable)
    groups = {}
    for idx, inst in enumerate(instruments):
        try:
            key = inst
            hash(key)
        except Exception:
            key = str(inst)
        groups.setdefault(key, []).append(idx)

    merged_tracks = []
    merged_instruments = []
    for key, idxs in groups.items():
        if len(idxs) == 1:
            track = midiarr[idxs[0]].copy()
            inst = instruments[idxs[0]]
        else:
            # elementwise max across tracks that share the same instrument
            track = np.max(midiarr[idxs], axis=0, keepdims=False)
            inst = instruments[idxs[0]]
        merged_tracks.append(track)
        merged_instruments.append(inst)

    midiarr = np.stack(merged_tracks, axis=0)
    instruments = merged_instruments

    total_frames = midiarr.shape[-1]
    overlap = int(frames * 0.1)
    step = frames - overlap
    segment_samples = frames * hop

    start = 0
    while start + frames <= total_frames:
        end = start + frames
        # slice midi: (track, notes, time)
        midi_seg = midiarr[..., start:end].copy()
        # if notes at t=0 are non-zero, set them to 2
        nonzero = midi_seg[..., 0] != 0
        if nonzero.any():
            midi_seg[..., 0][nonzero] = 2

        # save midi segment
        mid_file = numpy2midi(midi_seg, s_per_frame, instrument=instruments)
        mid_file.save(f"temp/{i}.mid")

        # save np
        np.save(f"temp/{i}.npy", midi_seg)

        # slice audio and save wav
        s0 = start * hop
        s1 = s0 + segment_samples
        audio_seg = y[s0:s1]
        if len(audio_seg) < segment_samples:
            pad_width = segment_samples - len(audio_seg)
            audio_seg = np.pad(audio_seg, (0, pad_width), mode='constant')
        wav_out = f"temp/{i}.wav"
        sf.write(wav_out, audio_seg, target_sr)

        i += 1
        start += step

In [4]:
import soundfile as sf

lengths = []
for fname in os.listdir("temp"):
    if fname.endswith(".wav"):
        path = os.path.join("temp", fname)
        y, sr = sf.read(path)
        lengths.append(len(y))

lengths = np.array(lengths)
print(f"Total wav files: {len(lengths)}")
print(f"Min length: {lengths.min()}")
print(f"Max length: {lengths.max()}")
print(f"Mean length: {lengths.mean()}")
print(f"Std length: {lengths.std()}")
print(f"All lengths equal: {np.all(lengths == lengths[0])}")

Total wav files: 12756
Min length: 230400
Max length: 230400
Mean length: 230400.0
Std length: 0.0
All lengths equal: True


In [7]:
import os
import shutil

# 获取temp文件夹下所有.wav文件，确定总数n
temp_dir = "temp"
wav_files = [f for f in os.listdir(temp_dir) if f.endswith(".wav")]
n = len(wav_files)
print(n)

# 计算test数量
test_num = n // 11  # 10:1划分，约等于n/11
indices = list(range(n))
np.random.shuffle(indices)  # 打乱索引
test_indices = set(indices[:test_num])
train_indices = set(indices[test_num:])

# 移动文件
for idx in indices:
    for ext in [".wav", ".mid", ".npy"]:
        fname = f"{idx}{ext}"
        src = os.path.join(temp_dir, fname)
        if os.path.exists(src):
            if idx in test_indices:
                dst = os.path.join("test", fname)
            else:
                dst = os.path.join("train", fname)
            shutil.move(src, dst)

FileNotFoundError: [Errno 2] No such file or directory: 'temp'

In [6]:
import shutil

# 删除 temp 文件夹及其所有内容
shutil.rmtree("temp", ignore_errors=True)