# MAGNeT
Welcome to MAGNeT's demo jupyter notebook. 
Here you will find a self-contained example of how to use MAGNeT for music/sound-effect generation.

First, we start by initializing MAGNeT for music generation, you can choose a model from the following selection:
1. facebook/magnet-small-10secs - a 300M non-autoregressive transformer capable of generating 10-second music conditioned on text.
2. facebook/magnet-medium-10secs - 1.5B parameters, 10 seconds music samples.
3. facebook/magnet-small-30secs - 300M parameters, 30 seconds music samples.
4. facebook/magnet-medium-30secs - 1.5B parameters, 30 seconds music samples.

We will use the `facebook/magnet-small-10secs` variant for the purpose of this demonstration.

In [1]:
# Import required libraries
import openai
import json5

# Initialize OpenAI API key
openai.api_key = ""  # Replace with your actual API key

In [2]:

# prompt_input = f"""""
# Summerize '{item['text']}' and format the caption with other information into a JSON file.
#"""""
example_output = {
    "comment": "Yeah, so this is a really cool track. It's got a great beat and a really catchy melody.The bassline is super funky and the whole thing just makes you want to dance. It's got a real African vibe to it, with the percussion drums and the 808 bass. The Shakuhachi adds a nice touch of Japanese flair. The koto melody is really beautiful and the bassline is just so funky. It's a really cool track and I think it would be perfect for a dance party or a club night. ",
    "captions": [
        "Minimal percussion drums with 808 bass, influenced by early 90's chicago house, bpm 120",
        "Japanese koto melody with jazz influence, tribal vibes for rituals, bpm 120",
        "A very funky minimal jazzy bassline, good for big parties, bpm 120"
    ]   
}

params = ["aggressive", "minimal", "experimental", "dubby", "jazzy", "pop"]

example_params = {
    "aggressive": 0.,
    "minimal": 0.8,
    "experimental": 0.0,
    "dubby": 0.0,
    "jazzy": 1.0,
    "pop": 0.2 
}

prompt_sample = f"""""I need text prompts for MusicGen, text-to-music model. Think step by step!\
    Please give me 3 text promots in a python array with a comment. one for each item in this array, (rhythm, melody, bass).\
    BPM should be inbetween 110 to 170. Be specific, get creative and use random instruments! \
    All three prompts should refleact the input parameters: 
"""""

print(prompt_sample)

isFirst = True

def get_captions(settings):
    global isFirst
    # Query

    param_settings =  dict(zip(params, settings))
    print(param_settings) 
       
    if isFirst or 1:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a world-famous dance music DJ performing at the biggest music festival."},
                {"role": "user", "content": f"{prompt_sample} {example_params}"},
                {"role": "assistant", "content": f"{example_output}"},
                {"role": "user", "content": f"{prompt_sample} {param_settings}"},
            ]
        )
    else:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",
            messages=[
                # {"role": "system", "content": "You are a world-famous dance music DJ."},
                # {"role": "user", "content": f"{prompt_sample}"},
                # {"role": "assistant", "content": f"{example_captions}"},
                {"role": "user", "content": f"{prompt_sample}"},
            ]
        )
        isFirst = False

    #print(response)

    # Extract caption from API response
    json_output = response['choices'][0]['message']['content'].strip()
#    json_output = json_output.replace('\'', '\"') 
    json_output = json5.loads(json_output)
    print(json_output)

    # parse the output
    comment = json_output["comment"]
    captions = json_output["captions"]
    if len(captions) == 3:
        captions[1] = "no drums. " + captions[1]
        captions[2] = "bassline only, " + captions[2]

    return comment, captions

""I need text prompts for MusicGen, text-to-music model. Think step by step!    Please give me 3 text promots in a python array with a comment. one for each item in this array, (rhythm, melody, bass).    BPM should be inbetween 110 to 170. Be specific, get creative and use random instruments!     All three prompts should refleact the input parameters: 



In [3]:
get_captions([0.5,0.5,0.5,0.5,0.5,0.5])

