In [None]:
from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import AutoModel


In [None]:
musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-short-v1", trust_remote_code=True)
musilingo.to("cuda")

In [None]:
class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False

def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1,
    max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0):

    # see https://huggingface.co/m-a-p/MusiLingo-musicqa-v1 for load_audio function definition
    audio = load_audio(audio_path, target_sr=24000,
                        is_mono=True,
                        is_normalize=False,
                        crop_to_length_in_sample_points=int(30*16000)+1,
                        crop_randomly=True,
                        pad=False).cuda()
    processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
    audio = processor(audio,
                    sampling_rate=24000,
                    return_tensors="pt")['input_values'][0].cuda()

    audio_embeds, atts_audio = model.encode_audio(audio)

    prompt = '<Audio><AudioHere></Audio> ' + text
    instruction_prompt = [model.prompt_template.format(prompt)]
    audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)

    model.llama_tokenizer.padding_side = "right"
    batch_size = audio_embeds.shape[0]
    bos = torch.ones([batch_size, 1],
                    dtype=torch.long,
                    device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
    bos_embeds = model.llama_model.model.embed_tokens(bos)
    # atts_bos = atts_audio[:, :1]
    inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
    # attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
    outputs = model.llama_model.generate(
        inputs_embeds=inputs_embeds,
        max_new_tokens=max_new_tokens,
        stopping_criteria=stopping,
        num_beams=num_beams,
        do_sample=True,
        min_length=min_length,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        temperature=temperature,
    )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
        output_token = output_token[1:]
    if output_token[0] == 1:  # if there is a start token <s> at the beginning. remove it
        output_token = output_token[1:]
    output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('###')[0]  # remove the stop sign '###'
    output_text = output_text.split('Assistant:')[-1].strip()
    return output_text

In [None]:
def load_audio(
    file_path,
    target_sr,
    is_mono=True,
    is_normalize=False,
    crop_to_length_in_sec=None,
    crop_to_length_in_sample_points=None,
    crop_randomly=False,
    pad=False,
    return_start=False,
    device=torch.device('cpu')
):
    """Load audio file and convert to target sample rate.
    Supports cropping and padding.

    Args:
        file_path (str): path to audio file
        target_sr (int): target sample rate, if not equal to sample rate of audio file, resample to target_sr
        is_mono (bool, optional): convert to mono. Defaults to True.
        is_normalize (bool, optional): normalize to [-1, 1]. Defaults to False.
        crop_to_length_in_sec (float, optional): crop to specified length in seconds. Defaults to None.
        crop_to_length_in_sample_points (int, optional): crop to specified length in sample points. Defaults to None. Note that the crop length in sample points is calculated before resampling.
        crop_randomly (bool, optional): crop randomly. Defaults to False.
        pad (bool, optional): pad to specified length if waveform is shorter than specified length. Defaults to False.
        device (torch.device, optional): device to use for resampling. Defaults to torch.device('cpu').

    Returns:
        torch.Tensor: waveform of shape (1, n_sample)
    """
    # TODO: deal with target_depth
    try:
        waveform, sample_rate = torchaudio.load(file_path)
    except Exception as e:
        waveform, sample_rate = torchaudio.backend.soundfile_backend.load(file_path)
    if waveform.shape[0] > 1:
        if is_mono:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

    if is_normalize:
        waveform = waveform / waveform.abs().max()

    waveform, start = crop_audio(
        waveform,
        sample_rate,
        crop_to_length_in_sec=crop_to_length_in_sec,
        crop_to_length_in_sample_points=crop_to_length_in_sample_points,
        crop_randomly=crop_randomly,
        pad=pad,
    )

    if sample_rate != target_sr:
        resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
        waveform = waveform.to(device)
        resampler = resampler.to(device)
        waveform = resampler(waveform)

    if return_start:
        return waveform, start
    return waveform

def crop_audio(
    waveform,
    sample_rate,
    crop_to_length_in_sec=None,
    crop_to_length_in_sample_points=None,
    crop_randomly=False,
    pad=False,
):
    """Crop waveform to specified length in seconds or sample points.
    Supports random cropping and padding.

    Args:
        waveform (torch.Tensor): waveform of shape (1, n_sample)
        sample_rate (int): sample rate of waveform
        crop_to_length_in_sec (float, optional): crop to specified length in seconds. Defaults to None.
        crop_to_length_in_sample_points (int, optional): crop to specified length in sample points. Defaults to None.
        crop_randomly (bool, optional): crop randomly. Defaults to False.
        pad (bool, optional): pad to specified length if waveform is shorter than specified length. Defaults to False.

    Returns:
        torch.Tensor: cropped waveform
        int: start index of cropped waveform in original waveform
    """
    assert crop_to_length_in_sec is None or crop_to_length_in_sample_points is None, \
    "Only one of crop_to_length_in_sec and crop_to_length_in_sample_points can be specified"

    # convert crop length to sample points
    crop_duration_in_sample = None
    if crop_to_length_in_sec:
        crop_duration_in_sample = int(sample_rate * crop_to_length_in_sec)
    elif crop_to_length_in_sample_points:
        crop_duration_in_sample = crop_to_length_in_sample_points

    # crop
    start = 0
    if crop_duration_in_sample:
        if waveform.shape[-1] > crop_duration_in_sample:
            if crop_randomly:
                start = random.randint(0, waveform.shape[-1] - crop_duration_in_sample)
            waveform = waveform[..., start:start + crop_duration_in_sample]

        elif waveform.shape[-1] < crop_duration_in_sample:
            if pad:
                waveform = torch.nn.functional.pad(waveform, (0, crop_duration_in_sample - waveform.shape[-1]))

    return waveform, start

In [None]:
import torchaudio
prompt = "describe this song"
audio_path = "/content/_8OIugVSFeE.wav"
stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
                                  torch.tensor([2277, 29937]).cuda()])])
