# MusicGen-Stem
Welcome to MusicGen-Stem's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen-Stem in different settings.

First, we start by initializing MusicGen-Stem, you can choose a model from the following selection:
1. `facebook/musicgen-stem-6cb` - 1.5B transformer decoder with 1 codebook for bass, 1 codebook for drums and 4 codebooks for other. It is the model that is showcased in the MusicGen-Stem paper.
2. `facebook/musicgen-stem-7cb` - 1.5B transformer decoder with 2 codebooks for bass, 1 codebook for drums and 4 codebooks for other. This model is not showcased in the MusicGen-Stem paper but has been developed in order to have a better sounding bass. 


In [None]:
import torchaudio
from audiocraft.utils.notebook import display_audio


In [None]:
from audiocraft.models import MusicGenStem

# Choose a model between these two:

# model = MusicGenStem.get_pretrained('facebook/musicgen-stem-6cb')
model = MusicGenStem.get_pretrained('facebook/musicgen-stem-7cb')


Next, let us configure the generation parameters. Specifically, you can control the following:
* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.
* `top_k` (int, optional): top_k used for sampling. Defaults to 250.
* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.
* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.
* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.
* `double_cfg` (bool, optional): If True, use double CFG. Defaults to False.


When left unchanged, MusicGen will revert to its default parameters.


In [None]:
model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=30
)

The model can perform text-to-music generation, music continuation, text condition instrument regeneration from a mixture or from stems.
* Text-to-music can be done using `model.generate`, or `model.generate_with_chroma` with the wav condition being None. 
* Style-to-music and Text-and-Style-to-music can be done using `model.generate_with_chroma`

## Text-to-Music

In [None]:

model.set_generation_params(
    duration=8, # generate 8 seconds, can go up to 30
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
)

descriptions=[
        '80s pop track with bassy drums and synth',
        '90s rock song with loud guitars and heavy drums',
        'EDM inspiring song',
        'Bluesy guitar instrumental with soulful licks and a driving rhythm section',
        'Funky song with a strong bassline and a dancy feeling',
    ]


output = model.generate(
    descriptions=descriptions,
    progress=True, return_tokens=True
)

# We create the mixture by summing the bass, drums and other
output[0]['mixture'] = sum(output[0].values())


In [None]:
# Now we listen to each stem as well as the mixture, song by song

for idx, description in enumerate(descriptions):
    print(description)
    for stem in ['bass', 'drums', 'other', 'mixture']:
        print(stem)
        display_audio(output[0][stem][idx], sample_rate=32000)


## Generate the continuation of an existing mixture

In [None]:
mixture, sr = torchaudio.load('../assets/electronic.mp3')
display_audio(mixture, sample_rate=sr)


In [None]:
output = model.generate_continuation_from_mixture(
    mixture=mixture, mixture_sample_rate=sr,
    descriptions=['dancy electronic song'],
    progress=True, return_tokens=True
)

# We create the mixture by summing the bass, drums and other
output[0]['mixture'] = sum(output[0].values())


In [None]:
for stem in ['bass', 'drums', 'other', 'mixture']:
    print(stem)
    display_audio(output[0][stem], sample_rate=32000)


## Generate the continuation of a song from its codes (tokens)
Given some generated codes (that we can obtain with a song that we generate):

In [None]:
model.set_generation_params(
    duration=5, # generate 4 seconds, can go up to 30
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
)

output = model.generate(
    descriptions=['Folk song with an acoustic guitar'],
    progress=True, return_tokens=True
)

output[0]['mixture'] = sum(output[0].values())

for stem in ['bass', 'drums', 'other', 'mixture']:
    print(stem)
    display_audio(output[0][stem], sample_rate=32000)

# We can extract the codes, since we used the argument return_tokens=True
codes = output[1]
print(codes)


Now, given the codes of the generated song, we can extend it with a different prompt and see how the model switches from one genre to another. The transition is not necesseraly very smooth but the continuation function can be useful when the user generates a few seconds, listen to it and then decides to generate the continuation if they enjoy the beginning.

In [None]:
# We have to extend the duration, let's say to 10 seconds

model.set_generation_params(
    duration=10,
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
)

output = model.generate_continuation_from_codes(
    codes=codes,
    descriptions=['Folk song with drums and an electric guitar'],
    progress=True, return_tokens=True
)

output[0]['mixture'] = sum(output[0].values())

for stem in ['bass', 'drums', 'other', 'mixture']:
    print(stem)
    display_audio(output[0][stem], sample_rate=32000)


## Regenerate stems on an existing song (mixture)
To do so, the user needs to load an existing song. The song can be in mono or stereo in any sample rate the model will convert it in 32khz mono, separate it with demucs and replace the desired stems. The model is made for 25 seconds excerpts. If the song is longer, only the first 25 seconds will be taken into account, if shorter the song will be padded with zeros. 
You need to use the ```regenerate_instruments_from_mixture``` function.

In [None]:
# We load a 25 seconds excerpt of a song
path_source = '../assets/pop_song.wav'

mixture, sr = torchaudio.load(path_source)

display_audio(mixture, sample_rate=sr)


### Let's regenerate the bass:

In [None]:
model.set_generation_params(
    duration=15, # The final length will be duration-5. You need to put 30 if you want to regenerate all of the 25 secs
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
)