{'aggressive': 0.5, 'minimal': 0.5, 'experimental': 0.5, 'dubby': 0.5, 'jazzy': 0.5, 'pop': 0.5}
{'comment': "This track is a wild experiment in sound. The rhythm is aggressive and in your face, with heavy bass and sharp percussion. The melody is experimental, with dissonant tones and unexpected twists. The bassline is dubby and deep, rumbling beneath the chaotic mix of sounds. It's a jazzy, pop-influenced track that defies convention and pushes boundaries.", 'captions': ['Aggressive industrial rhythm with distorted bass, BPM 150', 'Experimental electronic melody with glitchy effects, BPM 140', 'Dubby sub bassline with syncopated rhythms, BPM 130']}


("This track is a wild experiment in sound. The rhythm is aggressive and in your face, with heavy bass and sharp percussion. The melody is experimental, with dissonant tones and unexpected twists. The bassline is dubby and deep, rumbling beneath the chaotic mix of sounds. It's a jazzy, pop-influenced track that defies convention and pushes boundaries.",
 ['Aggressive industrial rhythm with distorted bass, BPM 150',
  'no drums. Experimental electronic melody with glitchy effects, BPM 140',
  'bassline only, Dubby sub bassline with syncopated rhythms, BPM 130'])

In [4]:
from audiocraft.models import MAGNeT
import torch, torchaudio
from audiocraft.utils.notebook import display_audio
from audiocraft.data.audio import audio_write

model = MAGNeT.get_pretrained('facebook/magnet-small-30secs')

model.set_generation_params(
    use_sampling=True,
    top_k=0,
    top_p=0.9,
    temperature=3.0,
    max_cfg_coef=10.0,
    min_cfg_coef=1.0,
    decoding_steps=[int(20 * model.lm.cfg.dataset.segment_duration // 10),  10, 10, 10],
    span_arrangement='stride1'
)

  from .autonotebook import tqdm as notebook_tqdm


Next, let us configure the generation parameters. Specifically, you can control the following:
* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.
* `top_k` (int, optional): top_k used for sampling. Defaults to 0.
* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
* `temperature` (float, optional): Initial softmax temperature parameter. Defaults to 3.0.
* `max_clsfg_coef` (float, optional): Initial coefficient used for classifier free guidance. Defaults to 10.0.
* `min_clsfg_coef` (float, optional): Final coefficient used for classifier free guidance. Defaults to 1.0.
* `decoding_steps` (list of n_q ints, optional): The number of iterative decoding steps, for each of the n_q RVQ codebooks.
* `span_arrangement` (str, optional): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') 
                                      in the masking scheme. 

When left unchanged, MAGNeT will revert to its default parameters.

Next, we can go ahead and start generating music given textual prompts.

### Text-conditional Generation - Music

In [5]:
import numpy as np

def apply_fade(audio, sr, duration=.025):
    # convert to audio indices (samples)
    length = int(duration*sr)
    end = audio.shape[0]
    start = end - length

    # compute fade out curve
    # linear fade
    fade_curve = np.linspace(1.0, 0.0, length)

    # apply the curve
    audio[start:end] = audio[start:end] * fade_curve

    fade_curve = np.linspace(0.0, 1.0, length)

    # apply the curve
    audio[:length] = audio[:length] * fade_curve

    return audio

In [6]:
# LOOP Extraction
from madmom.features.downbeats import DBNDownBeatTrackingProcessor
from madmom.features.downbeats import RNNDownBeatProcessor
import os
import soundfile as sf
import pyrubberband as pyrb
import librosa

proc = DBNDownBeatTrackingProcessor(beats_per_bar=4, fps = 100, verbose=False)

def extract_loop(file_path, desired_bpm, num_bars=2):
    try:
        _, sr = librosa.core.load(file_path, sr=None) # sr = None will retrieve the original sampling rate = 44100
    except:
        print('load file failed')
        return None
    try:
        act = RNNDownBeatProcessor(verbose=False)(file_path)
        down_beat=proc(act, verbose=False) # [..., 2] 2d-shape numpy array
    except Exception as exp:
        print('except happended', exp)
        return None
    count = 0
    name = file_path.replace('.wav', '')
    for i in range(0, down_beat.shape[0]):
        if down_beat[i][1] == 1 and i + 4*num_bars < down_beat.shape[0] and down_beat[i+4*num_bars][1] == 1:
            start_time = down_beat[i][0]
            end_time = down_beat[i + 4*num_bars][0]
            count += 1
            out_path = os.path.join("./", f'{name}_{count}.wav')
            print(out_path, "outpath")
            y_one_bar, _ = librosa.core.load(file_path, offset=start_time, duration = end_time - start_time, sr=None)
            # desired_duration = 60./desired_bpm * (4*num_bars)
            # y_stretch = pyrb.time_stretch(y_one_bar, sr,  (end_time - start_time) / desired_duration)
            y_one_bar = apply_fade(y_one_bar, sr)
            sf.write(out_path, y_one_bar, sr)
            return out_path
    return None

In [7]:
from audiocraft.utils.notebook import display_audio

batch_size = 3

def _loop_gen(filepath, bpm, temperature=1.0, settings=[]):

    try:
        print(settings)
        comment, descs = get_captions(settings)
    except Exception as e:
        print("Error: ", e) 
        comment, descs = "", ["","",""] #dummy captions

    descs = descs[:batch_size]
    print(descs)

    #descs = None
    # descs = ["", "", ""]

    # prompt audio
    prompt_waveform, prompt_sr = torchaudio.load(filepath)
    prompt_duration = 60. / bpm * 4 * 2 # the first 2 bar 
    prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]
    prompt_waveform = prompt_waveform[None, ...].repeat(batch_size, 1, 1)

    # desc = [description] * batch_size
    output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, 
                                         progress=True, return_tokens=True, descriptions=descs)
    # output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, descriptions=descriptions, 
                                        #  progress=True, return_tokens=True)
    print("generated", output.shape)

    sr = model.sample_rate
    start_pos = int(prompt_duration * sr)  # skip the first 2 bars
    display_audio(output[0], sample_rate=32000)

    outpaths = []
    for index in range(batch_size):
        file_path = "/mnt/c/tmp/output%d" % index
        # start_pos = 0
        output_data = output[0][index].cpu().squeeze()[start_pos:]
        audio_write(file_path, output_data, sr, strategy="loudness", loudness_compressor=True)
        outpath = extract_loop(file_path + ".wav", bpm, num_bars=4)
        if outpath:        
            client.send_message("/prompt", (index, descs[index]))
        outpaths.append(outpath)

    return outpaths

