## 1. Install dependencies

In [None]:
!pip install librosa pretty_midi jams music21 datasets huggingface_hub wandb

## Data

In [None]:
from datasets import load_dataset, concatenate_datasets

urmp_dataset = load_dataset("jonflynn/urmp_jukebox_embeddings_qa")
musicnet_dataset = load_dataset("jonflynn/musicnet_jukebox_embeddings_abc")

urmp_train = urmp_dataset['train']
musicnet_train = musicnet_dataset['train']

combined_dataset = concatenate_datasets([urmp_train, musicnet_train])

## Model

### With Unsloth

In [None]:
%%capture
!pip install unsloth
# Also get the latest nightly Unsloth!
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from unsloth import FastLanguageModel
from transformers import AutoProcessor, MusicgenConfig, MusicgenForConditionalGeneration
from transformers import AutoConfig
from torch import Tensor
import math
import json
import os

class CustomLlarkModel(nn.Module):
    def __init__(self, model_name, model_type, device, use_lora=True):
        super(CustomLlarkModel, self).__init__()

        self.model_type = model_type
        self.device = device
        self.target_sr = 44100

        # Define special tokens for audio start and end
        self.AUDIO_START_TOKEN = "<AUDIO_START>"
        self.AUDIO_END_TOKEN = "<AUDIO_END>"

        max_seq_length = 4096
        dtype = torch.bfloat16  # Automatically detect dtype
        self.language_model, self.tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_name,
            max_seq_length=max_seq_length,
            dtype=dtype,
            load_in_4bit=False,
            trust_remote_code=True
        )

        if use_lora:
          self.language_model = FastLanguageModel.get_peft_model(
              self.language_model,
              r=128,
              target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                                "gate_proj", "up_proj", "down_proj"],
              lora_alpha=256,
              lora_dropout=0,
              bias="none",
              use_gradient_checkpointing=True,
              random_state=3407
          )

        # Add special tokens to tokenizer
        self.tokenizer.add_tokens([self.AUDIO_START_TOKEN, self.AUDIO_END_TOKEN])
        self.language_model.resize_token_embeddings(len(self.tokenizer))

        # Get token ids for special tokens
        self.audio_start_token_id = self.tokenizer.convert_tokens_to_ids(self.AUDIO_START_TOKEN)
        self.audio_end_token_id = self.tokenizer.convert_tokens_to_ids(self.AUDIO_END_TOKEN)

        # Projection layer to map audio embeddings to language model space
        self.audio_projection = nn.Linear(4800, self.language_model.config.hidden_size).to(device)

        if self.model_type == "llama":
          self.start_of_header_token = 128006
          self.eot_token = 128001
        elif self.model_type == "gemma":
          self.start_of_turn_token = 106
          self.eot_token = self.tokenizer.eos_token_id
        elif self.model_type == "qwen2":
          self.start_of_turn_token = 151644
          self.eot_token = self.tokenizer.pad_token_id

    def get_prompt(self, query, answer):
        if self.model_type == "llama":
            return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
            You are a helpful AI assistant, you're given audio encoded as a sequence of tokens below and must transcribe it precisely.
            {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN}<|eot_id|><|start_header_id|>user<|end_header_id|>
            {query}<|eot_id|>
            <|start_header_id|>assistant<|end_header_id|>
            {answer}<|eot_id|><|end_of_text|>"""
        elif self.model_type == "gemma":
            return f"""<bos><start_of_turn>user
            You're given audio encoded as a sequence of 300 tokens: {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN} {query}<end_of_turn>
            <start_of_turn>model
            {answer}<end_of_turn><eos>"""
        elif self.model_type == "qwen2":
            return f"""<|im_start|>system
            You are a helpful AI assistant, you're given the following audio encoded as a sequence of 300 tokens and must transcribe it precisely. {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN}<|im_end|>
            <|im_start|>user
            {query}<|im_end|>
            <|im_start|>assistant
            {answer}<|im_end|><|endoftext|>"""

    def get_query_prompt(self, query):
        if self.model_type == "llama":
            return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
            You are a helpful AI assistant, you're given audio encoded as a sequence of tokens below and must transcribe it precisely.
            {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN}<|eot_id|><|start_header_id|>user<|end_header_id|>
            {query}<|eot_id|>
            <|start_header_id|>assistant<|end_header_id|>"""
        elif self.model_type == "gemma":
            return f"""<bos><start_of_turn>user
            You're given audio encoded as a sequence of tokens: {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN} {query}<end_of_turn>
            <start_of_turn>model"""
        elif self.model_type == "qwen2":
            return f"""<|im_start|>system
            You are a helpful AI assistant, you're given the following audio encoded as a sequence of tokens and must transcribe it precisely. {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN}<|im_end|>
            <|im_start|>user
            {query}<|im_end|>
            <|im_start|>assistant"""

    def forward(self, **kwargs):
        audio_embedding = kwargs.get('embedding')
        queries = kwargs.get('query')
        answers = kwargs.get('answer')

        audio_embedding = torch.tensor(audio_embedding, dtype=torch.bfloat16).to(self.device)

        # Project audio embeddings to match Llama's hidden size
        audio_features = self.audio_projection(audio_embedding).to(self.device)

        prompts = [self.get_prompt(query, answer) for query, answer in zip(queries, answers)]

        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"

        # Tokenize the input without padding
        tokenizer_output = self.tokenizer(
            prompts,
            return_tensors='pt',
            truncation=True,
            padding=True,
            # max_length=max_length_without_audio,
            # pad_to_multiple_of=max_length_without_audio
        )

        input_ids = tokenizer_output['input_ids'].to(self.device)
        attention_mask = tokenizer_output['attention_mask'].to(self.device)

        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
        number_of_audio_tokens = audio_embedding.shape[1]

        batch_size, sequence_length = input_ids.shape
        max_seq_length = sequence_length + number_of_audio_tokens

        new_inputs_embeds = torch.zeros(
            (batch_size, max_seq_length, inputs_embeds.size(2)),
            device=inputs_embeds.device
        )
        new_attention_mask = torch.zeros((batch_size, max_seq_length), device=inputs_embeds.device)

        # Insert the projected audio embeddings between the AUDIO_START_TOKEN and AUDIO_END_TOKEN for each sample in the batch
        for i in range(batch_size):
            start_pos = (input_ids[i] == self.audio_start_token_id).nonzero(as_tuple=True)[0]
            end_pos = (input_ids[i] == self.audio_end_token_id).nonzero(as_tuple=True)[0]

            if not len(start_pos) == 1:
                raise ValueError(f"Incorrect number of audio start tokens in the input. Got {len(start_pos)} tokens")

            if not len(end_pos) == 1:
                raise ValueError(f"Incorrect number of audio end tokens in the input. Got {len(end_pos)} tokens")

            if start_pos.size(0) > 0 and end_pos.size(0) > 0:
                start_pos = start_pos[0].item()
                end_pos = end_pos[0].item()

                # Create the new embedding sequence
                part1 = inputs_embeds[i, :start_pos + 1]
                part2 = audio_features[i]
                part3 = inputs_embeds[i, end_pos:]

                new_embed = torch.cat((part1, part2, part3), dim=0)

                new_inputs_embeds[i] = new_embed

                # Adjust attention mask for the inserted audio embeddings
                new_attention_mask[i] = torch.cat((
                    attention_mask[i, :start_pos + 1],
                    torch.ones(number_of_audio_tokens, device=inputs_embeds.device),
                    attention_mask[i, end_pos:]
                ), dim=0)

        position_ids = (new_attention_mask.cumsum(-1) - 1).masked_fill_((new_attention_mask == 0), 1).long()

        labels_list = []

        for i in range(input_ids.size(0)):
            # Create a copy of input_ids to serve as the base for labels
            sample_labels = torch.full_like(input_ids[i], -100)

            if self.model_type == "llama":
                # Get the third start_of_header_token which is the one where the assistant's response starts
                assistant_start_pos = (input_ids[i] == self.start_of_header_token).nonzero(as_tuple=True)[0][2].item()
            elif self.model_type == "gemma":
                # Get the second start_of_turn_token which is the one where the model's response starts then +1
                assistant_start_pos = (input_ids[i] == self.start_of_turn_token).nonzero(as_tuple=True)[0][1].item() + 1
            elif self.model_type == "qwen2":
                # Get the third start_of_turn_token which is the one where the model's response starts then +1
                assistant_start_pos = (input_ids[i] == self.start_of_turn_token).nonzero(as_tuple=True)[0][2].item() + 1

            eot_pos = (input_ids[i] == self.eot_token).nonzero(as_tuple=True)[0][0].item()

            # Fill in the input_ids for the assistant response part
            sample_labels[assistant_start_pos:eot_pos] = input_ids[i, assistant_start_pos:eot_pos]

            # Find AUDIO_START_TOKEN and AUDIO_END_TOKEN positions
            audio_start_pos = (input_ids[i] == self.audio_start_token_id).nonzero(as_tuple=True)[0].item()
            audio_end_pos = (input_ids[i] == self.audio_end_token_id).nonzero(as_tuple=True)[0].item()

            # Insert padding values between AUDIO_START_TOKEN and AUDIO_END_TOKEN for each sample separately
            sample_labels = torch.cat((
                sample_labels[:audio_start_pos + 1],
                torch.full((number_of_audio_tokens,), -100, dtype=input_ids.dtype, device=input_ids.device),
                sample_labels[audio_start_pos + 1:audio_end_pos + 1],
                sample_labels[audio_end_pos + 1:]
            ), dim=0)

            # Verify the length after concatenation
            expected_length = input_ids.size(1) + number_of_audio_tokens
            if sample_labels.size(0) != expected_length:
                raise ValueError(f"Concatenation error: expected length {expected_length} but got {sample_labels.size(0)}")

            labels_list.append(sample_labels)

        labels = torch.stack(labels_list)

        print(f"Sequence shape: {new_inputs_embeds.shape}")
        print(f"Number of tokens used as labels: {(labels != -100).sum().item()}")

        assert max_seq_length == position_ids.size(1) == new_inputs_embeds.size(1) == labels.size(1) == new_attention_mask.size(1), "position_ids, new_inputs_embeds, new_labels and new_attention_mask must have the same sequence length equal to the max_seq_length"

        outputs = self.language_model(
            inputs_embeds=new_inputs_embeds,
            position_ids=position_ids,
            attention_mask=new_attention_mask,
            labels=labels,
            return_dict=True
        )

        if 'loss' in outputs:
            print(f"Loss: {outputs.loss.item()}")

        return outputs

