In [1]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")

inputs = processor(
    text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
    padding=True,
    return_tensors="pt",
)

pad_token_id = model.generation_config.pad_token_id
decoder_input_ids = (
    torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)
    * pad_token_id
)

logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits
logits.shape  # (bsz * num_codebooks, tgt_len, vocab_size)

  WeightNorm.apply(module, name, dim)
  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)


torch.Size([8, 1, 2048])

In [19]:
from datasets import load_dataset

MUSICGEN_NUM_DECODER_HIDDEN_LAYERS = 25
MUSICGEN_HIDDEN_SIZE = 1024

def get_decoder_hidden_states(audio_batch, text_batch, sampling_rate, mock=False, device='cpu'):
    if mock:
        return torch.randn(len(audio_batch), MUSICGEN_NUM_DECODER_HIDDEN_LAYERS, MUSICGEN_HIDDEN_SIZE)
    inputs = processor(
        text=text_batch,
        audio=audio_batch,
        sampling_rate=sampling_rate,
        padding=True,
        return_tensors="pt",
    )

    pad_token_id = model.generation_config.pad_token_id
    decoder_input_ids = (
        torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)
        * pad_token_id
    ).to(device)
    # put inputs on device
    inputs = inputs.to(device)
    output = model(**inputs, decoder_input_ids=decoder_input_ids, output_hidden_states=True)

    return torch.stack(output.decoder_hidden_states).transpose(0, 1).squeeze(2) # (num_layers, batch_size, 1, hidden_size) -> (batch_size, num_layers, hidden_size)

In [14]:
import os
from datasets import load_dataset
import h5py
import torchaudio
import torch
from pytube import YouTube
from tqdm import tqdm

def load_and_preprocess_audio(file_path):
    # Load a local music file using torchaudio
    waveform, sampling_rate = torchaudio.load(file_path)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    return waveform.squeeze().numpy(), sampling_rate

def process_and_save_embeddings(output_file, batch_size=32, device='cpu'):
    model.to(device)
    dataset = load_dataset("google/MusicCaps", split="train")
    # now, filter the dataset for only records where ./music_data/{ytid}.mp3 exists
    dataset = dataset.filter(lambda x: os.path.exists(f"./music_data/{x['ytid']}.wav"))
    print(f"Filtered dataset to {len(dataset)} samples")
    
    with h5py.File(output_file, 'w') as f:
        dset = None  # We'll create this once we know the embedding shape
        
        for i, batch in tqdm(enumerate(dataset.iter(batch_size)), desc="Processing batches"):
            text_batch = batch['caption']
            audio_batch = []
            for ytid in batch['ytid']:
                # load the audio file from ./music_data/{ytid}.wav
                waveform, sampling_rate = load_and_preprocess_audio(f"./music_data/{ytid}.wav")
                audio_batch.append(waveform)
            
            hidden_states = get_decoder_hidden_states(audio_batch, text_batch, sampling_rate, device=device, mock=False)
            
            if dset is None:
                # Create the dataset once we know the shape
                dset = f.create_dataset('embeddings', 
                                        shape=(0,) + hidden_states.shape[1:],
                                        maxshape=(None,) + hidden_states.shape[1:],
                                        chunks=True, compression='gzip')
            
            # Resize the dataset and add new data
            dset.resize(dset.shape[0] + hidden_states.shape[0], axis=0)
            dset[-hidden_states.shape[0]:] = hidden_states.cpu().numpy()
            
            if (i+1) % 10 == 0:
                print(f"Processed {(i+1)*batch_size} samples")

# Usage
process_and_save_embeddings('musiccaps_embeddings.h5', batch_size=32, device='cuda')

Filtered dataset to 696 samples


Processing batches: 10it [00:06,  1.42it/s]

Processed 320 samples


Processing batches: 20it [00:13,  1.47it/s]

Processed 640 samples


Processing batches: 22it [00:14,  1.49it/s]


In [15]:
# load the embeddings
with h5py.File('musiccaps_embeddings.h5', 'r') as f:
    embeddings = f['embeddings'][:]
    print(embeddings.shape)

(696, 25, 1, 1024)
