### Setup

Follow the setup instructions to create the pipenv environment, then connect this notebook to the
Python kernel in the "`music-interpolation-...`" environment.

In [None]:
import torch
from IPython.display import Audio
from music_interpolation.audio import load_audio, resample, to_mono_resampled, time_stretch, trim_samples_to_match
from music_interpolation.encodec_interpolation import EncodecInterpolation
from music_interpolation.beats import tempo_beats_downbeats

AUDIO_A_PATH = "../tests/data/house-equanimity.mp3"
AUDIO_B_PATH = "../tests/data/they-know-me.mp3"

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
interp = EncodecInterpolation(device=device)
print(f"Loaded {interp.sampling_rate} Hz interpolation model on {device}")

In [None]:
# pyright: basic

# Load the audio files into raw waveform numpy arrays
audio_a, orig_sr_a = load_audio(AUDIO_A_PATH)
audio_b, orig_sr_b = load_audio(AUDIO_B_PATH)

Audio(audio_a, rate=orig_sr_a)

In [None]:
Audio(audio_b, rate=orig_sr_b)

In [None]:
# Downmix to mono 44.1kHz for audio analysis
audio_a_mono = to_mono_resampled(audio_a, orig_sr_a, 44100)
audio_b_mono = to_mono_resampled(audio_b, orig_sr_b, 44100)

audio_a_mono.shape, audio_b_mono.shape

In [None]:
# Compute tempo, confidence, beat positions, and beat positions (1-4) for each
tempo_a, tempo_a_confidence, beats_downbeats_a = tempo_beats_downbeats(audio_a_mono)
tempo_b, tempo_b_confidence, beats_downbeats_b = tempo_beats_downbeats(audio_b_mono)

beats_downbeats_a.shape, beats_downbeats_b.shape

In [None]:
# Create new arrays of bar positions for each audio file. beats_downbeats_a has shape
# (num_beats, 2) where the first column is beat positions in seconds and the second
# column is the bar index (1-4). We want to create a new array of shape (num_bars,)
# where each element is the beat position of the first beat in the bar
bars_a = beats_downbeats_a[beats_downbeats_a[:, 1] == 1, 0]
bars_b = beats_downbeats_b[beats_downbeats_b[:, 1] == 1, 0]

print(
    f"Tempo of audio_a: {tempo_a} bpm (confidence: {tempo_a_confidence * 100:.1f}%), "
    f"{beats_downbeats_a.shape[0]} beats, {bars_a.shape[0]} bars"
)
print(
    f"Tempo of audio_b: {tempo_b} bpm (confidence: {tempo_b_confidence * 100:.1f}%), "
    f"{beats_downbeats_b.shape[0]} beats, {bars_b.shape[0]} bars"
)

In [None]:
# Manually resample (if needed) instead of at load time to enable the highest
# quality resampler
if orig_sr_a != interp.sampling_rate:
    audio_resampled_a = resample(audio_a, orig_sr_a, interp.sampling_rate)
else:
    audio_resampled_a = audio_a
if orig_sr_b != interp.sampling_rate:
    audio_resampled_b = resample(audio_b, orig_sr_b, interp.sampling_rate)
else:
    audio_resampled_b = audio_b

audio_resampled_a.shape, audio_resampled_b.shape

In [None]:
# Time stretch track_b to match the tempo of track_a
tempo_ratio = tempo_a / tempo_b
if tempo_ratio != 1.0:
    print(f"Time stretching audio_b by {tempo_ratio:.3f}x")
    audio_stretched_b = time_stretch(audio_resampled_b, tempo_ratio)
    bars_stretched_b = bars_b / tempo_ratio
else:
    audio_stretched_b = audio_resampled_b
    bars_stretched_b = bars_b

Audio(audio_stretched_b, rate=interp.sampling_rate)

In [None]:
# Define the start and end of the interpolation in bars (1 bar = 4 beats)
bar_start_a = 36
bar_start_b = 12
bar_count = 8

# Calculate the start and end sample positions
start_a = int(bars_a[bar_start_a] * interp.sampling_rate)
start_b = int(bars_stretched_b[bar_start_b] * interp.sampling_rate)
end_a = int(bars_a[bar_start_a + bar_count] * interp.sampling_rate)
end_b = int(bars_stretched_b[bar_start_b + bar_count] * interp.sampling_rate)

# Extract the audio for the interpolation
audio_overlap_a = audio_resampled_a[:, start_a:end_a]
audio_overlap_b = audio_stretched_b[:, start_b:end_b]

print(f"audio_overlap_a = {audio_overlap_a.shape[1] / interp.sampling_rate:.3f} seconds")
print(f"audio_overlap_b = {audio_overlap_b.shape[1] / interp.sampling_rate:.3f} seconds")

In [None]:
# Trim any rounding error frames so the samples exactly match
audio_overlap_trimmed_a, audio_overlap_trimmed_b = trim_samples_to_match(audio_overlap_a, audio_overlap_b)

print(audio_overlap_trimmed_a.shape, audio_overlap_trimmed_b.shape)
Audio(audio_overlap_trimmed_a, rate=interp.sampling_rate)

In [None]:
Audio(audio_overlap_trimmed_b, rate=interp.sampling_rate)

In [None]:
# Extract up to four bars of audio from audio_a leading up to the start of the
# interpolation
leadup_bars = 4
leadup_samples_a = int(bars_a[bar_start_a - leadup_bars] * interp.sampling_rate)
leadup_a = audio_resampled_a[:, leadup_samples_a:start_a]

print(leadup_a.shape)
Audio(leadup_a, rate=interp.sampling_rate)

In [None]:
audio_c = interp.interpolate(audio_overlap_trimmed_a, audio_overlap_trimmed_b)

Audio(audio_c, rate=interp.sampling_rate)

In [None]:
from audiocraft.models.musicgen import MusicGen

print(f"Loading MusicGen model ({device})")
model = MusicGen.get_pretrained("melody", device)

In [None]:
from music_interpolation.musicgen import generate_continuation_with_chroma

CFG_COEF = 3
TEMPERATURE = 1

leadup_duration = leadup_a.shape[1] / interp.sampling_rate
overlap_duration = audio_c.shape[1] / interp.sampling_rate
total_duration = leadup_duration + overlap_duration

print(
    f"Generating {leadup_duration:.1f} + {overlap_duration:.1f} = "
    f"{total_duration:.1f} seconds of audio"
)
model.set_generation_params(duration=total_duration, cfg_coef=CFG_COEF, temperature=TEMPERATURE)
prompt = torch.tensor(leadup_a)
melody = torch.tensor(audio_c)
melody_wavs = melody[None]
wav = generate_continuation_with_chroma(
    model, prompt, interp.sampling_rate, None, melody_wavs, interp.sampling_rate, progress=True
)

wav = wav[0].cpu().numpy()
print(wav.shape)
Audio(wav, rate=model.sample_rate)