In [1]:
import os
import random
import re
from collections import defaultdict
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Any

import fadtk
import librosa
import numpy as np
import soundfile as sf
import yaml

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch

torch.cuda.set_device(4)

In [3]:
SLAKH_DIR = Path("/data/matt/slakh2100_flac_redux")
BABYSLAKH_DIR = Path("/data/matt/babyslakh_16k")
TRACK_ID_PATTERN = re.compile(r"slakh2100_flac_redux\/(.+?)\/Track(\d+)\/mix\.flac$")
BABYSLAKH_TRACK_ID_PATTERN = re.compile(r"\/Track(\d+)\/mix\.wav$")
DEFAULT_INSTRUMENTS = ["Piano", "Bass", "Guitar", "Drums"]
DEFAULT_MIDI_TEMPO = 500000
BABYSLAKH_SAMPLE_RATE = 16000
SLAKH_SAMPLE_RATE = 44100


def get_babyslakh_paths(root_dir: Path = BABYSLAKH_DIR) -> List[Path]:
    return [
        root_dir / track_dir / "mix.wav"
        for track_dir in os.listdir(root_dir)
        if "Track" in track_dir and (root_dir / track_dir / "mix.wav").exists()
    ]


def get_slakh_paths(root_dir: Path = SLAKH_DIR) -> List[Path]:
    splits = ["train", "test", "validation"]
    paths = []
    for split_dir in os.listdir(root_dir):
        if split_dir not in splits:
            continue
        split_path = root_dir / split_dir
        for track_dir in os.listdir(split_path):
            mix_path = split_path / track_dir / "mix.flac"
            if "Track" in track_dir and mix_path.exists():
                paths.append(mix_path)
    return paths


def extract_sample_id(path: str, is_babyslakh: bool = False) -> Tuple[str, str]:
    pattern = BABYSLAKH_TRACK_ID_PATTERN if is_babyslakh else TRACK_ID_PATTERN
    match = pattern.search(path)
    if match is None:
        raise ValueError(f"Track ID not found in path: {path}")
    if is_babyslakh:
        coin_flip = random.randint(0, 1)
        split = "test" if coin_flip == 0 else "train"
        return split, match.group(1)
    return match.group(1), match.group(2)


def get_midi_program_names(track_directory: Path) -> List[str]:
    try:
        with open(track_directory / "metadata.yaml", "r") as f:
            metadata = yaml.safe_load(f)
        program_names = []
        for stem_id, stem_info in metadata["stems"].items():
            if "midi_program_name" in stem_info:
                program_names.append(stem_info["midi_program_name"])
        return program_names
    except Exception as e:
        print(f"Failed to load metadata for {track_directory}: {e}")
        return DEFAULT_INSTRUMENTS


def get_tempo(mid):
    for track in mid.tracks:
        for msg in track:
            if msg.type == "set_tempo":
                return msg.tempo
    return DEFAULT_MIDI_TEMPO


def get_bpm(track_directory: Path) -> int:
    try:
        mid = mido.MidiFile(track_directory / "all_src.mid")
        tempo = get_tempo(mid)
    except Exception as e:
        print(f"Failed to get tempo for {track_directory}: {e}")
        tempo = DEFAULT_MIDI_TEMPO
    return round(mido.tempo2bpm(tempo))


def get_condition_data(slakh_paths, is_babyslakh: bool = False) -> Dict[str, Any]:
    condition_data = defaultdict(dict)
    for audio_path in tqdm(slakh_paths):
        track_directory = audio_path.parent
        path_str = str(audio_path)
        split, track_id = extract_sample_id(path_str, is_babyslakh=is_babyslakh)
        if split == "train":
            split = "training"
        try:
            bpm = get_bpm(track_directory)
            program_names = get_midi_program_names(track_directory)
            condition_data[split][track_id] = {
                "bpm": bpm,
                "midi_program_names": program_names,
                "track_path": str(audio_path),
            }
        except Exception as e:
            print(f"Failed on {audio_path}: {e}")
    return condition_data

In [4]:
# create reference slakh subset
test_dir = SLAKH_DIR / "test"
reference_dir = Path("/data/matt/st_ref")


def create_reference_dir(
    reference_dir: Path,
    num_tracks: int = 32,
) -> None:
    reference_dir.mkdir(parents=True, exist_ok=True)
    tracks_copied = 0
    for track_dir in os.listdir(test_dir):
        if "Track" not in track_dir:
            continue
        mix_flac = test_dir / track_dir / "mix.flac"
        if not mix_flac.exists():
            continue
        audio, original_sr = librosa.load(mix_flac, sr=None, mono=False)
        resampled = librosa.resample(audio, orig_sr=original_sr, target_sr=32000)
        if resampled.ndim == 2:
            # librosa returns (channels, samples), sf expects (samples, channels)
            resampled = resampled.T

        _, track_id = extract_sample_id(str(mix_flac))
        destination = reference_dir / f"track{track_id}.wav"
        sf.write(destination, resampled, samplerate=32000)
        tracks_copied += 1
        if tracks_copied >= num_tracks:
            break

