# MusicGen
Welcome to MusicGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen in different settings.

First, we start by initializing MusicGen, you can choose a model from the following selection:
1. `facebook/musicgen-small` - 300M transformer decoder.
2. `facebook/musicgen-medium` - 1.5B transformer decoder.
3. `facebook/musicgen-melody` - 1.5B transformer decoder also supporting melody conditioning.
4. `facebook/musicgen-large` - 3.3B transformer decoder.

We will use the `facebook/musicgen-small` variant for the purpose of this demonstration.

In [1]:
from audiocraft.models import MusicGen
#from audiocraft.models import MultiBandDiffusion
import torch, torchaudio
from audiocraft.utils.notebook import display_audio
from audiocraft.data.audio import audio_write

USE_DIFFUSION_DECODER = False
# Using small model, better results would be obtained with `medium` or `large`.
model = MusicGen.get_pretrained('facebook/musicgen-small')
#if USE_DIFFUSION_DECODER:
#    mbd = MultiBandDiffusion.get_mbd_musicgen()

  from .autonotebook import tqdm as notebook_tqdm


Next, let us configure the generation parameters. Specifically, you can control the following:
* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.
* `top_k` (int, optional): top_k used for sampling. Defaults to 250.
* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.
* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.
* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.

When left unchanged, MusicGen will revert to its default parameters.

In [2]:
output_duration = 10

model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=output_duration,
    temperature=1.0
)

Next, we can go ahead and start generating music using one of the following modes:
* Unconditional samples using `model.generate_unconditional`
* Music continuation using `model.generate_continuation`
* Text-conditional samples using `model.generate`
* Melody-conditional samples using `model.generate_with_chroma`

### Music Continuation

In [3]:
from madmom.features.downbeats import DBNDownBeatTrackingProcessor
from madmom.features.downbeats import RNNDownBeatProcessor
import os
import soundfile as sf
import pyrubberband as pyrb
import librosa

proc = DBNDownBeatTrackingProcessor(beats_per_bar=4, fps = 100, verbose=False)

def extract_loop(file_path, desired_bpm, num_bars=2):
    try:
        _, sr = librosa.core.load(file_path, sr=None) # sr = None will retrieve the original sampling rate = 44100
    except:
        print('load file failed')
        return None
    try:
        act = RNNDownBeatProcessor(verbose=False)(file_path)
        down_beat=proc(act, verbose=False) # [..., 2] 2d-shape numpy array
    except Exception as exp:
        print('except happended', exp)
        return None
    count = 0
    name = file_path.replace('.wav', '')
    for i in range(down_beat.shape[0]):
        if down_beat[i][1] == 1 and i + 4*num_bars < down_beat.shape[0] and down_beat[i+4*num_bars][1] == 1:
            start_time = down_beat[i][0]
            end_time = down_beat[i + 4*num_bars][0]
            count += 1
            out_path = os.path.join("./", f'{name}_{count}.wav')
            y_one_bar, _ = librosa.core.load(file_path, offset=start_time, duration = end_time - start_time, sr=None)
            desired_duration = 60./desired_bpm * (4*num_bars)
            y_stretch = pyrb.time_stretch(y_one_bar, sr,  (end_time - start_time) / desired_duration)
            sf.write(out_path, y_stretch, sr)
            return out_path

In [4]:
# # You can also use any audio from a file. Make sure to trim the file if it is too long!
# from audiocraft.data.audio import audio_write

# prompt_waveform, prompt_sr = torchaudio.load("../assets/beat.wav")
# prompt_duration = 4
# print(prompt_waveform.shape)

# prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]
# prompt_waveform = prompt_waveform[None, ...].repeat(3, 1, 1)

# output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, descriptions=["reggae music on the tropical beach", "evocative jazz music with women voice singing emotionally like Bjork", "funky beats"], progress=True, return_tokens=True)
# print(output[0].shape)
# sr = model.sample_rate
# start_pos = prompt_duration * sr 
# display_audio(output[0], sample_rate=32000)

# for index in range(output[0].shape[0]):
#     file_path = "output"
#     output_data = output[0][index].cpu().squeeze()
#     audio_write(file_path, output_data, sr, strategy="loudness", loudness_compressor=True)
#     #print(output_data.shape)
#     #sf.write(file_path, output_data, sr)
#     #extract_loop(file_path + ".wav", 120)

torch.Size([2, 768000])
torch.Size([3, 1, 512000])


In [5]:
batch_size = 3

def _loop_gen(filepath, duration, bpm, temperature=1.0, description="jazzy beat"):

    model.set_generation_params(
        use_sampling=True,
        top_k=250,
        duration=output_duration,
        temperature=temperature
    )

    print("description:", description)
    prompt_waveform, prompt_sr = torchaudio.load(filepath)
    prompt_duration = duration/1000.
    prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]
    prompt_waveform = prompt_waveform[None, ...].repeat(batch_size, 1, 1)

    output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, descriptions=["heavy bass", "deep house with piano", "jungle, dubstep"], progress=True, return_tokens=True)
    print(output[0].shape)
    sr = model.sample_rate
    start_pos = int(prompt_duration * sr) 
    display_audio(output[0][:,:,start_pos:], sample_rate=32000)

    for index in range(batch_size):
        file_path = "/mnt/c/tmp/output%d" % index
        output_data = output[0][index].cpu().squeeze()[start_pos:]
        audio_write(file_path, output_data, sr, strategy="loudness", loudness_compressor=True)
        outpath = extract_loop(file_path + ".wav", bpm)
    return outpath

In [6]:

# %%
import os 
import time
from threading import Thread
from pythonosc import dispatcher
from pythonosc import osc_server, udp_client
from IPython.display import clear_output


def loop_gen(unused_addr, filepath, duration, bpm, temperature=1.0, description="funky beat"):
    print(filepath)
    if os.path.exists(filepath):
        outpath = _loop_gen(filepath, duration, bpm, temperature, description)
        if outpath is not None:    
            client.send_message("/generated", (outpath))
    else:
        print("file not found", filepath)
    clear_output(wait=True)

dispatcher = dispatcher.Dispatcher()
dispatcher.map("/loop_gen", loop_gen)

#%%

#client = udp_client.SimpleUDPClient('127.0.0.1', 10018)
client = udp_client.SimpleUDPClient('10.0.1.16', 10018)


def beacon():
    while True:
        client.send_message("/heartbeat", 1)
        time.sleep(1.0)
t1 = Thread(target = beacon)
t1.setDaemon(True)
t1.start()

server = osc_server.ThreadingOSCUDPServer(
    ('172.17.140.208', 10026), dispatcher)
print("Serving on {}".format(server.server_address))
server.serve_forever()


/mnt/c/tmp/input.wav
description: s
torch.Size([3, 1, 512000])


  best = np.argmax(np.asarray(results)[:, 1])