In [8]:
# %%
import os 
import time
from threading import Thread
from pythonosc import dispatcher
from pythonosc import osc_server, udp_client
from IPython.display import clear_output

isGenerating = False

def loop_gen(unused_addr, filepath, bpm=120., temperature=1.0, p1=0.5, p2=0.5, p3=0.5, p4=0.5, p5=0.5, p6=0.5):
    global isGenerating

    if isGenerating: # avoid duplicate generation
        return
    
    # parameters
    #model.set_generation_params(temperature=temperature)

    isGenerating = True
    filepath = filepath.replace("C:/", "/mnt/c/") # file path is in windows format

    if os.path.exists(filepath):
        try:
            settings = [p1, p2, p3, p4, p5, p6]
            outpaths = _loop_gen(filepath, bpm, temperature, settings)
        except Exception as e:
            print(e)
            outpaths = None
        # if outpath is not None:    
        client.send_message("/generated", (1))
        client.send_message("/generated", (1))
    else:
        print("file not found", filepath)

    clear_output(wait=True)
    isGenerating = False

dispatcher = dispatcher.Dispatcher()
dispatcher.map("/loop_gen", loop_gen)

#%%

# client = udp_client.SimpleUDPClient('127.0.1.1', 10018)
client = udp_client.SimpleUDPClient('172.17.128.1', 10018)


def beacon():
    while True:
        client.send_message("/heartbeat", 1)
        time.sleep(1.0)
def generating():
    while True:
        if isGenerating:
            client.send_message("/generating", 1)
        time.sleep(0.333)

t1 = Thread(target = beacon)
t2 = Thread(target = generating)
t1.setDaemon(True)
t2.setDaemon(True)
t1.start()
t2.start()

server = osc_server.ThreadingOSCUDPServer(
    ('172.17.140.208', 10026), dispatcher)
print("Serving on {}".format(server.server_address))
server.serve_forever()

KeyboardInterrupt: 