response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1)

In [3]:
musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-short-v1", trust_remote_code=True)

In [11]:
musilingo.eval()

MusilingoModel(
  (model): MusiLingo(
    (audio_encoder): MERTModel(
      (feature_extractor): HubertFeatureEncoder(
        (conv_layers): ModuleList(
          (0): HubertGroupNormConvLayer(
            (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
            (activation): GELUActivation()
            (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
          )
          (1-4): 4 x HubertNoLayerNormConvLayer(
            (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
            (activation): GELUActivation()
          )
          (5-6): 2 x HubertNoLayerNormConvLayer(
            (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
            (activation): GELUActivation()
          )
        )
      )
      (feature_projection): MERTFeatureProjection(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (projection): Linear(in_features=512, out_features=1024, bias=True)
        (dro

In [10]:
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from transformers import AutoModel

to_exclude = ["lwdDm3UO5WM", "sETUDPPoDuo", "W58kioYp1Ms"]
data_path = "data/wav_files/wav-48"
train_data = pd.read_csv("train_labels.csv")[["ytid", "caption"]]
train_data = train_data[~train_data['ytid'].isin(to_exclude)]
train_data["ytid"] = [f"{data_path}/{filename}.wav" for filename in train_data["ytid"]]

test_data = pd.read_csv("test_labels.csv")[["ytid", "caption"]]
test_data = test_data[~test_data['ytid'].isin(to_exclude)]
test_data["ytid"] = [f"{data_path}/{filename}.wav" for filename in test_data["ytid"]]

class AudioTextDataset(Dataset):
    def __init__(self, audio_paths, targets, processor):
        self.audio_paths = audio_paths
        self.targets = targets
        self.processor = processor

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        waveform, sample_rate = torchaudio.load(self.audio_paths[idx])
    
        if waveform.size(0) > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        if sample_rate != self.resample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)
            waveform = resampler(waveform)
            
        waveform = torch.nn.functional.pad(waveform, (0, 240000 - len(waveform[0])), mode="constant", value=0)
        
        audio_input = self.processor(waveform.squeeze().numpy(), sampling_rate=24000, return_tensors="pt")

        
        audio = self.processor(audio, 
                        sampling_rate=24000, 
                        return_tensors="pt")['input_values'][0].cuda() 
            
        audio_embeds, atts_audio = model.encode_audio(audio)
            
        prompt = '<Audio><AudioHere></Audio> ' + text
        instruction_prompt = [model.prompt_template.format(prompt)]
        audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
        
        model.llama_tokenizer.padding_side = "right"
        batch_size = audio_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                        dtype=torch.long,
                        device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
        bos_embeds = model.llama_model.model.embed_tokens(bos)
        # atts_bos = atts_audio[:, :1]
        inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
        
        prompt = self.prompts[idx]
        target = self.targets[idx]
        
        # Tokenize text
        tokenized_prompt = self.tokenizer("describe this song", return_tensors="pt", padding=True)
        tokenized_target = self.tokenizer(target, return_tensors="pt", padding=True)

        return audio, tokenized_prompt, tokenized_target


# Dataset
dataset = AudioTextDataset(list(train_data["ytid"]), list(train_data["caption"]), musilingo.llama_tokenizer)

# DataLoader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Training Setup
optimizer = optim.AdamW(musilingo.parameters(), lr=5e-5)
loss_fn = nn.CrossEntropyLoss()

# Training Loop
for epoch in range(3):  # Number of epochs
    musilingo.train()
    for batch in dataloader:
        audio, tokenized_prompt, tokenized_target = batch

        # Forward pass
        audio_embeds, atts_audio = musilingo.model.encode_audio(audio)
        inputs_embeds = musilingo.model.instruction_prompt_wrap(audio_embeds, atts_audio, tokenized_prompt)

        # Generate predictions
        outputs = musilingo.model.llama_model(
            inputs_embeds=inputs_embeds["inputs_embeds"],
            attention_mask=inputs_embeds["attention_mask"],
            labels=tokenized_target["input_ids"]
        )

        # Compute loss
        loss = outputs.loss

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} - Loss: {loss.item()}")
    torch.save(model.state_dict(), f"musilingo_weights_{epoch}.pth")




Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:  23%|##3       | 1.15G/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.82G [00:00<?, ?B/s]

Loading Audio Encoder
Loading Audio Encoder Done
Loading LLAMA


tokenizer_config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message


config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

LlamaForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

Loading LLAMA Done


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

AttributeError: 'MusilingoModel' object has no attribute 'llama_tokenizer'