In [1]:
import torch
import intel_extension_for_pytorch as ipex

  warn(
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration

processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").to("xpu")

  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


In [3]:
inputs = processor(
    text=["Dramatic piano music"],
    padding=True,
    return_tensors="pt",
)
inputs = {k:v.to('xpu') for k,v in inputs.items()}

with torch.no_grad():
    audio_values = model.generate(**inputs, max_new_tokens=258).cpu()

In [4]:
from IPython.display import Audio

sampling_rate = model.config.audio_encoder.sampling_rate
Audio(audio_values[0].numpy(), rate=sampling_rate)

In [5]:
audio_values.shape

torch.Size([1, 1, 163200])

In [6]:
inputs = processor(
    audio=audio_values[0][0],
    text=["Drums coming in"],
    padding=True,
    return_tensors="pt",
)
inputs = {k:v.to('xpu') for k,v in inputs.items()}

with torch.no_grad():
    audio_values = model.generate(**inputs, max_new_tokens=248).cpu()

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


In [7]:
Audio(audio_values[0].numpy(), rate=sampling_rate)

In [None]:
import scipy

sampling_rate = model.config.audio_encoder.sampling_rate
scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())

In [23]:
text=set(model.text_encoder.modules())
audio_enc=set(model.audio_encoder.modules())
audio_dec=set(model.decoder.modules())
proj=set(model.enc_to_dec_proj.modules())

all_layers = text.union(audio_enc).union(audio_dec).union(proj)

In [24]:
[type(x) for x in model.modules() if x not in all_layers]

[transformers.models.musicgen.modeling_musicgen.MusicgenForConditionalGeneration]

In [32]:
text_inputs = processor.tokenizer(["Drums coming in"],return_tensors='pt')
text_inputs = {k:v.xpu() for k,v in text_inputs.items()}
with torch.no_grad():
    text_emb=model.text_encoder(**text_inputs)
text_emb

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-1.8343e-01, -1.6788e-01, -1.3638e-01,  ...,  7.4312e-02,
          -5.7504e-01, -3.2075e-01],
         [-5.2078e-01,  1.2646e-01, -1.2453e-01,  ...,  2.9757e-01,
           4.8505e-02, -2.4622e-01],
         [-6.2492e-01,  1.0176e-01, -9.7727e-02,  ...,  1.6174e-01,
          -2.8732e-02, -4.1226e-02],
         [-3.7238e-01, -1.1041e-01,  2.1421e-01,  ..., -2.7740e-02,
          -3.6800e-01, -5.7025e-02],
         [-7.5804e-03, -3.5105e-04,  8.4690e-03,  ..., -1.6502e-03,
          -2.4512e-03,  1.8491e-03]]], device='xpu:0'), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)

In [34]:
text_emb.keys()

odict_keys(['last_hidden_state'])

In [35]:
text_emb.last_hidden_state.cpu()

tensor([[[-1.8343e-01, -1.6788e-01, -1.3638e-01,  ...,  7.4312e-02,
          -5.7504e-01, -3.2075e-01],
         [-5.2078e-01,  1.2646e-01, -1.2453e-01,  ...,  2.9757e-01,
           4.8505e-02, -2.4622e-01],
         [-6.2492e-01,  1.0176e-01, -9.7727e-02,  ...,  1.6174e-01,
          -2.8732e-02, -4.1226e-02],
         [-3.7238e-01, -1.1041e-01,  2.1421e-01,  ..., -2.7740e-02,
          -3.6800e-01, -5.7025e-02],
         [-7.5804e-03, -3.5105e-04,  8.4690e-03,  ..., -1.6502e-03,
          -2.4512e-03,  1.8491e-03]]])

In [58]:
with torch.no_grad():
    encoded=model.audio_encoder(input_values=audio_values.xpu())

In [61]:
encoded

EncodecOutput(audio_codes=tensor([[[[ 648,  315, 1771,  ..., 1788,  773,  801],
          [1519,  971, 1048,  ..., 1954, 1958, 1753],
          [ 924, 1895, 1190,  ..., 2025, 1974, 1878],
          [1628, 1595, 1456,  ..., 1116, 1116, 1409]]]], device='xpu:0'), audio_values=tensor([[[-0.0230, -0.0213, -0.0169,  ...,  0.0069,  0.0074,  0.0118]]],
       device='xpu:0'))

In [73]:
encoded.audio_codes.shape,text_emb.last_hidden_state.shape

(torch.Size([1, 1, 4, 500]), torch.Size([1, 5, 768]))

In [69]:
with torch.no_grad():
    x = model.enc_to_dec_proj(text_emb.last_hidden_state)
x.shape

torch.Size([1, 5, 1024])

In [79]:
with torch.no_grad():
    ans=model.decoder(encoded.audio_codes,encoder_hidden_states=x)

In [81]:
ans.keys()

odict_keys(['logits', 'past_key_values'])

In [83]:
ans.logits.shape

torch.Size([4, 500, 2048])