### 加载模型

In [1]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

model_name = "/root/autodl-tmp/musicgen-large"  # 可选：small, medium, large
# 初次使用记得去掉local_files_only=True
processor = AutoProcessor.from_pretrained(model_name, local_files_only=True)
base_model = MusicgenForConditionalGeneration.from_pretrained(model_name, local_files_only=True).half().to(device)
# model.half()解决精度问题报错

from peft import PeftModel

# lora_path = "./outputs/musicgen-lora/checkpoint-1600"
# lora_path = "./outputs/musicgen-lora/initial_lora"
# model = PeftModel.from_pretrained("/root/autodl-tmp/musicgen-large", lora_path)

device: cuda


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### 生成

In [None]:
inputs = processor(
    text=[
        "90s rock song but with guzheng and erhu",
        #"Indian classical music with sitar and tabla",
        # "Traditional Chinese music with guzheng and flute",
        # "Hip-hop beats with 808 drums and synth",
        # "80s pop track with bassy drums and synth", 
        # "90s rock song with loud guitars and heavy drums", 
        # "This music is an intense instrumental.The tempo is fast with vigorous violin harmony that slows down to the accompaniment of a grim Piano harmony. The music is a Ritardando and has a grim, dark, intense,serious,bleak, dreary, and dangerous vibe to it. The chord sequence is Em7b5/D, D, Dm. The beat counts to 2. The tempo of this song is 169.0 beats per minute. The key is D minor."
        ],
    padding=True,
    return_tensors="pt",
).to(device)

inputs

{'input_ids': tensor([[ 2557, 11702,   723,    28,   108,  2046,    11,  3808,   521,     1]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [17]:
audio_values = base_model.generate(
    **inputs, 
    max_new_tokens=256, 
    )
audio_values

tensor([[[-2.8580e-02, -3.4302e-02, -2.4521e-02,  ..., -1.8444e-03,
          -7.2021e-03,  2.0027e-05]]], device='cuda:0', dtype=torch.float16)

In [18]:
from IPython.display import Audio

sampling_rate = base_model.config.audio_encoder.sampling_rate
print("Sampling rate:", sampling_rate)

Sampling rate: 32000


### 试听

In [19]:
Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)

In [20]:
Audio(audio_values[1].cpu().numpy(), rate=sampling_rate)

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [15]:
Audio(audio_values[2].cpu().numpy(), rate=sampling_rate)

### 保存到文件

In [8]:
import scipy
import numpy as np

sampling_rate = base_model.config.audio_encoder.sampling_rate
print("Sampling rate:", sampling_rate)
for i in range(len(audio_values)):
    scipy.io.wavfile.write(f"outputs/musicgen_out_{i}.wav", rate=sampling_rate, data=np.asarray(audio_values[0, 0].cpu(), dtype=np.float32))

Sampling rate: 32000
