# Music Generation
This notebook shows how to generate music using Facebook's musicgen model (`musicgen-small`).


<div>
<img src='https://ichef.bbci.co.uk/images/ic/640x360/p09h5gp2.jpg'
    width="500"/>
</div>

[Image source](https://www.bbc.co.uk/programmes/w3ct1rl3)

# 1. Settings

In [1]:
#%pip install torchaudio
#%pip install matplotlib

In [2]:
import os
import torch
import torchaudio
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from IPython.display import Audio

import warnings
warnings.filterwarnings("ignore")


# 2. What kind of music would you like to have?

In [3]:
# user's input: 
## instrument 
instrument ='ukulele'
## description of music 
descriptions = [f'lofi {instrument} music with breeze-like, warm, tuneful melody at 60bpm'] 


# output file name
filename =instrument.replace(" ","_") + ".wav"

# output file path
filepath =os.path.join("../data/audio", filename)


# 3. Generate music

In [4]:
# model
music_model = "facebook/musicgen-small"

# a function to generate music 
## 5 seconds music: max_new_tokens= 256 
## 10 seconds mucic: max_new_tokens= 512
def MusicGen(descriptions, 
             filepath= filepath, 
             music_model = music_model, 
             max_new_tokens= 512) : 
    
    # constructs a MusicGen processor 
    processor = AutoProcessor.from_pretrained(music_model)

    # model 
    model = MusicgenForConditionalGeneration.from_pretrained(music_model)

    # input 
    inputs = processor(
        text=descriptions,
        padding=True,
        return_tensors="pt")

    # generate music
    music = model.generate(**inputs, max_new_tokens=max_new_tokens)

        # sampling rate
    sampling_rate = model.config.audio_encoder.sampling_rate
    
    # save
    torchaudio.save(filepath, music[0], sampling_rate)
    return filepath



In [5]:
# generate music and save as a wav file
music_path = MusicGen(descriptions, music_model = music_model) 

In [6]:
# load generated music
music = torchaudio.load(music_path)

# 4. Listen

In [7]:
# listen
Audio(music[0], rate=music[1])