# Task 1: Unconditioned Symbolic Generation


In [None]:
# !git clone https://github.com/facebookresearch/audiocraft.git
# %cd udiocraft
# !uv pip install -e .
# !uv pip install dora-search numba librosa mido PyYAML datasets

  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/home/matt/audiocraft
[2mUsing Python 3.9.21 environment at: /data/matt/miniconda3/envs/audiocraft[0m
[2K[2mResolved [1m152 packages[0m [2min 1.03s[0m[0m                                       [0m
[2K[37m⠙[0m [2mPreparing packages...[0m (0/1)                                                   
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/1)----[0m[0m     0 B/5.38 KiB                      [1A
[2K[2mPrepared [1m1 package[0m [2min 7ms[0m[0m                                                    [1A
[2mUninstalled [1m15 packages[0m [2min 250ms[0m[0m
[2K[2mInstalled [1m107 packages[0m [2min 453ms[0m[0m                             [0m
 [32m+[39m [1maiofiles[0m[2m==23.2.1[0m
 [32m+[39m [1mannotated-types[0m[2m==0.7.0[0m
 [32m+[39m [1manyio[0m[2m==4.9.0[0m
 [32m+[39m [1maudiocraft[0m[2m==1.4.0a2 (from file:///home/matt/audiocraft)[0m
 [32m+[39m [1mav[0m[2m==11.0.0[0m
 [32m+[39m [1mblis[0m[2m==0.7.11[0m
 [32m+[39m [

In [3]:
%cd /home/matt/audiocraft

/home/matt/audiocraft


In [2]:
import json
import os
import random
import re
import shutil
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import librosa
import mido
import numpy as np
import soundfile as sf
import torch
import yaml
from datasets import Dataset, load_dataset
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from audiocraft import train
from audiocraft.data.audio import audio_write
from audiocraft.models import MusicGen
from audiocraft.utils import export

Dora directory: /tmp/audiocraft_matt


In [4]:
os.environ["TMPDIR"] = "/data/matt/tmp"
PROJECT_DATA_DIR = Path("/data/matt/cse253a2")
PROJECT_DATA_DIR.mkdir(exist_ok=True)

## slakh dataset preprocessing


In [17]:
SLAKH_DIR = Path("/data/matt/slakh2100_flac_redux")
BABYSLAKH_DIR = Path("/data/matt/babyslakh_16k")
TRACK_ID_PATTERN = re.compile(r"slakh2100_flac_redux\/(.+?)\/Track(\d+)\/mix\.flac$")
BABYSLAKH_TRACK_ID_PATTERN = re.compile(r"\/Track(\d+)\/mix\.wav$")
DEFAULT_INSTRUMENTS = ["Piano", "Bass", "Guitar", "Drums"]
DEFAULT_MIDI_TEMPO = 500000
BABYSLAKH_SAMPLE_RATE = 16000
SLAKH_SAMPLE_RATE = 44100


def get_babyslakh_paths(root_dir: Path = BABYSLAKH_DIR) -> List[Path]:
    return [
        root_dir / track_dir / "mix.wav"
        for track_dir in os.listdir(root_dir)
        if "Track" in track_dir and (root_dir / track_dir / "mix.wav").exists()
    ]


def get_slakh_paths(root_dir: Path = SLAKH_DIR) -> List[Path]:
    splits = ["train", "test", "validation"]
    paths = []
    for split_dir in os.listdir(root_dir):
        if split_dir not in splits:
            continue
        split_path = root_dir / split_dir
        for track_dir in os.listdir(split_path):
            mix_path = split_path / track_dir / "mix.flac"
            if "Track" in track_dir and mix_path.exists():
                paths.append(mix_path)
    return paths


def extract_sample_id(path: str, is_babyslakh: bool = False) -> Tuple[str, str]:
    pattern = BABYSLAKH_TRACK_ID_PATTERN if is_babyslakh else TRACK_ID_PATTERN
    match = pattern.search(path)
    if match is None:
        raise ValueError(f"Track ID not found in path: {path}")
    if is_babyslakh:
        coin_flip = random.randint(0, 1)
        split = "test" if coin_flip == 0 else "train"
        return split, match.group(1)
    return match.group(1), match.group(2)


def get_midi_program_names(track_directory: Path) -> List[str]:
    try:
        with open(track_directory / "metadata.yaml", "r") as f:
            metadata = yaml.safe_load(f)
        program_names = []
        for stem_id, stem_info in metadata["stems"].items():
            if "midi_program_name" in stem_info:
                program_names.append(stem_info["midi_program_name"])
        return program_names
    except Exception as e:
        print(f"Failed to load metadata for {track_directory}: {e}")
        return DEFAULT_INSTRUMENTS


def get_tempo(mid):
    for track in mid.tracks:
        for msg in track:
            if msg.type == "set_tempo":
                return msg.tempo
    return DEFAULT_MIDI_TEMPO


def get_bpm(track_directory: Path) -> int:
    try:
        mid = mido.MidiFile(track_directory / "all_src.mid")
        tempo = get_tempo(mid)
    except Exception as e:
        print(f"Failed to get tempo for {track_directory}: {e}")
        tempo = DEFAULT_MIDI_TEMPO
    return round(mido.tempo2bpm(tempo))


def get_condition_data(slakh_paths, is_babyslakh: bool = False) -> Dict[str, Any]:
    condition_data = defaultdict(dict)
    for audio_path in tqdm(slakh_paths):
        track_directory = audio_path.parent
        path_str = str(audio_path)
        split, track_id = extract_sample_id(path_str, is_babyslakh=is_babyslakh)
        if split == "train":
            split = "training"
        try:
            bpm = get_bpm(track_directory)
            program_names = get_midi_program_names(track_directory)
            condition_data[split][track_id] = {
                "bpm": bpm,
                "midi_program_names": program_names,
                "track_path": str(audio_path),
            }
        except Exception as e:
            print(f"Failed on {audio_path}: {e}")
    return condition_data

In [7]:
# babyslakh_paths = get_babyslakh_paths()
# condition_data = get_condition_data(babyslakh_paths, is_babyslakh=True)
slakh_paths = get_slakh_paths()
condition_data = get_condition_data(slakh_paths, is_babyslakh=False)

  0%|          | 0/1710 [00:00<?, ?it/s]

100%|██████████| 1710/1710 [04:49<00:00,  5.90it/s]


In [8]:
with open("/data/matt/all_conditions.json", "w") as f:
    json.dump(condition_data, f, indent=4)

In [10]:
# Create .jsonl from the extracted features, make a train/test split, and save in the right place.


def write_jsonl(data: list[dict], file_path: Path) -> None:
    with open(file_path, "w") as f:
        for entry in data:
            f.write(json.dumps(entry) + "\n")


def prepare_slakh_data(
    split_directories: dict[str, Path],
    sr: int = SLAKH_SAMPLE_RATE,
    file_extension: str = "flac",
):
    for directory in split_directories.values():
        directory.mkdir(parents=True, exist_ok=True)

    data_lists = {
        "train": [],
        "test": [],
        "validation": [],
    }

    for split, split_data in condition_data.items():
        if split == "training":
            split = "train"
        for track_id, track_info in tqdm(split_data.items(), total=len(split_data)):
            path = Path(track_info["track_path"])
            y, sr = librosa.load(path)
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            key = np.argmax(np.sum(chroma, axis=1))
            length = librosa.get_duration(y=y, sr=sr)
            entry = {
                "key": str(key),
                "sample_rate": sr,
                "file_extension": file_extension,
                "description": "",
                "keywords": "",
                "duration": length,
                "bpm": track_info["bpm"],
                "genre": "",
                "title": "",
                "name": "",
                "instrument": ", ".join(track_info["midi_program_names"]),
                "moods": [],
                "path": str(path),
            }
            data_lists[split].append(entry)

    # print split sizes
    for split, data in data_lists.items():
        print(f"{split} size: {len(data)}")
        write_jsonl(data, split_directories[split] / "data.jsonl")

In [11]:
music_gen_slakh_directory = Path("/data/matt/music_gen_slakh")
split_directories = {
    "train": music_gen_slakh_directory / "train",
    "test": music_gen_slakh_directory / "test",
    "validation": music_gen_slakh_directory / "validation",
}
prepare_slakh_data(
    split_directories,
    sr=SLAKH_SAMPLE_RATE,
    file_extension="flac",
)

100%|██████████| 1289/1289 [23:11<00:00,  1.08s/it]
100%|██████████| 270/270 [04:49<00:00,  1.07s/it]
100%|██████████| 151/151 [02:51<00:00,  1.14s/it]

train size: 1289
test size: 151
validation size: 270





## run training with dora


In [12]:
os.listdir(music_gen_slakh_directory)

['train', 'validation', 'test']

In [13]:
baby_command = """\
CUDA_VISIBLE_DEVICES=4,5,6,7 dora -P audiocraft run \
  solver=musicgen/musicgen_base_32khz \
  +model.lm.model_scale=small \
  continue_from=//pretrained/facebook/musicgen-small \
  conditioner=text2music \
  dset=audio/babyslakh \
  dataset.num_workers=2 \
  dataset.valid.num_samples=1 \
  dataset.batch_size=2 \
  schedule.cosine.warmup=8 \
  optim.optimizer=adamw \
  optim.lr=1e-4 \
  optim.epochs=2 \
  optim.updates_per_epoch=100 \
  optim.adam.weight_decay=0.01 \
  generate.lm.prompted_samples=False \
  generate.lm.gen_gt_samples=True
"""

command = """\
CUDA_VISIBLE_DEVICES=4,5,6,7 dora -P audiocraft run \
  solver=musicgen/musicgen_base_32khz \
  +model.lm.model_scale=small \
  continue_from=//pretrained/facebook/musicgen-small \
  conditioner=text2music \
  dset=audio/slakh \
  dataset.num_workers=4 \
  dataset.valid.num_samples=32 \
  dataset.batch_size=4 \
  schedule.cosine.warmup=8 \
  optim.optimizer=adamw \
  optim.lr=1e-4 \
  optim.epochs=1 \
  optim.updates_per_epoch=1000 \
  optim.adam.weight_decay=0.01 \
  generate.lm.prompted_samples=False \
  generate.lm.gen_gt_samples=True
"""

In [35]:
!{command}

Dora directory: /tmp/audiocraft_matt
Traceback (most recent call last):
  File "/data/matt/miniconda3/envs/cse253/bin/dora", line 10, in <module>
    sys.exit(main())
  File "/data/matt/miniconda3/envs/cse253/lib/python3.9/site-packages/dora/__main__.py", line 170, in main
    args.action(args, main)
  File "/data/matt/miniconda3/envs/cse253/lib/python3.9/site-packages/dora/run.py", line 51, in run_action
    xp = main.get_xp(args.argv)
  File "/data/matt/miniconda3/envs/cse253/lib/python3.9/site-packages/dora/hydra.py", line 190, in get_xp
    delta += self._get_delta(base, cfg)
  File "/data/matt/miniconda3/envs/cse253/lib/python3.9/site-packages/dora/hydra.py", line 297, in _get_delta
    for diff in _compare_config(init, other):
  File "/data/matt/miniconda3/envs/cse253/lib/python3.9/site-packages/dora/hydra.py", line 75, in _compare_config
    yield from _compare_config(ref_value, other_value, path)
  File "/data/matt/miniconda3/envs/cse253/lib/python3.9/site-packages/dora/hydra.p

In [16]:
# original (tmp has since been set to /data/matt/tmp)
# samples_dir = Path("/tmp/audiocraft_matt/xps/ed9b1b62/samples")
# baby training run
# sig = "ed9b1b62"
sig = "12c4508d"

## export fine-tuned model params


In [15]:
checkpoints_dir = Path("/data/matt/mg_checkpoints")
v1_checkpoints_dir = checkpoints_dir / "v1/finetune"
v2_checkpoints_dir = checkpoints_dir / "v2/finetune"
checkpoints_dir = v2_checkpoints_dir

In [17]:
# Exporting .bin files from a training run:


sig = "ed9b1b62"

# from https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md#importing--exporting-models
xp = train.main.get_xp_from_sig(sig)
checkpoints_dir.mkdir(parents=True, exist_ok=True)
export.export_lm(xp.folder / "checkpoint.th", checkpoints_dir / "state_dict.bin")
export.export_pretrained_compression_model(
    "facebook/encodec_32khz", checkpoints_dir / "compression_state_dict.bin"
)

Dora directory: /tmp/audiocraft_matt


In [13]:
CKPT_DIR = Path("/data/matt/mg_checkpoints")
CKPT_DIR.mkdir(parents=True, exist_ok=True)


def export_model_checkpoint(sig: str, ckpt_d: Path = CKPT_DIR):
    ckpt_d.mkdir(parents=True, exist_ok=True)
    xp = train.main.get_xp_from_sig(sig)
    export.export_lm(xp.folder / "checkpoint.th", ckpt_d / "state_dict.bin")
    export.export_pretrained_compression_model(
        "facebook/encodec_32khz",
        ckpt_d / "compression_state_dict.bin",
    )

In [29]:
# latest (monday midnight) slakh run
slakh_scaled_sig = "f8f7a1d3"
slakh_scaled_ckpt = Path("/data/matt/mg_checkpoints/slakh_scaled")
export_model_checkpoint(
    sig="f8f7a1d3",
    ckpt_d=slakh_scaled_ckpt,
)

## Set up reference directory


In [34]:
# create slakh test split reference (st_ref)
slakh_test_dir = SLAKH_DIR / "test"
slakh_reference_dir = PROJECT_DATA_DIR / "slakh/reference"

In [36]:
def create_slakh_reference_dir(
    reference_dir: Path,
    num_tracks: int = 32,
    track_length: int = 16,
) -> list[str]:
    reference_dir.mkdir(parents=True, exist_ok=True)
    tracks_copied = 0
    track_ids = []
    for track_dir in os.listdir(slakh_test_dir):
        if "Track" not in track_dir:
            continue
        mix_flac = slakh_test_dir / track_dir / "mix.flac"
        if not mix_flac.exists():
            continue
        audio, original_sr = librosa.load(mix_flac, sr=None, mono=False)
        resampled = librosa.resample(audio, orig_sr=original_sr, target_sr=32000)
        # cut it down to the first `track_length` seconds
        if resampled.ndim == 1:
            resampled = resampled[: track_length * 32000]
        elif resampled.ndim == 2:
            resampled = resampled[:, : track_length * 32000]
            # librosa returns (channels, samples), sf expects (samples, channels)
            resampled = resampled.T
        _, track_id = extract_sample_id(str(mix_flac))
        destination = reference_dir / f"track{track_id}.wav"
        sf.write(destination, resampled, samplerate=32000)
        tracks_copied += 1
        track_ids.append(str(track_id))
        if tracks_copied >= num_tracks:
            break
    return track_ids

In [37]:
reference_track_ids = create_slakh_reference_dir(slakh_reference_dir, num_tracks=32)

In [43]:
slakh_og_metadata_path = PROJECT_DATA_DIR / "slakh/original_metadata.json"
if slakh_og_metadata_path.exists():
    with open(slakh_og_metadata_path, "r") as f:
        # Load existing metadata
        original_metadata = json.load(f)
    condition_data = original_metadata
else:
    condition_data = get_condition_data(
        get_slakh_paths(SLAKH_DIR),
        is_babyslakh=False,
    )
    with open(PROJECT_DATA_DIR / "slakh/original_metadata.json", "w") as f:
        json.dump(condition_data, f, indent=4)

In [45]:
def format_prompt(bpm: int, midi_program_names: list[str]) -> str:
    instruments = ", ".join(midi_program_names)
    return f"{bpm} BPM with {instruments}."


def create_reference_condition_data(
    condition_data: dict[str, Any],
    reference_dir: Path,
) -> Dict[str, Any]:
    """Create a dictionary with reference condition data."""
    reference_condition_data = {}
    for track_id, info in condition_data["test"].items():
        file_path = reference_dir / f"track{track_id}.wav"
        if not file_path.exists():
            continue
        reference_condition_data[track_id] = {
            "bpm": info["bpm"],
            "midi_program_names": info["midi_program_names"],
            "track_path": str(file_path),
            "prompt": format_prompt(info["bpm"], info["midi_program_names"]),
        }
    return reference_condition_data

In [46]:
rcd = create_reference_condition_data(
    condition_data,
    slakh_reference_dir,
)

## Generate New Samples


In [18]:
# select GPU
torch.cuda.set_device(4)

In [27]:
# load in baseline model
baseline = MusicGen.get_pretrained("small")



In [25]:
ft_slakh = MusicGen.get_pretrained(checkpoints_dir)

In [20]:
GENERATED_AUDIO_DIR = Path("/data/matt/mg_generated_audio")
BASELINE_OUTPUT_DIR = GENERATED_AUDIO_DIR / "baseline"
FINETUNE_OUTPUT_DIR = GENERATED_AUDIO_DIR / "finetune"

In [21]:
def _resample_and_overwrite(
    audio_file: Path,
    target_sr: Optional[int] = None,
) -> None:
    if target_sr is None:
        return
    y, sr = librosa.load(audio_file, sr=None)
    audio = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
    if audio.ndim == 2:
        audio = audio.T
    # overwrite the file with the resampled audio
    sf.write(audio_file, audio, samplerate=target_sr)


def unconditional_generate_wrapper(
    model: MusicGen,
    duration: int = 16,
    num_samples: int = 32,
    output_dir: Path = BASELINE_OUTPUT_DIR,
    batch_size: int = 4,
    target_sr: Optional[int] = None,
):
    output_dir.mkdir(parents=True, exist_ok=True)
    model.set_generation_params(duration=duration)
    for i in tqdm(range(0, num_samples, batch_size)):
        samples_in_this_batch = min(batch_size, num_samples - i)
        batch = model.generate_unconditional(num_samples=samples_in_this_batch)
        wavs = batch.cpu()
        for j, wav in enumerate(wavs):
            audio_write(
                output_dir / f"sample_{i + j}",
                wav,
                model.sample_rate,
                strategy="loudness",
            )
            _resample_and_overwrite(
                output_dir / f"sample_{i + j}.wav",
                target_sr=target_sr,
            )


def conditional_generate_wrapper(
    model: MusicGen,
    prompts: dict[str, str],
    duration: int = 16,
    output_dir: Path = BASELINE_OUTPUT_DIR,
    batch_size: int = 4,
    target_sr: Optional[int] = None,
) -> dict[str, dict[str, str]]:
    output_dir.mkdir(parents=True, exist_ok=True)
    model.set_generation_params(duration=duration)
    prompt_id_pairs = list(prompts.items())
    res = {}
    for i in tqdm(range(0, len(prompt_id_pairs), batch_size)):
        samples_in_this_batch = min(batch_size, len(prompt_id_pairs) - i)
        batch = prompt_id_pairs[i : i + samples_in_this_batch]
        prompts = []
        ids = []
        for track_id, prompt in batch:
            prompts.append(prompt)
            ids.append(track_id)
        batch = model.generate(prompts)
        wavs = batch.cpu()
        for j, (wav, track_id) in enumerate(zip(wavs, ids)):
            audio_write(
                output_dir / f"{track_id}",
                wav,
                model.sample_rate,
                strategy="loudness",
                loudness_compressor=True,
            )
            _resample_and_overwrite(
                output_dir / f"{track_id}.wav",
                target_sr=target_sr,
            )
            res[track_id] = {
                "prompt": prompts[j],
                "generated_audio_file": str(output_dir / f"{track_id}.wav"),
            }
    return res

In [55]:
# generate baseline
output_dir = BASELINE_OUTPUT_DIR / "slakh_scaled_uncond"
unconditional_generate_wrapper(
    baseline,
    duration=16,
    num_samples=32,
    output_dir=output_dir,
    batch_size=4,
)
print(output_dir)

  0%|          | 0/8 [00:00<?, ?it/s]CLIPPING /data/matt/mg_generated_audio/baseline/slakh_scaled_uncond/sample_0 happening with proba (a bit of clipping is okay): 0.0013281250139698386 maximum scale:  1.2280017137527466
CLIPPING /data/matt/mg_generated_audio/baseline/slakh_scaled_uncond/sample_1 happening with proba (a bit of clipping is okay): 0.0010410156100988388 maximum scale:  1.4098337888717651
CLIPPING /data/matt/mg_generated_audio/baseline/slakh_scaled_uncond/sample_2 happening with proba (a bit of clipping is okay): 0.009599609300494194 maximum scale:  1.4785785675048828
CLIPPING /data/matt/mg_generated_audio/baseline/slakh_scaled_uncond/sample_3 happening with proba (a bit of clipping is okay): 0.009087890386581421 maximum scale:  1.3560988903045654
 12%|█▎        | 1/8 [00:15<01:48, 15.47s/it]CLIPPING /data/matt/mg_generated_audio/baseline/slakh_scaled_uncond/sample_4 happening with proba (a bit of clipping is okay): 0.00034374999813735485 maximum scale:  2.0891764163970947

/data/matt/mg_generated_audio/baseline/slakh_scaled_uncond





In [56]:
# generate unconditional
output_dir = FINETUNE_OUTPUT_DIR / "slakh_scaled_uncond"
unconditional_generate_wrapper(
    ft_slakh,
    duration=16,
    num_samples=32,
    output_dir=output_dir,
    batch_size=4,
)
print(output_dir)

  0%|          | 0/8 [00:00<?, ?it/s]CLIPPING /data/matt/mg_generated_audio/finetune/slakh_scaled_uncond/sample_0 happening with proba (a bit of clipping is okay): 0.0013281250139698386 maximum scale:  1.789244532585144
CLIPPING /data/matt/mg_generated_audio/finetune/slakh_scaled_uncond/sample_1 happening with proba (a bit of clipping is okay): 0.00012109374802093953 maximum scale:  1.2391533851623535
CLIPPING /data/matt/mg_generated_audio/finetune/slakh_scaled_uncond/sample_2 happening with proba (a bit of clipping is okay): 0.006246093660593033 maximum scale:  1.9762834310531616
CLIPPING /data/matt/mg_generated_audio/finetune/slakh_scaled_uncond/sample_3 happening with proba (a bit of clipping is okay): 0.0003242187376599759 maximum scale:  1.3497477769851685
 12%|█▎        | 1/8 [00:09<01:06,  9.56s/it]CLIPPING /data/matt/mg_generated_audio/finetune/slakh_scaled_uncond/sample_4 happening with proba (a bit of clipping is okay): 0.002189453225582838 maximum scale:  1.5302330255508423


/data/matt/mg_generated_audio/finetune/slakh_scaled_uncond





## Conditional Generation


In [None]:
# # Example 2: text guided generation

# wavs = musicgen.generate([
#     'disco',
#     'slide guitar bluegrass',
#     'breakbeat, amen break',
#     'epic orchestral strings'
# ])

# # save and display generated audio
# for idx, one_wav in enumerate(wavs):
#     audio_write(f'{idx}', one_wav.cpu(), musicgen.sample_rate, strategy="loudness", loudness_compressor=True)
#     ipd.display(ipd.Audio(one_wav.cpu(), rate=32000))

In [48]:
rcd_prompts = {track_id: v["prompt"] for track_id, v in rcd.items()}

In [47]:
ft_slakh = MusicGen.get_pretrained(slakh_scaled_ckpt)

In [51]:
# conditional_generate_wrapper(
#     ft_slakh,
#     rcd_prompts,
#     duration=16,
#     output_dir=FINETUNE_OUTPUT_DIR / "slakh_cond_v2",
#     batch_size=4,
# )
output_info = conditional_generate_wrapper(
    ft_slakh,
    rcd_prompts,
    duration=16,
    output_dir=FINETUNE_OUTPUT_DIR / "slakh_scaled_cond",
)
output_info_path = PROJECT_DATA_DIR / "slakh/scaled_cond_output.json"
with open(output_info_path, "w") as f:
    json.dump(output_info, f, indent=4)
print(f"Output info saved to {output_info_path}")

100%|██████████| 8/8 [01:52<00:00, 14.03s/it]


In [54]:
output_info = conditional_generate_wrapper(
    baseline,
    rcd_prompts,
    duration=16,
    output_dir=BASELINE_OUTPUT_DIR / "slakh_cond1",
    batch_size=4,
)
with open(PROJECT_DATA_DIR / "slakh/baseline_cond1.json", "w") as f:
    json.dump(output_info, f, indent=4)

100%|██████████| 8/8 [01:37<00:00, 12.23s/it]


In [55]:
def combine_condition_data_with_output_audio_file_path(
    condition_data: dict[str, dict[str, Any]],
    output_dir: Path,
):
    res = {}
    for track_id, info in condition_data.items():
        target_path = output_dir / f"{track_id}.wav"
        if not target_path.exists():
            print(f"Warning: {target_path} does not exist.")
            continue
        res[track_id] = {
            "prompt": info["prompt"],
            "bpm": info["bpm"],
            "midi_program_names": info["midi_program_names"],
            "audio_file_path": str(target_path),
        }
    return res

In [58]:
clap_input_baseline = combine_condition_data_with_output_audio_file_path(
    rcd,
    BASELINE_OUTPUT_DIR / "slakh_cond_v2",
)

clap_input_finetune = combine_condition_data_with_output_audio_file_path(
    rcd,
    FINETUNE_OUTPUT_DIR / "slakh_cond_v2",
)

with open(PROJECT_DATA_DIR / "clap_input_baseline.json", "w") as f:
    json.dump(clap_input_baseline, f, indent=4)

with open(PROJECT_DATA_DIR / "clap_input_finetune.json", "w") as f:
    json.dump(clap_input_finetune, f, indent=4)

# Music Caps Conditional Generation Experiment


In [78]:
music_caps_wavs_dir = Path("/data/matt/music_caps/wavs")

In [84]:
dataset = load_dataset("google/MusicCaps", split="train")

In [86]:
dummy_id = dataset[0]["ytid"]
print(dummy_id)
file_path = music_caps_wavs_dir / f"{dummy_id}.wav"
file_path.exists()

-0Gj8-vB1q4


In [93]:
y, sr = librosa.load(file_path)
y.shape

(220500,)

In [95]:
MUSIC_CAPS_SR = 16000
MUSIC_CAPS_DURATION = 10


def prep_music_caps(
    dataset: Dataset,
    wav_dir: Path = music_caps_wavs_dir,
    output_dir: Path = music_caps_wavs_dir.parent / "audiocraft",
    train_split_size: int = 2048,
    test_split_size: int = 32,
) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)
    dataset.shuffle(seed=42)  # Shuffle the dataset for randomness
    train_dataset = dataset.select(range(train_split_size))
    test_dataset = dataset.select(
        range(train_split_size, train_split_size + test_split_size)
    )
    prep_music_caps_split(train_dataset, "train", wav_dir, output_dir)
    prep_music_caps_split(test_dataset, "test", wav_dir, output_dir)


def prep_music_caps_split(
    dataset: Dataset,
    split: str,
    wav_dir: Path = music_caps_wavs_dir,
    output_dir: Path = music_caps_wavs_dir.parent / "audiocraft",
) -> None:
    split_data = []
    for entry in tqdm(dataset):
        ytid = entry["ytid"]
        wav_path = wav_dir / f"{ytid}.wav"
        if not wav_path.exists():
            print(f"Warning: {wav_path} does not exist.")
            continue

        # use librosa to estimate key
        y, sr = librosa.load(wav_path, sr=None)
        chroma = librosa.feature.chroma_stft(y=y, sr=sr)
        key = np.argmax(np.sum(chroma, axis=1))

        ac_entry = {
            "key": str(key),
            "sample_rate": MUSIC_CAPS_SR,
            "file_extension": "wav",
            "description": entry["caption"],
            "keywords": "",
            "duration": MUSIC_CAPS_DURATION,
            "bpm": "",
            "genre": "",
            "title": "",
            "name": "",
            "instrument": "",
            "moods": [],
            "path": str(wav_path),
        }
        split_data.append(ac_entry)
    split_dir = output_dir / split
    split_dir.mkdir(parents=True, exist_ok=True)
    write_jsonl(split_data, output_dir / split / "data.jsonl")

In [96]:
prep_music_caps(
    dataset,
    wav_dir=music_caps_wavs_dir,
    output_dir=music_caps_wavs_dir.parent / "audiocraft",
    train_split_size=2048,
    test_split_size=32,
)

  7%|▋         | 147/2048 [00:06<01:22, 23.06it/s]



 10%|▉         | 202/2048 [00:08<00:55, 33.00it/s]



 19%|█▉        | 397/2048 [00:17<01:08, 24.19it/s]



 35%|███▌      | 724/2048 [00:31<00:51, 25.88it/s]



 40%|████      | 821/2048 [00:35<00:47, 26.04it/s]



 42%|████▏     | 855/2048 [00:36<00:51, 23.39it/s]



 56%|█████▋    | 1157/2048 [00:48<00:35, 25.26it/s]



 59%|█████▉    | 1210/2048 [00:50<00:33, 25.38it/s]



 63%|██████▎   | 1282/2048 [00:53<00:31, 24.24it/s]



 65%|██████▌   | 1336/2048 [00:56<00:31, 22.58it/s]



  return pitch_tuning(
 87%|████████▋ | 1783/2048 [01:14<00:09, 26.97it/s]



 94%|█████████▍| 1924/2048 [01:20<00:04, 25.88it/s]



100%|██████████| 2048/2048 [01:25<00:00, 23.92it/s]
100%|██████████| 32/32 [00:01<00:00, 21.60it/s]


## run training with dora framework


In [99]:
music_caps_train_command = """\
TMPDIR=/data/matt/tmp CUDA_VISIBLE_DEVICES=4,5,6,7 dora -P audiocraft run \
  solver=musicgen/musicgen_base_32khz \
  +model.lm.model_scale=small \
  continue_from=//pretrained/facebook/musicgen-small \
  conditioner=text2music \
  dset=audio/music_caps \
  dataset.num_workers=4 \
  dataset.valid.num_samples=32 \
  dataset.batch_size=4 \
  schedule.cosine.warmup=8 \
  optim.optimizer=adamw \
  optim.lr=1e-4 \
  optim.epochs=3 \
  optim.updates_per_epoch=1000 \
  optim.adam.weight_decay=0.01 \
  generate.lm.prompted_samples=False \
  generate.lm.gen_gt_samples=True
"""

In [100]:
# !{music_caps_train_command}

In [102]:
# export checkpoint
MC_V1_CKPT_DIR = CKPT_DIR / "mc_v1"
MC_V1_SIG = "40b4c24f"
export_model_checkpoint(MC_V1_SIG, ckpt_d=MC_V1_CKPT_DIR)

## Generate new samples from MusicCaps fine-tuned model


In [103]:
torch.cuda.set_device(5)
baseline_mc = MusicGen.get_pretrained("small")
mc_ckpt = MC_V1_CKPT_DIR
ft_mc = MusicGen.get_pretrained(mc_ckpt)



In [110]:
def get_music_cap_test_prompts():
    music_cap_prompts = {}
    mc_test_jsonl_path = music_caps_wavs_dir.parent / "audiocraft/test/data.jsonl"
    lines = mc_test_jsonl_path.read_text().splitlines()
    for i, line in enumerate(lines):
        entry = json.loads(line)
        id_ = f"mc_test{i}"
        music_cap_prompts[id_] = entry["description"]
    return music_cap_prompts


music_cap_prompts = get_music_cap_test_prompts()

In [116]:
conditional_generate_wrapper(
    model=baseline_mc,
    prompts=music_cap_prompts,
    duration=MUSIC_CAPS_DURATION,
    output_dir=BASELINE_OUTPUT_DIR / "mc_cond_v1",
    target_sr=MUSIC_CAPS_SR,
)

  0%|          | 0/8 [00:14<?, ?it/s]


TypeError: list indices must be integers or slices, not str

## music genre dataset


In [2]:
music_genre_dataset = load_dataset("lewtun/music_genres")

Downloading data: 100%|██████████| 16/16 [01:41<00:00,  6.33s/files]
Generating train split: 100%|██████████| 19909/19909 [00:39<00:00, 508.62 examples/s]
Generating test split: 100%|██████████| 5076/5076 [00:09<00:00, 538.92 examples/s]


In [5]:
mge0 = music_genre_dataset["train"][0]
isinstance(mge0["audio"]["array"], np.ndarray)

True

In [6]:
mge0

{'audio': {'path': None,
  'array': array([ 3.97140170e-07,  7.30310376e-07,  7.56406820e-07, ...,
         -1.19636677e-01, -1.16811883e-01, -1.12441715e-01]),
  'sampling_rate': 44100},
 'song_id': 0,
 'genre_id': 0,
 'genre': 'Electronic'}

In [22]:
GENRE_AUDIO_FILES = PROJECT_DATA_DIR / "music_genre_audio_files"
GENRE_AUDIO_FILES.mkdir(parents=True, exist_ok=True)


def prepare_genre_data(
    dataset: Dataset,
    output_dir: Path,
    train_split_size: int = 2048,
    test_split_size: int = 32,
    target_sr: int = 32000,
) -> None:
    dataset = dataset.shuffle(seed=42)  # Shuffle the dataset for randomness
    train_dataset = dataset.select(range(train_split_size))
    test_dataset = dataset.select(
        range(train_split_size, train_split_size + test_split_size)
    )
    prep_genre_split(train_dataset, "train", output_dir, target_sr)
    prep_genre_split(test_dataset, "test", output_dir, target_sr)


def prep_genre_split(
    dataset: Dataset,
    split: str,
    output_dir: Path,
    target_sr: int = 32000,
) -> None:
    split_data = []
    for i, e in tqdm(enumerate(dataset), total=len(dataset)):
        # step 1: write out the audio (array) to a file
        audio_array = e["audio"]["array"]
        sr = e["audio"]["sampling_rate"]
        audio_file_path = GENRE_AUDIO_FILES / f"{i}.wav"
        resampled_audio = librosa.resample(audio_array, orig_sr=sr, target_sr=target_sr)
        sf.write(
            audio_file_path,
            resampled_audio,
            samplerate=sr,
            format="WAV",
        )
        # step 2: write out the metadata in the format expected by MusicGen
        entry = {
            "key": "",
            "sample_rate": sr,
            "file_extension": "wav",
            "description": "",
            "keywords": "",
            "duration": 30,
            "bpm": "",
            "genre": e["genre"],
            "title": "",
            "name": "",
            "instrument": "",
            "moods": [],
            "path": str(audio_file_path),
        }
        split_data.append(entry)
    split_dir = output_dir / split
    split_dir.mkdir(parents=True, exist_ok=True)
    write_jsonl(split_data, split_dir / "data.jsonl")


In [27]:
ac_genre_dir = PROJECT_DATA_DIR / "genre_audiocraft"
ac_genre_dir.mkdir(parents=True, exist_ok=True)

prepare_genre_data(
    music_genre_dataset["train"],
    output_dir=ac_genre_dir,
    train_split_size=2048,
    test_split_size=32,
    target_sr=32000,
)

  0%|          | 0/2048 [00:00<?, ?it/s]

100%|██████████| 2048/2048 [05:47<00:00,  5.89it/s]
100%|██████████| 32/32 [00:04<00:00,  6.78it/s]


## music genre generation time please please please


In [60]:
def get_genre_prompts():
    genre_prompts = {}
    genre_jsonl_path = ac_genre_dir / "test/data.jsonl"
    lines = genre_jsonl_path.read_text().splitlines()
    for i, line in enumerate(lines):
        entry = json.loads(line)
        id_ = f"genre_test{i}"
        genre_prompts[id_] = entry["genre"]
    return genre_prompts


genre_prompts = get_genre_prompts()

In [57]:
# /tmp/audiocraft_matt/xps/49fd2443/checkpoint.th
genre_sig = "49fd2443"
genre_checkpoint_dir = CKPT_DIR / "music_genre_v1e1"
export_model_checkpoint(
    genre_sig,
    ckpt_d=genre_checkpoint_dir,
)

In [58]:
genre_v1e1 = MusicGen.get_pretrained(genre_checkpoint_dir)

In [63]:
genre_prompts

{'genre_test0': 'Pop',
 'genre_test1': 'Rock',
 'genre_test2': 'International',
 'genre_test3': 'Hip-Hop',
 'genre_test4': 'Experimental',
 'genre_test5': 'Electronic',
 'genre_test6': 'Punk',
 'genre_test7': 'Hip-Hop',
 'genre_test8': 'Chiptune / Glitch',
 'genre_test9': 'Folk',
 'genre_test10': 'Old-Time / Historic',
 'genre_test11': 'Hip-Hop',
 'genre_test12': 'Punk',
 'genre_test13': 'Chiptune / Glitch',
 'genre_test14': 'Rock',
 'genre_test15': 'Punk',
 'genre_test16': 'Pop',
 'genre_test17': 'Electronic',
 'genre_test18': 'Electronic',
 'genre_test19': 'Rock',
 'genre_test20': 'Instrumental',
 'genre_test21': 'Jazz',
 'genre_test22': 'Punk',
 'genre_test23': 'Folk',
 'genre_test24': 'Hip-Hop',
 'genre_test25': 'Rock',
 'genre_test26': 'Punk',
 'genre_test27': 'Blues',
 'genre_test28': 'Classical',
 'genre_test29': 'Rock',
 'genre_test30': 'Rock',
 'genre_test31': 'Pop'}

In [64]:
unconditional_generate_wrapper(
    model=genre_v1e1,
    duration=30,
    num_samples=8,
    output_dir=FINETUNE_OUTPUT_DIR / "genre_uncond_v1e1",
)

  0%|          | 0/2 [00:00<?, ?it/s]CLIPPING /data/matt/mg_generated_audio/finetune/genre_uncond_v1e1/sample_1 happening with proba (a bit of clipping is okay): 1.0416666782475659e-06 maximum scale:  1.0830597877502441
CLIPPING /data/matt/mg_generated_audio/finetune/genre_uncond_v1e1/sample_2 happening with proba (a bit of clipping is okay): 3.12499992105586e-06 maximum scale:  1.0476261377334595
CLIPPING /data/matt/mg_generated_audio/finetune/genre_uncond_v1e1/sample_3 happening with proba (a bit of clipping is okay): 0.0004385416687000543 maximum scale:  1.5941087007522583
 50%|█████     | 1/2 [00:20<00:20, 20.20s/it]CLIPPING /data/matt/mg_generated_audio/finetune/genre_uncond_v1e1/sample_5 happening with proba (a bit of clipping is okay): 0.0006447916384786367 maximum scale:  1.8232024908065796
100%|██████████| 2/2 [00:37<00:00, 18.57s/it]


In [62]:
output_dir = FINETUNE_OUTPUT_DIR / "genre_cond_v1e1"
genre_cond_outputs = conditional_generate_wrapper(
    model=genre_v1e1,
    prompts=genre_prompts,
    duration=30,
    output_dir=output_dir,
)
print(f"Genre conditional outputs saved to {output_dir}")
with open(PROJECT_DATA_DIR / "genre_cond_v1e1_outputs.json", "w") as f:
    json.dump(genre_cond_outputs, f, indent=4)

100%|██████████| 8/8 [02:59<00:00, 22.48s/it]

Genre conditional outputs saved to /data/matt/mg_generated_audio/finetune/genre_uncond_v1e1





In [None]:
output_dir = FINETUNE_OUTPUT_DIR / "genre_cond1"
genre_cond_outputs = conditional_generate_wrapper(
    model=baseline,
    prompts=genre_prompts,
    duration=30,
    output_dir=output_dir,
)
print(f"Genre conditional outputs saved to {output_dir}")
with open(PROJECT_DATA_DIR / "genre_cond_v1e1_outputs.json", "w") as f:
    json.dump(genre_cond_outputs, f, indent=4)

In [39]:
with open("/data/matt/pg_bpm.json") as f:
    bpm_info = json.load(f)

In [42]:
bpm_info_fixed = {}
for k, v in bpm_info.items():
    k_fixed = k.replace("bespoke_data", "bespoke30")
    bpm_info_fixed[k_fixed] = v


In [43]:
bpm_info_fixed

{'/data/matt/bespoke30/piano_00.wav': 117,
 '/data/matt/bespoke30/piano_01.wav': 108,
 '/data/matt/bespoke30/piano_02.wav': 77,
 '/data/matt/bespoke30/piano_03.wav': 105,
 '/data/matt/bespoke30/piano_04.wav': 110,
 '/data/matt/bespoke30/piano_05.wav': 98,
 '/data/matt/bespoke30/piano_06.wav': 92,
 '/data/matt/bespoke30/piano_07.wav': 100,
 '/data/matt/bespoke30/piano_08.wav': 69,
 '/data/matt/bespoke30/piano_09.wav': 94,
 '/data/matt/bespoke30/piano_10.wav': 128,
 '/data/matt/bespoke30/piano_11.wav': 90,
 '/data/matt/bespoke30/piano_12.wav': 115,
 '/data/matt/bespoke30/piano_13.wav': 65,
 '/data/matt/bespoke30/piano_14.wav': 104,
 '/data/matt/bespoke30/piano_15.wav': 120,
 '/data/matt/bespoke30/piano_16.wav': 110,
 '/data/matt/bespoke30/piano_17.wav': 120,
 '/data/matt/bespoke30/piano_18.wav': 120,
 '/data/matt/bespoke30/piano_19.wav': 69,
 '/data/matt/bespoke30/guitar_00.wav': 120,
 '/data/matt/bespoke30/guitar_01.wav': 50,
 '/data/matt/bespoke30/guitar_02.wav': 80,
 '/data/matt/bespo

In [51]:
def prep_bespoke(
    data_dir: Path,
    output_dir: Path,
    num_train_samples: int = 15,
    num_test_samples: int = 5,
) -> list[str]:
    train_data = []
    test_data = []
    test_ids = []
    train_inst_counts = defaultdict(lambda: 0)
    for wav_file in os.listdir(data_dir):
        if "piano" in wav_file:
            inst = "Piano"
            prefix = "piano_"
        else:
            assert "guitar" in wav_file
            inst = "Guitar"
            prefix = "guitar_"
        id_ = wav_file.split(".")[0].replace(prefix, "")
        path_ = str(data_dir / wav_file)
        bpm = bpm_info_fixed.get(path_, "")

        ac_entry = {
            "key": "",
            "sample_rate": SLAKH_SAMPLE_RATE,
            "file_extension": "wav",
            "description": inst,
            "keywords": "",
            "duration": 30,
            "bpm": bpm,
            "genre": "",
            "title": "",
            "name": "",
            "instrument": inst,
            "moods": [],
            "path": str(data_dir / wav_file),
        }
        inst_count = train_inst_counts[inst]
        if inst_count < num_train_samples:
            train_inst_counts[inst] += 1
            train_data.append(ac_entry)
        else:
            test_ids.append(f"{inst}_{id_}.wav")
            test_data.append(ac_entry)

    for split in ("train", "test"):
        split_dir = output_dir / split
        split_dir.mkdir(parents=True, exist_ok=True)
        split_data = train_data if split == "train" else test_data
        write_jsonl(split_data, output_dir / split / "data.jsonl")

    return test_ids

In [48]:
BESPOKE_DIR = Path("/data/matt/bespoke30")

In [52]:
prep_bespoke(
    data_dir=BESPOKE_DIR,
    output_dir=PROJECT_DATA_DIR / "audiocraft_bespoke30bpm",
)

['Guitar_04.wav',
 'Guitar_15.wav',
 'Piano_19.wav',
 'Piano_14.wav',
 'Piano_12.wav',
 'Piano_15.wav',
 'Piano_04.wav',
 'Guitar_14.wav',
 'Guitar_06.wav',
 'Guitar_18.wav']

In [56]:
# [06-03 10:13:47][audiocraft.utils.checkpoint][INFO] - Checkpoint saved to /tmp/audiocraft_matt/xps/d7ebad35/checkpoint.th
# b_sig = "d7ebad35"
# b_checkpoint_dir = CKPT_DIR / "bespoke30_v1"
# export_model_checkpoint(
#     b_sig,
#     ckpt_d=b_checkpoint_dir,
# )

# [06-03 11:09:42][audiocraft.utils.checkpoint][INFO] - Checkpoint saved to /tmp/audiocraft_matt/xps/d703f0e0/checkpoint.th
b_sig = "d703f0e0"
b_checkpoint_dir = CKPT_DIR / "bespoke30_bpm"
export_model_checkpoint(
    b_sig,
    ckpt_d=b_checkpoint_dir,
)

In [60]:
pg_test_set = [
    "Guitar_04.wav",
    "Guitar_15.wav",
    "Piano_19.wav",
    "Piano_14.wav",
    "Piano_12.wav",
    "Piano_15.wav",
    "Piano_04.wav",
    "Guitar_14.wav",
    "Guitar_06.wav",
    "Guitar_18.wav",
]
pg_prompts = {}
pg_ref_dir = PROJECT_DATA_DIR / "pg_test_ref"
pg_ref_dir.mkdir(parents=True, exist_ok=True)


for wav_file in pg_test_set:
    wav_file_p = BESPOKE_DIR / wav_file.lower()
    shutil.copy(wav_file_p, pg_ref_dir / wav_file.lower())
    wav_file_s = str(wav_file_p)
    bpm = bpm_info_fixed.get(wav_file_s, "")
    pg_prompts[wav_file.lower().replace(".wav", "")] = (
        f"{bpm} BPM with {wav_file.split('_')[0]}."
    )


# pg_prompts = {}
# for i in range(5):
#     pg_prompts[f"piano_{i}"] = "Piano"
#     pg_prompts[f"guitar_{i}"] = "Guitar"

In [55]:
pg_prompts

{'guitar_04': '70 BPM with Guitar.',
 'guitar_15': '122 BPM with Guitar.',
 'piano_19': '69 BPM with Piano.',
 'piano_14': '104 BPM with Piano.',
 'piano_12': '115 BPM with Piano.',
 'piano_15': '120 BPM with Piano.',
 'piano_04': '110 BPM with Piano.',
 'guitar_14': '135 BPM with Guitar.',
 'guitar_06': '117 BPM with Guitar.',
 'guitar_18': '129 BPM with Guitar.'}

In [57]:
b_model = MusicGen.get_pretrained(b_checkpoint_dir)

In [58]:
b_model_cond_out = conditional_generate_wrapper(
    model=b_model,
    prompts=pg_prompts,
    duration=10,
    output_dir=FINETUNE_OUTPUT_DIR / "bespoke_cond_bpm",
)

with open(PROJECT_DATA_DIR / "bespoke_ft_cond_bpm_outputs.json", "w") as f:
    json.dump(b_model_cond_out, f, indent=4)

100%|██████████| 3/3 [00:34<00:00, 11.61s/it]


In [37]:
print(PROJECT_DATA_DIR / "bespoke_ft_cond_v1_outputs.json")

/data/matt/cse253a2/bespoke_ft_cond_v1_outputs.json


In [59]:
pg_ft_out = conditional_generate_wrapper(
    model=baseline,
    prompts=pg_prompts,
    duration=10,
    output_dir=BASELINE_OUTPUT_DIR / "bespoke_cond_bpm",
)

with open(PROJECT_DATA_DIR / "bespoke_cond_bpm_outputs.json", "w") as f:
    json.dump(pg_ft_out, f, indent=4)

100%|██████████| 3/3 [00:34<00:00, 11.37s/it]


In [61]:
pg_prompts

{'guitar_04': '70 BPM with Guitar.',
 'guitar_15': '122 BPM with Guitar.',
 'piano_19': '69 BPM with Piano.',
 'piano_14': '104 BPM with Piano.',
 'piano_12': '115 BPM with Piano.',
 'piano_15': '120 BPM with Piano.',
 'piano_04': '110 BPM with Piano.',
 'guitar_14': '135 BPM with Guitar.',
 'guitar_06': '117 BPM with Guitar.',
 'guitar_18': '129 BPM with Guitar.'}

In [38]:
print(PROJECT_DATA_DIR / "bespoke_cond_v1_outputs.json")

/data/matt/cse253a2/bespoke_cond_v1_outputs.json
