In [1]:
import torch
import intel_extension_for_pytorch as ipex

  warn(
  from .autonotebook import tqdm as notebook_tqdm


In [49]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration

processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").to("xpu")

In [89]:
inputs = processor(
    text=["dramatic modern piano solo in F minor"],
    padding=True,
    return_tensors="pt",
)
inputs = {k:v.to('xpu') for k,v in inputs.items()}

with torch.no_grad():
    audio_values = model.generate(**inputs, max_new_tokens=256).cpu()

In [90]:

#help(model.generate)
base_audio_values=audio_values

In [91]:
from IPython.display import Audio

sampling_rate = model.config.audio_encoder.sampling_rate
Audio(audio_values[0].numpy(), rate=sampling_rate)

In [92]:
audio_values.shape

torch.Size([1, 1, 161920])

In [101]:
inputs = processor(
    audio=base_audio_values[0][0],
    text=["80s Rock gitar"],
    padding=True,
    return_tensors="pt",
)
inputs = {k:v.to('xpu') for k,v in inputs.items()}

with torch.no_grad():
    audio_values = model.generate(**inputs, max_new_tokens=256).cpu()

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


In [102]:
Audio(audio_values[0].numpy(), rate=sampling_rate)

In [103]:
# import scipy

# sampling_rate = model.config.audio_encoder.sampling_rate
# scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())

# lookig on the internals

we will now look at the 3 fundemental models that make this work.

1. a text encoder model that takes text and outputs  embeddings
2. an audio encoding model that translates between audio and its own tokens
3. a new token predictor for these audio tokens that can use model 1s embedding as conditionals


the pipeline goes as follows:

1. encode the user prompt into embeddings
2. encode any past musical context into tokens
3. use the decoder to predict the next tokens
4. decode the tokens back into sound

In [104]:
#we will now look at all the parts of the model

#text -> embeddings
text=set(model.text_encoder.modules())
proj=set(model.enc_to_dec_proj.modules()) 

#audio -> tokens OR tokens -> audio
audio_enc=set(model.audio_encoder.modules())

#text_embeddings + audio_tokens -> next_audion_tokens
audio_dec=set(model.decoder.modules())

all_layers = text.union(audio_enc).union(audio_dec).union(proj)

In [133]:
model.text_encoder

T5EncoderModel(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dropout(p=0.1, 

In [105]:
[type(x) for x in model.modules() if x not in all_layers]

[transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration]

In [106]:
proj

{Linear(in_features=768, out_features=1024, bias=True)}

#### text to embedding

In [107]:

text_inputs = processor.tokenizer(["Drums coming in"],return_tensors='pt')
text_inputs = {k:v.xpu() for k,v in text_inputs.items()}
with torch.no_grad():
    text_emb=model.text_encoder(**text_inputs)
text_emb

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-1.8343e-01, -1.6788e-01, -1.3638e-01,  ...,  7.4312e-02,
          -5.7504e-01, -3.2075e-01],
         [-5.2078e-01,  1.2646e-01, -1.2453e-01,  ...,  2.9757e-01,
           4.8505e-02, -2.4622e-01],
         [-6.2492e-01,  1.0176e-01, -9.7727e-02,  ...,  1.6174e-01,
          -2.8732e-02, -4.1226e-02],
         [-3.7238e-01, -1.1041e-01,  2.1421e-01,  ..., -2.7740e-02,
          -3.6800e-01, -5.7025e-02],
         [-7.5804e-03, -3.5105e-04,  8.4690e-03,  ..., -1.6502e-03,
          -2.4512e-03,  1.8491e-03]]], device='xpu:0'), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)

In [108]:
text_emb.keys()

odict_keys(['last_hidden_state'])

In [109]:
with torch.no_grad():
    text_emb = model.enc_to_dec_proj(text_emb.last_hidden_state)
text_emb.shape

torch.Size([1, 5, 1024])

#### audio to tokens

In [110]:
with torch.no_grad():
    encoded=model.audio_encoder(input_values=audio_values.xpu())

In [111]:
encoded.audio_codes