## Training

### Load model

In [None]:
from huggingface_hub import login

# Log in to Hugging Face
login(token='hf_token')

In [None]:
# model = CustomLlarkModel("unsloth/gemma-2b-it", "gemma", "cuda", use_lora=True)
model = CustomLlarkModel("Qwen/Qwen2-7B-Instruct", "qwen2", "cuda", use_lora=True)
# model = CustomLlarkModel("unsloth/Llama-3.1-Storm-8B", "llama", "cuda", use_lora=True)

### Load wandb

In [None]:
import wandb, os
wandb.login()

wandb_project = "llark"
if len(wandb_project) > 0:
    os.environ["WANDB_PROJECT"] = wandb_project

In [None]:
project = "LLarK"
run_name = "run"
project_and_run_name = project + "-" + run_name
output_dir = "./" + project_and_run_name

### Train with SFTtrainer

In [None]:
from transformers import Trainer, TrainingArguments
from datetime import datetime

wandbname = project + "-" + run_name

import torch
import torch.nn.functional as F

def custom_collate_fn(batch):
    queries = [item['question'] for item in batch]
    answers = [item['answer'] for item in batch]
    embeddings = [item['embedding'] for item in batch]

    # can't return tensors here or unsloth has weird error
    return {
        'query': queries,
        'answer': answers,
        'embedding': embeddings
    }

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=1,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 1,
    warmup_ratio = 0.1,
    logging_dir='./logs',
    learning_rate = 2e-5,
    logging_steps = 1,
    #eval_strategy="epoch",
    #eval_steps=75,
    #save_steps=100,
    max_grad_norm=10.0,
    fp16 = not torch.cuda.is_bf16_supported(),
    bf16 = torch.cuda.is_bf16_supported(),
    optim = "adamw_8bit",
    weight_decay = 0.001,
    seed = 3407,
    save_strategy="no",
    #save_strategy="epoch",
    lr_scheduler_type = "cosine",
    #load_best_model_at_end=True,
    remove_unused_columns=False,
    report_to="wandb",
    run_name=f"{wandbname}-{datetime.now().strftime('%m-%d-%H-%M')}"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=combined_dataset,
    data_collator=custom_collate_fn
)

