# MAGNeT
Welcome to MAGNeT's demo jupyter notebook. 
Here you will find a self-contained example of how to use MAGNeT for music/sound-effect generation.

First, we start by initializing MAGNeT for music generation, you can choose a model from the following selection:
1. facebook/magnet-small-10secs - a 300M non-autoregressive transformer capable of generating 10-second music conditioned on text.
2. facebook/magnet-medium-10secs - 1.5B parameters, 10 seconds music samples.
3. facebook/magnet-small-30secs - 300M parameters, 30 seconds music samples.
4. facebook/magnet-medium-30secs - 1.5B parameters, 30 seconds music samples.

We will use the `facebook/magnet-small-10secs` variant for the purpose of this demonstration.

In [1]:
from audiocraft.models import MAGNeT
import torch, torchaudio
from audiocraft.utils.notebook import display_audio
from audiocraft.data.audio import audio_write

model = MAGNeT.get_pretrained('facebook/magnet-small-10secs')

model.set_generation_params(
    use_sampling=True,
    top_k=0,
    top_p=0.9,
    temperature=3.0,
    max_cfg_coef=10.0,
    min_cfg_coef=1.0,
    decoding_steps=[int(20 * model.lm.cfg.dataset.segment_duration // 10),  10, 10, 10],
    span_arrangement='stride1'
)

  from .autonotebook import tqdm as notebook_tqdm


Next, let us configure the generation parameters. Specifically, you can control the following:
* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.
* `top_k` (int, optional): top_k used for sampling. Defaults to 0.
* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
* `temperature` (float, optional): Initial softmax temperature parameter. Defaults to 3.0.
* `max_clsfg_coef` (float, optional): Initial coefficient used for classifier free guidance. Defaults to 10.0.
* `min_clsfg_coef` (float, optional): Final coefficient used for classifier free guidance. Defaults to 1.0.
* `decoding_steps` (list of n_q ints, optional): The number of iterative decoding steps, for each of the n_q RVQ codebooks.
* `span_arrangement` (str, optional): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') 
                                      in the masking scheme. 

When left unchanged, MAGNeT will revert to its default parameters.

Next, we can go ahead and start generating music given textual prompts.

### Text-conditional Generation - Music

In [2]:
import numpy as np

def apply_fade(audio, sr, duration=.1):
    # convert to audio indices (samples)
    length = int(duration*sr)
    end = audio.shape[0]
    start = end - length

    # compute fade out curve
    # linear fade
    fade_curve = np.linspace(1.0, 0.0, length)

    # apply the curve
    audio[start:end] = audio[start:end] * fade_curve

    fade_curve = np.linspace(0.0, 1.0, length)

    # apply the curve
    audio[:length] = audio[:length] * fade_curve

    return audio

In [3]:
# LOOP Extraction
from madmom.features.downbeats import DBNDownBeatTrackingProcessor
from madmom.features.downbeats import RNNDownBeatProcessor
import os
import soundfile as sf
import pyrubberband as pyrb
import librosa

proc = DBNDownBeatTrackingProcessor(beats_per_bar=4, fps = 100, verbose=False)

def extract_loop(file_path, desired_bpm, num_bars=2):
    try:
        _, sr = librosa.core.load(file_path, sr=None) # sr = None will retrieve the original sampling rate = 44100
    except:
        print('load file failed')
        return None
    try:
        act = RNNDownBeatProcessor(verbose=False)(file_path)
        down_beat=proc(act, verbose=False) # [..., 2] 2d-shape numpy array
    except Exception as exp:
        print('except happended', exp)
        return None
    count = 0
    name = file_path.replace('.wav', '')
    for i in range(8, down_beat.shape[0]):
        if down_beat[i][1] == 1 and i + 4*num_bars < down_beat.shape[0] and down_beat[i+4*num_bars][1] == 1:
            start_time = down_beat[i][0]
            end_time = down_beat[i + 4*num_bars][0]
            count += 1
            out_path = os.path.join("./", f'{name}_{count}.wav')
            y_one_bar, _ = librosa.core.load(file_path, offset=start_time, duration = end_time - start_time, sr=None)
            desired_duration = 60./desired_bpm * (4*num_bars)
            y_stretch = pyrb.time_stretch(y_one_bar, sr,  (end_time - start_time) / desired_duration)
            y_stretch = apply_fade(y_stretch, sr)
            sf.write(out_path, y_stretch, sr)
            return out_path
    return None

In [4]:
from audiocraft.utils.notebook import display_audio

batch_size = 3

def _loop_gen(filepath, bpm, temperature=1.0, description="jazzy beat"):

    prompt_waveform, prompt_sr = torchaudio.load(filepath)
    prompt_duration = 2.0
    prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]
    prompt_waveform = prompt_waveform[None, ...].repeat(batch_size, 1, 1)

    desc = [description] * batch_size
    output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, 
                                         progress=True, return_tokens=True, descriptions=desc)
    # output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, descriptions=descriptions, 
                                        #  progress=True, return_tokens=True)
    print(output[0].shape)
    sr = model.sample_rate
    start_pos = int(prompt_duration * sr) 
    display_audio(output[0][:,:,start_pos:], sample_rate=32000)

    outpaths = []
    for index in range(batch_size):
        file_path = "/mnt/c/tmp/output%d" % index
        start_pos = 0
        output_data = output[0][index].cpu().squeeze()[start_pos:]
        audio_write(file_path, output_data, sr, strategy="loudness", loudness_compressor=True)
        outpath = extract_loop(file_path + ".wav", bpm, num_bars=2)
        outpaths.append(outpath)
    return outpaths

In [5]:
# %%
import os 
import time
from threading import Thread
from pythonosc import dispatcher
from pythonosc import osc_server, udp_client
from IPython.display import clear_output


def loop_gen(unused_addr, filepath, duration, bpm, temperature=1.0, description="funky beat"):
    print(filepath)
    filepath = filepath.replace("C:/", "/mnt/c/") # file path is in windows format

    if os.path.exists(filepath):
        outpaths = _loop_gen(filepath, duration, bpm, temperature, description)
        # if outpath is not None:    
        client.send_message("/generated", (1))
        client.send_message("/generated", (1))
    else:
        print("file not found", filepath)
    clear_output(wait=True)

dispatcher = dispatcher.Dispatcher()
dispatcher.map("/loop_gen", loop_gen)

#%%

#client = udp_client.SimpleUDPClient('127.0.0.1', 10018)
client = udp_client.SimpleUDPClient('172.17.128.1', 10018)


def beacon():
    while True:
        client.send_message("/heartbeat", 1)
        time.sleep(1.0)
t1 = Thread(target = beacon)
t1.setDaemon(True)
t1.start()

server = osc_server.ThreadingOSCUDPServer(
    ('172.17.140.208', 10026), dispatcher)
print("Serving on {}".format(server.server_address))
server.serve_forever()

  best = np.argmax(np.asarray(results)[:, 1])
  best = np.argmax(np.asarray(results)[:, 1])