tensor([[[[  68,  564,  564,  ..., 1548, 1968, 1131],
          [1941, 2041, 1658,  ..., 1889, 1814,  970],
          [1895, 1689, 2024,  ...,  369,  743, 1441],
          [2034, 2034,  704,  ...,  708, 1310, 1661]]]], device='xpu:0')

In [112]:
encoded.keys()

odict_keys(['audio_codes', 'audio_values'])

In [113]:
encoded.audio_codes.shape,text_emb.shape

(torch.Size([1, 1, 4, 506]), torch.Size([1, 5, 1024]))

#### embeddings + tokens to next tokens

In [114]:
with torch.no_grad():
    ans=model.decoder(encoded.audio_codes,encoder_hidden_states=text_emb)

In [115]:
ans.keys() 

odict_keys(['logits', 'past_key_values'])

In [116]:
#note that the answer acts just like any other generative transformer. we have logits and past key values
ans.logits.shape

torch.Size([4, 506, 2048])

In [117]:
model.decoder

MusicgenForCausalLM(
  (model): MusicgenModel(
    (decoder): MusicgenDecoder(
      (embed_tokens): ModuleList(
        (0-3): 4 x Embedding(2049, 1024)
      )
      (embed_positions): MusicgenSinusoidalPositionalEmbedding()
      (layers): ModuleList(
        (0-23): 24 x MusicgenDecoderLayer(
          (self_attn): MusicgenAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
          )
          (activation_fn): GELUActivation()
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MusicgenAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=False)


# playing with codes

to demonstrate how the music tokens are translated to audio we will try and modify them directly

In [118]:
codes=encoded.audio_codes.clone()
codes.shape

torch.Size([1, 1, 4, 506])

In [119]:
type(model.audio_encoder)

transformers.models.encodec.modeling_encodec.EncodecModel

In [120]:
model.decoder.config.audio_channels

1

In [121]:
with torch.no_grad():
    ans=model.audio_encoder.decode(codes,audio_scales=[None])
ans

EncodecDecoderOutput(audio_values=tensor([[[-0.0094, -0.0124, -0.0113,  ..., -0.0233, -0.0251, -0.0233]]],
       device='xpu:0'))

In [122]:
ans.audio_values.shape

torch.Size([1, 1, 323840])

In [123]:
#here we will just see the original audio
Audio(ans.audio_values[0].cpu().numpy(), rate=sampling_rate)

In [124]:
#lets automate this
@torch.no_grad
def codes_to_audio(codes):
    ans=model.audio_encoder.decode(codes,audio_scales=[None])
    return ans.audio_values[0].cpu().numpy()

In [125]:
Audio(codes_to_audio(codes[:,:,:,300:500]), rate=sampling_rate)

In [126]:
#changing a few things barely matters
codes[0,0,0,300]=23
Audio(codes_to_audio(codes[:,:,:,300:500]), rate=sampling_rate)

In [127]:
codes[0,0,0,300:320]=23
Audio(codes_to_audio(codes[:,:,:,300:500]), rate=sampling_rate)

In [128]:
#now we take more of it
codes[0,0,0,300:400]=23
Audio(codes_to_audio(codes[:,:,:,300:500]), rate=sampling_rate)

In [130]:
#now we took over more chanels it will break
codes[0,0,2,300:400]=23
Audio(codes_to_audio(codes[:,:,:,300:500]), rate=sampling_rate)

In [131]:
#start will still sound fine
Audio(codes_to_audio(codes), rate=sampling_rate)

In [132]:
#looking at the effects of chanels
codes=encoded.audio_codes.clone()
codes[0,0,2,300:400]=23
Audio(codes_to_audio(codes[:,:,:,300:500]), rate=sampling_rate)

In [37]:
#looking at the effects of chanels
codes=encoded.audio_codes.clone()
codes[0,0,1,300:400]=23
Audio(codes_to_audio(codes[:,:,:,300:500]), rate=sampling_rate)

In [38]:
#looking at the effects of chanels
codes=encoded.audio_codes.clone()
codes[0,0,3,300:400]=23
Audio(codes_to_audio(codes[:,:,:,300:500]), rate=sampling_rate)