trainer.train()

## Save model and checkpoints

In [None]:
torch.save(model.audio_projection.state_dict(), "/content/llark_multi_modal_projector_weights.pth")
model.language_model.save_pretrained_merged("model_16bit_merged", model.tokenizer, save_method = "merged_16bit",)

In [None]:
!cp -r "/content/llark_multi_modal_projector_weights.pth" "/content/drive/My Drive/automatic-music-transcription/saved_models/"
!cp -r "/content/model_16bit_merged" "/content/drive/My Drive/automatic-music-transcription/saved_models/"

## Try model

### Without Unsloth for inference

Unsloth doesn't support passing in `inputs_embeds` to the `generate` function which we need to do to accommodate the audio tokens so instead we use just `transformers` for inference

https://github.com/unslothai/unsloth/issues/862

Load in full 32bit otherwise there's errors

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig, MusicgenConfig, MusicgenForConditionalGeneration
from torch import Tensor
import math
from peft import PeftModel

class CustomLlarkModel(nn.Module):
    def __init__(self, language_model_name, model_type, device, use_lora=True):
        super(CustomLlarkModel, self).__init__()

        self.model_type = model_type
        self.device = device

        # Initialize tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(language_model_name)
        self.language_model = AutoModelForCausalLM.from_pretrained(language_model_name, device_map="auto").to(self.device)
        self.language_model.gradient_checkpointing_disable()

        # Define special tokens for audio start and end
        self.AUDIO_START_TOKEN = "<AUDIO_START>"
        self.AUDIO_END_TOKEN = "<AUDIO_END>"

        self.audio_start_token_id = self.tokenizer.convert_tokens_to_ids(self.AUDIO_START_TOKEN)
        self.audio_end_token_id = self.tokenizer.convert_tokens_to_ids(self.AUDIO_END_TOKEN)

        self.audio_projection = nn.Linear(4800, self.language_model.config.hidden_size).to(device)

        if self.model_type == "llama":
          self.start_of_header_token = 128006
          self.eot_token = 128001
        elif self.model_type == "gemma":
          self.start_of_turn_token = 106
          self.eot_token = self.tokenizer.eos_token_id # 1
        elif self.model_type == "qwen2":
          self.start_of_turn_token = 151644
          self.eot_token = self.tokenizer.pad_token_id # 1

    def get_prompt(self, query):
        if self.model_type == "llama":
            return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
            You are a helpful AI assistant, you're given audio encoded as a sequence of 300 tokens below and must transcribe it precisely.
            {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN}<|eot_id|><|start_header_id|>user<|end_header_id|>
            {query}<|eot_id|>
            <|start_header_id|>assistant<|end_header_id|>"""
        elif self.model_type == "gemma":
            return f"""<bos><start_of_turn>user
            You're given audio encoded as a sequence of 300 tokens: {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN} {query}<end_of_turn>
            <start_of_turn>model"""
        elif self.model_type == "qwen2":
            return f"""<|im_start|>system
            You are a helpful AI assistant, you're given the following audio encoded as a sequence of 300 tokens and must transcribe it precisely. {self.AUDIO_START_TOKEN}{self.AUDIO_END_TOKEN}<|im_end|>
            <|im_start|>user
            {query}<|im_end|>
            <|im_start|>assistant"""

    def load_trained_weights(self, projector_weights_path):
        projector_state_dict = torch.load(projector_weights_path)
        self.audio_projection.load_state_dict(projector_state_dict)

        print("Trained weights loaded successfully.")

    def generate(self, audio_embedding, query, max_new_tokens=4096, num_beams=1, do_sample=False, top_k=None, top_p=None, temperature=1.0):
        audio_embedding = torch.tensor([audio_embedding], dtype=torch.bfloat16).to(self.device)
        audio_features = self.audio_projection(audio_embedding).to(self.device)
        number_of_audio_tokens = audio_features.shape[1]

        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"

        prompt = self.get_prompt(query)

        # Tokenize the input without padding
        tokenizer_output = self.tokenizer(
            prompt,
            return_tensors='pt',
            truncation=True,
        )

        input_ids = tokenizer_output['input_ids'].to(self.device)
        attention_mask = tokenizer_output['attention_mask'].to(self.device)

        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)

        batch_size, sequence_length = input_ids.shape
        max_seq_length = sequence_length + number_of_audio_tokens

        new_inputs_embeds = torch.zeros(
            (batch_size, max_seq_length, inputs_embeds.size(2)),
            device=inputs_embeds.device
        )
        new_attention_mask = torch.zeros((batch_size, max_seq_length), device=inputs_embeds.device)

        # Insert the projected audio embeddings between the AUDIO_START_TOKEN and AUDIO_END_TOKEN for each sample in the batch
        for i in range(batch_size):
            start_pos = (input_ids[i] == self.audio_start_token_id).nonzero(as_tuple=True)[0]
            end_pos = (input_ids[i] == self.audio_end_token_id).nonzero(as_tuple=True)[0]

            if not len(start_pos) == 1:
                raise ValueError(f"Incorrect number of audio start tokens in the input. Got {len(start_pos)} tokens")

            if not len(end_pos) == 1:
                raise ValueError(f"Incorrect number of audio end tokens in the input. Got {len(end_pos)} tokens")

            if start_pos.size(0) > 0 and end_pos.size(0) > 0:
                start_pos = start_pos[0].item()
                end_pos = end_pos[0].item()

                # Create the new embedding sequence
                part1 = inputs_embeds[i, :start_pos + 1]
                part2 = audio_features[i]
                part3 = inputs_embeds[i, end_pos:]

                new_embed = torch.cat((part1, part2, part3), dim=0)

                new_inputs_embeds[i] = new_embed

                # Adjust attention mask for the inserted audio embeddings
                new_attention_mask[i] = torch.cat((
                    attention_mask[i, :start_pos + 1],
                    torch.ones(number_of_audio_tokens, device=inputs_embeds.device),
                    attention_mask[i, end_pos:]
                ), dim=0)

        position_ids = (new_attention_mask.cumsum(-1) - 1).masked_fill_((new_attention_mask == 0), 1).long()

        # Ensure the `inputs_embeds` are used in the first step of generation
        generation_params = {
            "inputs_embeds": new_inputs_embeds,
            "attention_mask": new_attention_mask,
            "max_new_tokens": max_new_tokens,
            # "position_ids": position_ids,    generate() currently breaks if passed in `position_ids`
            "num_beams": num_beams,
            "use_cache": False,
            "do_sample": do_sample,
            "temperature": temperature,
        }

        if top_k is not None:
            generation_params["top_k"] = top_k
        if top_p is not None:
            generation_params["top_p"] = top_p

        output_ids = self.language_model.generate(**generation_params)
        return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

In [None]:
!cp -r "/content/drive/My Drive/automatic-music-transcription/saved_models/model_16bit_merged/" "/content/"
!cp -r "/content/drive/My Drive/automatic-music-transcription/saved_models/llark_multi_modal_projector_weights.pth" "/content/llark_multi_modal_projector_weights.pth"

In [None]:
model = CustomLlarkModel("/content/model_16bit_merged", "qwen2", "cuda", use_lora=False)

In [None]:
model.load_trained_weights(projector_weights_path="/content/llark_multi_modal_projector_weights.pth")

In [None]:
audio_embedding = combined_dataset[0]['embedding']
query = combined_dataset[0]['question']

In [None]:
output_ids = model.generate(audio_embedding, query, max_new_tokens=2048, do_sample=True, top_p=0.7, temperature=1.2)
output_ids