In [5]:
create_reference_dir(reference_dir, num_tracks=32)

In [6]:
embedding_model = fadtk.CLAPLaionModel("music")
fad = fadtk.FrechetAudioDistance(
    ml=embedding_model,
)

Loading HTSAT-base model config.
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Load the specified checkpoint /data/matt/miniconda3/envs/fadenv/lib/python3.12/site-packages/fadtk/.model-checkpoints/music_audioset_epoch_15_esc_90.14.pt from users.
Load Checkpoint...
logit_scale_a 	 Loaded
logit_scale_t 	 Loaded
audio_branch.spectrogram_extractor.stft.conv_real.weight 	 Loaded
audio_branch.spectrogram_extractor.stft.conv_imag.weight 	 Loaded
audio_branch.logmel_extractor.melW 	 Loaded
audio_branch.bn0.weight 	 Loaded
audio_branch.bn0.bias 	 Loaded
audio_branch.patch_embed.proj.weight 	 Loaded
audio_branch.patch_embed.proj.bias 	 Loaded
audio_branch.patch_embed.norm.weight 	 Loaded
audio_branch.patch_embed.norm.bias 	 Loaded
audio_branch.layers.0.blocks.0.norm1.weight 	 Loaded
audio_branch.layers.0.blocks.0.norm1.bias 	 Loaded
audio_branch.layers.0.blocks.0.attn.relative_position_bias_table 	 Loaded
audio_branch.layers.0.blocks.0.attn.qkv.weight 	 Loaded
audio_branch.layers.0.blocks.0.attn.qkv.bias 	 Loaded
audio_branch.layers.0.blocks.0.attn.proj.weight 	 Loaded
aud

In [8]:
BASELINE_OUTPUT_DIR = Path("/data/matt/mg_baseline_output")
FINETUNE_OUTPUT_DIR = Path("/data/matt/mg_finetune_output")

In [16]:
!ls {BASELINE_OUTPUT_DIR}

sample_0.wav   sample_16.wav  sample_22.wav  sample_29.wav  sample_6.wav
sample_10.wav  sample_17.wav  sample_23.wav  sample_2.wav   sample_7.wav
sample_11.wav  sample_18.wav  sample_24.wav  sample_30.wav  sample_8.wav
sample_12.wav  sample_19.wav  sample_25.wav  sample_31.wav  sample_9.wav
sample_13.wav  sample_1.wav   sample_26.wav  sample_3.wav
sample_14.wav  sample_20.wav  sample_27.wav  sample_4.wav
sample_15.wav  sample_21.wav  sample_28.wav  sample_5.wav


In [9]:
for dir_ in [BASELINE_OUTPUT_DIR, FINETUNE_OUTPUT_DIR, reference_dir]:
    fadtk.cache_embedding_files(
        dir_,
        embedding_model,
    )

[Frechet Audio Distance] Loading 32 audio files...
Loading HTSAT-base model config.
Loading HTSAT-base model config.
Loading HTSAT-base model config.
Loading HTSAT-base model config.
Loading HTSAT-base model config.
Loading HTSAT-base model config.
Loading HTSAT-base model config.
Loading HTSAT-base model config.
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 44.52 GiB of which 15.00 MiB is free. Process 94591 has 37.20 GiB memory in use. Process 248927 has 576.00 MiB memory in use. Process 286967 has 576.00 MiB memory in use. Process 287044 has 576.00 MiB memory in use. Process 287123 has 576.00 MiB memory in use. Process 287202 has 576.00 MiB memory in use. Process 287279 has 576.00 MiB memory in use. Process 288203 has 574.00 MiB memory in use. Process 288362 has 574.00 MiB memory in use. Process 288842 has 574.00 MiB memory in use. Process 298990 has 1.20 GiB memory in use. Including non-PyTorch memory, this process has 1.00 GiB memory in use. Of the allocated memory 577.03 MiB is allocated by PyTorch, and 24.97 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [21]:
baseline_fad_score = fad.score(
    reference_dir,
    BASELINE_OUTPUT_DIR,
)
print(baseline_fad_score)

finetune_score = fad.score(
    reference_dir,
    FINETUNE_OUTPUT_DIR,
)
print(finetune_score)

Loading embedding files from /data/matt/st_ref...


AssertionError: No files provided