# Setup

In [1]:
from pathlib import Path
import os
# set CA bundle path for requests to work via Zscaler
os.environ['CURL_CA_BUNDLE'] = str(Path.home() / '.zscaler-cert-app-store/Bundle.pem')
# needed so bitsandbytes can find correct cuda path!
%env LD_LIBRARY_PATH=/usr/local/cuda-12.2/lib64

env: LD_LIBRARY_PATH=/usr/local/cuda-12.2/lib64


# Download Model and Inference

In [90]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration, BitsAndBytesConfig
from accelerate import Accelerator

accelerator = Accelerator()
print(f'Using device: {accelerator.device}')

# NOTE: 4-bit quantization really messes output up, and 8-bit speeds up loading but significantly reduces inference time since we have to skip enc_to_dec_proj
# seems like 4/8 bit not suited for inference? https://github.com/TimDettmers/bitsandbytes/issues/490 
quant_config = BitsAndBytesConfig(
   # load_in_4bit=True,
   load_in_8bit=True, 
   llm_int8_has_fp16_weight=True,
   # bnb_4bit_quant_type="nf4",
   # bnb_4bit_use_double_quant=True,
   # bnb_4bit_compute_dtype=torch.bfloat16,
   llm_int8_skip_modules=['enc_to_dec_proj'] # skip final layer since weight_norm is not deepcopy-able
)

model_name = "facebook/musicgen-small"
# text and melody input tokenizer
processor = AutoProcessor.from_pretrained(model_name)
# actual encoder/decoder models
model = MusicgenForConditionalGeneration.from_pretrained(model_name) #, quantization_config=quant_config)
model = accelerator.prepare(model)

Using device: cuda




In [110]:

inputs = processor(
    text=["80s pop track with bassy drums and synth and dominant piano"],
    return_tensors="pt",
    padding=True
)

for k in inputs:
    inputs[k] = inputs[k].to(accelerator.device)

inputs

{'input_ids': tensor([[ 2775,     7,  2783,  1463,    28,  7981,    63,  5253,     7,    11,
         13353,    11, 12613,  8355,     1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [111]:
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)

In [112]:
from IPython.display import Audio
audio = Audio(data=audio_values[0][0].cpu(), rate=32000)
display(audio)

# Data processing