output = model.regenerate_instruments_from_mixture(
    mixture=mixture,
    mixture_sample_rate=sr,
    which_instruments_regenerate=['bass'], # list of stems that you want to replace
    descriptions=['Pop song with a funky bass'], # put any prompt that you want
    progress=True, return_tokens=True
)


output[0]['mixture'] = sum(output[0].values())

for stem in ['bass', 'drums', 'other', 'mixture']:
    print(stem)
    display_audio(output[0][stem], sample_rate=32000)


### Let's regenerate the drums:

In [None]:
model.set_generation_params(
    duration=15, # The final length will be duration-5. You need to put 30 if you want to regenerate all of the 25 secs
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
)

output = model.regenerate_instruments_from_mixture(
    mixture=mixture,
    mixture_sample_rate=sr,
    which_instruments_regenerate=['drums'], # list of stems that you want to replace
    descriptions=['Upbeat drums, pop song'], # put any prompt that you want
    progress=True, return_tokens=True
)


output[0]['mixture'] = sum(output[0].values())

for stem in ['bass', 'drums', 'other', 'mixture']:
    print(stem)
    display_audio(output[0][stem], sample_rate=32000)


### Let's regenerate the other stems:

In [None]:
model.set_generation_params(
    duration=15, # The final length will be duration-5. You need to put 30 if you want to regenerate all of the 25 secs
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
)

output = model.regenerate_instruments_from_mixture(
    mixture=mixture,
    mixture_sample_rate=sr,
    which_instruments_regenerate=['other'], # list of stems that you want to replace
    descriptions=['Pop song with a piano'], # put any prompt that you want
    progress=True, return_tokens=True
)


output[0]['mixture'] = sum(output[0].values())

for stem in ['bass', 'drums', 'other', 'mixture']:
    print(stem)
    display_audio(output[0][stem], sample_rate=32000)


### Let's regenerate the bass and the other stems:

In [None]:
model.set_generation_params(
    duration=15, # The final length will be duration-5. You need to put 30 if you want to regenerate all of the 25 secs
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
)

output = model.regenerate_instruments_from_mixture(
    mixture=mixture,
    mixture_sample_rate=sr,
    which_instruments_regenerate=['bass', 'other'], # list of stems that you want to replace
    descriptions=['Pop song with an acoustic guitar and an upbeat drums'], # put any prompt that you want
    progress=True, return_tokens=True
)


output[0]['mixture'] = sum(output[0].values())

for stem in ['bass', 'drums', 'other', 'mixture']:
    print(stem)
    display_audio(output[0][stem], sample_rate=32000)


## Regenerate from audio stems. 
If you already have audio stems, (e.g. a drum loop, you can generate the bass and other instruments). 
For this you need to use the ```regenerate_instruments_from_stems``` function.

In [None]:
# load some drums

drums, sr = torchaudio.load('../assets/drum_loop.wav')

display_audio(drums, sample_rate=sr)

stems = {'drums': drums}

In [None]:
model.set_generation_params(
    duration=15, # The final length will be duration-5. You need to put 30 if you want to regenerate all of the 25 secs
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
)

output = model.regenerate_instruments_from_stems(
    stems=stems,
    stems_sample_rate=sr,
    which_instruments_regenerate=['bass', 'other'], # list of stems that you want to replace
    descriptions=['House song with synth pads with a groovy bassline. Uplifting feeling'],
    progress=True, return_tokens=True,
    return_non_compressed_stems=True, # for the input stems, we return the original ones instead of the compressed ones
)


output[0]['mixture'] = sum(output[0].values())

for stem in ['bass', 'drums', 'other', 'mixture']:
    print(stem)
    display_audio(output[0][stem], sample_rate=32000)


## Regenerate from codes.
This function is useful for regenerating some stems from codes. The typical usecase is if we generate a song from scratch with MusicGen-Stem and want to regenerate some stems. We then use the ```regenerate_from_codes``` function.

In [None]:
model.set_generation_params(
    duration=10, # generate 4 seconds, can go up to 30
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
)

output = model.generate(
    descriptions=['Folk song with drums, bass and an acoustic guitar'],
    progress=True, return_tokens=True
)

output[0]['mixture'] = sum(output[0].values())

for stem in ['bass', 'drums', 'other', 'mixture']:
    print(stem)
    display_audio(output[0][stem], sample_rate=32000)

# We can extract the codes, since we used the argument return_tokens=True
codes = output[1]
print(codes)


Now, given the codes, we want to regenerate some stems (e.g. the bass)

In [None]:
model.set_generation_params(
    duration=15, # The final length will be duration-5. You need to put 30 if you want to regenerate all of the 25 secs
    use_sampling=True, 
    top_k=250,
    cfg_coef=3., # Classifier Free Guidance coefficient 
)

output = model.regenerate_from_codes(
    codes=codes,
    which_instruments_regenerate=['bass'], # list of stems that you want to replace
    descriptions=['Folk song with drums, bass and an acoustic guitar'], # put any prompt that you want
    progress=True, return_tokens=True
)


output[0]['mixture'] = sum(output[0].values())

for stem in ['bass', 'drums', 'other', 'mixture']:
    print(stem)
    display_audio(output[0][stem], sample_rate=32000)
