In [None]:
from tqdm.auto import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import AutoModel



class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False

def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1,
    max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0):

    # see https://huggingface.co/m-a-p/MusiLingo-musicqa-v1 for load_audio function definition
    audio = load_audio(audio_path, target_sr=24000,
                        is_mono=True,
                        is_normalize=False,
                        crop_to_length_in_sample_points=int(30*16000)+1,
                        crop_randomly=True, 
                        pad=False).cuda()    
    processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True) 
    audio = processor(audio, 
                    sampling_rate=24000, 
                    return_tensors="pt")['input_values'][0].cuda() 
        
    audio_embeds, atts_audio = model.encode_audio(audio)
        
    prompt = '<Audio><AudioHere></Audio> ' + text
    instruction_prompt = [model.prompt_template.format(prompt)]
    audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
    
    model.llama_tokenizer.padding_side = "right"
    batch_size = audio_embeds.shape[0]
    bos = torch.ones([batch_size, 1],
                    dtype=torch.long,
                    device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
    bos_embeds = model.llama_model.model.embed_tokens(bos)
    # atts_bos = atts_audio[:, :1]
    inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
    # attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
    outputs = model.llama_model.generate(
        inputs_embeds=inputs_embeds,
        max_new_tokens=max_new_tokens,
        stopping_criteria=stopping,
        num_beams=num_beams,
        do_sample=True,
        min_length=min_length,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        temperature=temperature,
    )
    output_token = outputs[0]
    if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
        output_token = output_token[1:]
    if output_token[0] == 1:  # if there is a start token <s> at the beginning. remove it
        output_token = output_token[1:]
    output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
    output_text = output_text.split('###')[0]  # remove the stop sign '###'
    output_text = output_text.split('Assistant:')[-1].strip()
    return output_text

musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-short-v1", trust_remote_code=True)
musilingo.to("cuda")
musilingo.eval()

prompt = "describe this song"
audio = "data/wav_files/wav-48/FFQVVwFjy7s.wav"
stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
                                  torch.tensor([2277, 29937]).cuda()])])
response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1)




model.safetensors.index.json:   0%|          | 0.00/75.4k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

In [4]:
!pip install timm

Collecting timm
  Downloading timm-1.0.12-py3-none-any.whl.metadata (51 kB)
Downloading timm-1.0.12-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m65.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: timm
Successfully installed timm-1.0.12
