## 1. Install dependencies

In [None]:
!pip install librosa pretty_midi wandb jams datasets huggingface_hub

In [None]:
!apt-get update
!apt-get install -y abcmidi

In [None]:
!abc2midi --version

## Data

**I have these datasets stored in my Google Drive**

### Download Maestro

In [None]:
%%time
!cp -r "/content/drive/My Drive/maestro/" "/content/maestro/"

### Download URMP

In [None]:
%%time
!cp "/content/drive/My Drive/automatic-music-transcription/URMP_Dataset.tar.gz" "/content/URMP_Dataset.tar.gz"
!mkdir -p /content/URMP_Dataset
!tar -xzvf /content/URMP_Dataset.tar.gz -C /content/URMP_Dataset
!rm /content/URMP_Dataset.tar.gz

## MIDI to ABC notation

In [None]:
import subprocess
from pathlib import Path
from typing import Optional

def remove_backslash_after_bar(text):
    # Replace occurrences of "| \" with just "|"
    modified_text = text.replace("| \\", "| ")
    return modified_text

def cmd_midi_to_abc(
    midi_path: str,
    extract_all_tracks: bool = False,
    verbose: bool = False,
    min_note_length: Optional[int] = None
) -> Optional[str]:
    if not Path(midi_path).is_file():
        raise FileNotFoundError(f"MIDI file not found: {midi_path}")

    cmd = ['midi2abc', midi_path]
    if extract_all_tracks:
        cmd.append('-xa')
    if verbose:
        cmd.append('-v')
    if min_note_length is not None:
        cmd.extend(['-mpl', str(min_note_length)])

    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        return result.stdout
    except subprocess.CalledProcessError as e:
        print(f"Error converting MIDI to ABC: {e}")
        print(f"Error output: {e.stderr}")
        return None

def remove_key_signature(abc_string):
    # Dictionary of key signatures and their affected notes (capitals for base notes)
    key_signatures = {
        # Major keys
        'Cb': ['B-', 'E-', 'A-', 'D-', 'G-', 'C-', 'F-'],
        'Gb': ['B-', 'E-', 'A-', 'D-', 'G-', 'C-'],
        'Db': ['B-', 'E-', 'A-', 'D-', 'G-'],
        'Ab': ['B-', 'E-', 'A-', 'D-'],
        'Eb': ['B-', 'E-', 'A-'],
        'Bb': ['B-', 'E-'],
        'F': ['B-'],
        'C': [],
        'G': ['F'],
        'D': ['F', 'C'],
        'A': ['F', 'C', 'G'],
        'E': ['F', 'C', 'G', 'D'],
        'B': ['F', 'C', 'G', 'D', 'A'],
        'F#': ['F', 'C', 'G', 'D', 'A', 'E'],
        'C#': ['F', 'C', 'G', 'D', 'A', 'E', 'B'],
        # Minor keys
        'Abmin': ['B-', 'E-', 'A-', 'D-'],
        'Ebmin': ['B-', 'E-', 'A-'],
        'Bbmin': ['B-', 'E-'],
        'Fmin': ['B-', 'E-', 'A-', 'D-'],
        'Cmin': ['B-', 'E-', 'A-'],
        'Gmin': ['B-', 'E-'],
        'Dmin': ['B-'],
        'Amin': [],
        'Emin': ['F'],
        'Bmin': ['F', 'C'],
        'F#min': ['F', 'C', 'G'],
        'C#min': ['F', 'C', 'G', 'D'],
        'G#min': ['F', 'C', 'G', 'D', 'A']
    }

    # Find and store the key
    lines = abc_string.split('\n')
    current_key = None
    key_line_index = None

    for i, line in enumerate(lines):
        if line.startswith('K:'):
            current_key = line.split(':')[1].strip().split('%')[0].strip()  # Get the key without comments
            key_line_index = i
            break

    if not current_key or current_key == 'none':
        return abc_string

    # Get the affected notes for this key
    affected_notes = key_signatures.get(current_key, [])

    # Replace the key line with "none"
    lines[key_line_index] = 'K: none'

    # Process the music lines
    for i in range(len(lines)):
        if not any(lines[i].startswith(x) for x in ['M:', 'L:', 'K:', 'Q:', '%', 'V:', '%%']):
            line = lines[i]
            new_line = ''
            j = 0
            while j < len(line):
                if j+1 < len(line) and line[j] == '=' and line[j+1].isalpha():
                    # For naturals, just keep the note without the natural sign
                    new_line += line[j+1]
                    j += 2
                elif line[j].isalpha():
                    note = line[j].upper()
                    # Check if this note is affected by key signature and doesn't already have an accidental
                    if note in [n[0] for n in affected_notes]:
                        # Only add accidental if there isn't one already
                        if j == 0 or not (line[j-1] in ['^', '_', '=']):
                            # Add sharp (^) or flat (_) based on key signature
                            if affected_notes[0][-1] == '-':  # if it's a flat key
                                new_line += '_'
                            else:  # if it's a sharp key
                                new_line += '^'
                    new_line += line[j]
                    j += 1
                else:
                    new_line += line[j]
                    j += 1
            lines[i] = new_line

    return '\n'.join(lines)

def remove_comment_lines(abc_string):
    # Split the input string into lines
    lines = abc_string.splitlines()
    # Filter out lines that start with "%" or "T: "
    non_comment_lines = [line for line in lines if not line.strip().startswith('%') and not line.strip().startswith("T: ")]
    # Join the remaining lines back into a single string
    return "\n".join(non_comment_lines)

abc_questions = [
    """Transcribe this music clip using ABC notation. Follow this template:

```
M: 4/4
L: 1/16
Q:1/4=120
K: none

V:1 name="Instrument 1"
(notes here)

(Add more voices if present)
```

Key points for ABC notation:
- Notes: A-G (lowercase for higher octaves)
- Accidentals: ^ (sharp), _ (flat)
- Note length: Numbers after note (C2 = twice as long as C)
- Dotted notes: . after note (C.)
- Rests: z with optional duration (z2 = half rest)
- Chords: [CEG]
- Ties: -
- Bar lines: |
- Broken rhythms: > or < between notes (C>D = dotted-C eighth + D sixteenth)

Important:
- The key should always be set to "none" as this is a short clip and it's hard to identify a key.
- Always explicitly show accidentals as the key is set to none.
- Include a separate voice (V:) for each distinct instrument you can identify.
- List the instrument names using one of the 128 General MIDI instrument names.
    """
]

In [None]:
from typing import Tuple, Dict
from pretty_midi import pretty_midi
from mido import MidiFile, MidiTrack, Message, MetaMessage, bpm2tempo

# Global cache for MIDI data
midi_data_cache: Dict[str, pretty_midi.PrettyMIDI] = {}

def get_midi_data(midi_file: str) -> pretty_midi.PrettyMIDI:
    if midi_file in midi_data_cache:
        return midi_data_cache[midi_file]
    else:
        midi_data = pretty_midi.PrettyMIDI(midi_file)
        midi_data_cache[midi_file] = midi_data
        return midi_data

def get_tempo_and_metre(midi_file: str, midi_data: pretty_midi.PrettyMIDI, chunk_start_time: float, chunk_end_time: float) -> Tuple[float, int, int]:
    tempo_changes_times, tempo_changes_values = midi_data.get_tempo_changes()

    initial_tempo = tempo_changes_values[0]
    for i, t in enumerate(tempo_changes_times):
        if t <= chunk_start_time:
            initial_tempo = tempo_changes_values[i]
        else:
            break

    for i, t in enumerate(tempo_changes_times):
        if chunk_start_time < t < chunk_end_time:
            new_tempo = tempo_changes_values[i]
            tempo_change = abs(new_tempo - initial_tempo)

            if tempo_change > 10:
                raise ValueError(f"Significant tempo change (>{tempo_change:.2f} BPM) detected within the chunk ({chunk_start_time} - {chunk_end_time})")

    try:
        relevant_time_signature = None
        for ts in midi_data.time_signature_changes:
            if ts.time <= chunk_start_time:
                relevant_time_signature = ts
            elif ts.time > chunk_start_time:
                break

        if relevant_time_signature is None:
            relevant_time_signature = midi_data.time_signature_changes[0]

        changes_within_chunk = [ts for ts in midi_data.time_signature_changes
                              if chunk_start_time < ts.time < chunk_end_time]
        if changes_within_chunk:
            print(f"Note: Time signature change(s) detected within the chunk ({chunk_start_time} - {chunk_end_time}). Using the first time signature for the entire chunk.")

    except IndexError:
        raise ValueError("Error: The time_signatures list is empty.")

    return (initial_tempo, relevant_time_signature.numerator, relevant_time_signature.denominator)

def get_tempo_and_metre_simp(midi_data: pretty_midi.PrettyMIDI) -> Tuple[float, int, int]:
    tempo = midi_data.get_tempo_changes()[1][0]
    time_sig = midi_data.time_signature_changes[0]
    return (tempo, time_sig.numerator, time_sig.denominator)

from mido import MidiFile, MidiTrack, Message, MetaMessage, bpm2tempo
from typing import Tuple, Dict
from pretty_midi import pretty_midi

def extract_midi_segment(input_file: str, output_file: str, start_time: float, duration: float) -> None:
    end_time = start_time + duration

    # Get cached MIDI data
    midi_data = get_midi_data(input_file)

    tempo_bpm, numerator, denominator = get_tempo_and_metre(input_file, midi_data, start_time, end_time)
    tempo_microseconds = bpm2tempo(tempo_bpm)
    mid = MidiFile(input_file)
    new_midi = MidiFile(ticks_per_beat=mid.ticks_per_beat)

    # Create metadata track
    meta_track = MidiTrack()
    new_midi.tracks.append(meta_track)

    meta_track.append(MetaMessage('time_signature',
                                numerator=numerator,
                                denominator=denominator,
                                clocks_per_click=24,
                                notated_32nd_notes_per_beat=8,
                                time=0))

    meta_track.append(MetaMessage('set_tempo',
                                tempo=tempo_microseconds,
                                time=0))

    start_ticks = int(start_time * tempo_bpm * mid.ticks_per_beat / 60)
    end_ticks = int(end_time * tempo_bpm * mid.ticks_per_beat / 60)

    # Process each track
    for track in mid.tracks[1:]:
        new_track = MidiTrack()
        current_time = 0
        track_started = False
        active_notes = {}  # Dictionary to track active notes: key = (note, channel), value = start_time
        messages_added = False

        # Pre-scan for notes that started before our window but should be included
        for msg in track:
            current_time += msg.time
            if current_time < start_ticks:
                if msg.type == 'note_on' and msg.velocity > 0:
                    active_notes[(msg.note, msg.channel)] = current_time
                elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                    if (msg.note, msg.channel) in active_notes:
                        del active_notes[(msg.note, msg.channel)]

        # Add still-active notes from before our window
        if active_notes:
            for (note, channel), note_start in active_notes.items():
                new_track.append(Message('note_on', note=note, velocity=64,
                                      time=0 if track_started else 0,
                                      channel=channel))
                track_started = True
                messages_added = True

        # Reset for main processing
        current_time = 0

        # Process messages within our window
        for msg in track:
            current_time += msg.time

            if current_time <= end_ticks:
                if start_ticks <= current_time:
                    # Handle note_on messages
                    if msg.type == 'note_on' and msg.velocity > 0:
                        if not track_started:
                            new_msg = msg.copy(time=0)
                            track_started = True
                        else:
                            new_msg = msg.copy(time=msg.time)
                        new_track.append(new_msg)
                        messages_added = True
                        active_notes[(msg.note, msg.channel)] = current_time

                    # Handle note_off messages (including note_on with velocity 0)
                    elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                        if (msg.note, msg.channel) in active_notes:
                            if not track_started:
                                new_msg = msg.copy(time=0)
                                track_started = True
                            else:
                                new_msg = msg.copy(time=msg.time)
                            new_track.append(new_msg)
                            messages_added = True
                            del active_notes[(msg.note, msg.channel)]

                    # Handle other control messages (pitch bend, control change, etc.)
                    elif msg.type in ['pitch_bend', 'control_change', 'program_change']:
                        if not track_started:
                            new_msg = msg.copy(time=0)
                            track_started = True
                        else:
                            new_msg = msg.copy(time=msg.time)
                        new_track.append(new_msg)
                        messages_added = True

        # For any notes still active at the end of the chunk,
        # add note_off messages at the chunk boundary
        if active_notes:
            last_time = current_time if current_time <= end_ticks else end_ticks
            time_to_end = end_ticks - last_time

            # Add note_off for first note with the time delta to chunk end
            first_note = True
            for (note, channel) in list(active_notes.keys()):
                if first_note:
                    new_track.append(Message('note_off', note=note, velocity=0,
                                          time=time_to_end,
                                          channel=channel))
                    first_note = False
                else:
                    # Subsequent note_offs happen simultaneously
                    new_track.append(Message('note_off', note=note, velocity=0,
                                          time=0,
                                          channel=channel))
                messages_added = True
                del active_notes[(note, channel)]

        # Only add tracks that had actual content
        if messages_added:
            new_midi.tracks.append(new_track)
            new_track.append(MetaMessage('end_of_track', time=0))

    new_midi.save(output_file)

## Datasets

In [None]:
import os
from torch.utils.data import Dataset
from collections import OrderedDict
import random
from typing import List, Tuple
import numpy as np
import librosa

def generate_short_qa_data(tempo: float, instruments: List[str]) -> Tuple[str, str]:
    tempo = int(tempo)
    question_type = random.choices(["tempo", "instrument"], weights=[20, 80], k=1)[0]

    if question_type == "tempo":
        # Variations of tempo-related questions
        tempo_questions = [
            "What is the tempo of this audio clip?",
            "Can you tell me the tempo of the track?",
            "How fast is the tempo in this music?",
            "What beats per minute (BPM) is this audio clip playing at?",
            "Identify the tempo in this audio track."
        ]
        question = random.choice(tempo_questions)

        # Variations of tempo-related answers
        tempo_answers = [
            f"The tempo of this audio clip is {tempo} BPM.",
            f"This track has a tempo of {tempo} beats per minute.",
            f"The BPM for this audio is {tempo}.",
            f"The music plays at a tempo of {tempo} BPM.",
            f"The speed of this track is {tempo} beats per minute."
        ]
        answer = random.choice(tempo_answers)

    elif question_type == "instrument":
        instrument_questions = [
            "What instrument is playing in this audio clip?",
            "What instruments are playing in this audio clip?",
            "Can you identify the instrument in this track?",
            "Can you name the instruments playing in this track?",
            "Which instruments are present in this audio?",
            "List the instruments that are featured in this recording."
        ]

        question = random.choice(instrument_questions)

        # Variations of instrument-related answers for one instrument
        if len(instruments) == 1:
            instrument_answers_single = [
                f"In this audio clip, a {instruments[0]} is playing.",
                f"You can hear a {instruments[0]} playing in this audio.",
                f"This track features a {instruments[0]}.",
                f"The sound in this audio is produced by a {instruments[0]}.",
                f"In this recording, a {instruments[0]} is the main instrument."
            ]
            answer = random.choice(instrument_answers_single)

        # Variations of instrument-related answers for multiple instruments
        else:
            instrument_answers_multiple = [
                f"In this audio clip, the following instruments are playing: {', '.join(instruments[:-1])}, and {instruments[-1]}.",
                f"This track features these instruments: {', '.join(instruments[:-1])}, and {instruments[-1]}.",
                f"The instruments heard in this recording are: {', '.join(instruments[:-1])}, and {instruments[-1]}.",
                f"You can hear the following instruments in this audio: {', '.join(instruments[:-1])}, and {instruments[-1]}.",
                f"In this audio, you can hear a combination of: {', '.join(instruments[:-1])}, and {instruments[-1]}."
            ]
            answer = random.choice(instrument_answers_multiple)

    return question, answer

class BaseAudioDataset(Dataset):
    audio_cache = OrderedDict()
    max_cache_size = 1000

    def __init__(self, root_dir, dataset_name, duration=25, target_sr=16000, return_audio=True):
        self.root_dir = root_dir
        self.dataset_name = dataset_name
        self.duration = duration
        self.target_sr = target_sr
        self.audio_cache = OrderedDict()
        self.max_cache_size = 1000
        self.samples = []
        self.return_audio = return_audio

        BaseAudioDataset.max_cache_size = 1000

    def load_audio(self, audio_path):
        if not self.return_audio:
            return None

        cache = BaseAudioDataset.audio_cache
        if audio_path in cache:
            cache.move_to_end(audio_path)
            return cache[audio_path]
        else:
            audio, sr = librosa.load(audio_path, sr=None, mono=True)
            if sr != self.target_sr:
                audio = librosa.resample(audio, orig_sr=sr, target_sr=self.target_sr)
            cache[audio_path] = audio
            if len(cache) > BaseAudioDataset.max_cache_size:
                cache.popitem(last=False)
            return audio

    def get_audio_chunk(self, audio, start_time, duration):
        start_sample = int(start_time * self.target_sr)
        end_sample = int((start_time + duration) * self.target_sr)
        audio_chunk = audio[start_sample:end_sample]

        return normalize_audio(audio_chunk)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if idx >= len(self.samples):
            raise IndexError(f"Index {idx} is out of range for dataset with {len(self.samples)} samples")

        print(f"Processing sample {idx}")
        audio_path, _, start_time, chunk_duration, folder_path, _ = self.samples[idx]

        audio_chunk = None
        if self.return_audio:
            audio = self.load_audio(audio_path)
            audio_chunk = self.get_audio_chunk(audio, start_time, chunk_duration)

        result = self.process_item(idx, audio_chunk)
        if result is not None:
            return result

    def process_item(self, idx, audio_chunk):
        # Implemented by child classes
        raise NotImplementedError

def normalize_audio(audio_chunk):
    mean_audio = np.mean(audio_chunk)
    std_audio = np.std(audio_chunk)
    if std_audio > 0:
        return (audio_chunk - mean_audio) / std_audio
    else:
        return audio_chunk - mean_audio

### Maestro

In [None]:
def process_maestro_chunk(midi_file, synthetic_data: bool) -> Tuple[str, str]:
    midi_data = pretty_midi.PrettyMIDI(midi_file)

    tempo, metre_numerator, metre_denominator = get_tempo_and_metre_simp(midi_data)

    if synthetic_data:
      instruments = [
          pretty_midi.program_to_instrument_name(instrument.program)
          for instrument in midi_data.instruments
      ]
      return generate_short_qa_data(tempo, instruments)

    abc_notation = cmd_midi_to_abc(midi_file, extract_all_tracks=True)
    abc_notation = remove_key_signature(abc_notation)
    abc_notation = remove_comment_lines(abc_notation)
    abc_notation = remove_backslash_after_bar(abc_notation)
    abc_notation = abc_notation.replace('V:1', 'V:1 name="Acoustic Grand Piano"')

    return random.choice(abc_questions), abc_notation

In [None]:
!mkdir /content/maestro/temp_midi/

In [None]:
import os
import yaml
import librosa
import numpy as np
import pretty_midi
import torch
from torch.utils.data import Dataset
import random
import pickle
import os
import torch.nn.functional as F
import glob
import mido

class MaestroDataset(BaseAudioDataset):
    def __init__(self, root_dir, dataset_name, duration=25, target_sr=16000, return_audio=True):
        super().__init__(root_dir, dataset_name, duration, target_sr, return_audio)
        self.initialize_samples()

    def initialize_samples(self):
        for year_folder in os.listdir(self.root_dir):
            year_path = os.path.join(self.root_dir, year_folder)
            if os.path.isdir(year_path):

                wav_files = glob.glob(os.path.join(year_path, '*.wav'))

                for wav_file in wav_files:
                    # Get the base name without extension
                    base_name = os.path.splitext(os.path.basename(wav_file))[0]
                    midi_path = os.path.join(year_path, base_name + '.midi')

                    total_duration = librosa.get_duration(path=wav_file)

                    num_chunks = int(total_duration // self.duration)  # Full chunks
                    for i in range(num_chunks):
                        start_time = i * self.duration
                        midi_file_name = os.path.basename(midi_path)
                        output_midi_path = f'/content/maestro/temp_midi/{midi_file_name}_{str(start_time)}.mid'

                        try:
                            extract_midi_segment(midi_path, output_midi_path, float(start_time), float(self.duration))
                            self.samples.append((wav_file, output_midi_path, start_time, self.duration, "", False))
                        except Exception as e:
                            print(f"Skipping sample due to MIDI extraction error: {e}")
                            continue

        '''
        # Take 25% of the samples
        sample_size = len(self.samples) // 4
        self.samples = random.sample(self.samples, sample_size)
        '''

    def process_item(self, idx, audio_chunk):
        audio_path, midi_path, start_time, chunk_duration, _, synthetic_flag = self.samples[idx]

        query, answer = process_maestro_chunk(midi_path, synthetic_flag)

        return audio_chunk, query, answer, self.dataset_name, audio_path, int(start_time)

### URMP

In [None]:
from pretty_midi import Instrument
from typing import Optional, List, Tuple
import re
import random

def insert_instrument_names(abc_notation, instrument_names):
    lines = abc_notation.split('\n')
    instrument_index = 0

    # Process each line
    for i in range(len(lines)):
        # Look for voice headers
        if lines[i].startswith('V:'):
            # Extract voice number
            voice_num = lines[i].split(':')[1].strip()

            # Replace the line with new format including instrument name
            if instrument_index < len(instrument_names):
                lines[i] = f'V:{voice_num} name="{instrument_names[instrument_index]}"'
                instrument_index += 1

    return '\n'.join(lines)

def process_urmp_chunk(midi_track_number: int, instruments: List[str], midi_file: str,
                      chunk_start_time: int, sample_duration: int, synthetic_data: bool) -> Tuple[str, str]:
    def get_instrument_name(instrument: Instrument, instruments: List[str], index: int) -> str:
        if not instrument.name:
            print(f"Warning: instrument names not available: {midi_data.instruments}")
            return instruments[index]
        return instrument.name

    midi_data = pretty_midi.PrettyMIDI(midi_file)
    tempo, metre_numerator, metre_denominator = get_tempo_and_metre_simp(midi_data)

    # Handle instrument names
    if midi_track_number is not None:
        instrument_names = [instruments[0]] if synthetic_data else instruments
    else:
        instrument_names = [
            get_instrument_name(instrument, instruments, i)
            for i, instrument in enumerate(midi_data.instruments)
        ]

    if synthetic_data:
        return generate_short_qa_data(tempo, instrument_names)

    abc_notation = cmd_midi_to_abc(midi_file, extract_all_tracks=True)
    abc_notation = remove_key_signature(abc_notation)
    abc_notation = remove_comment_lines(abc_notation)
    abc_notation = remove_backslash_after_bar(abc_notation)
    abc_notation = insert_instrument_names(abc_notation, instrument_names)

    return random.choice(abc_questions), abc_notation

# Map of instrument abbreviations to full names
instrument_map = {
    "Vn": "Violin",
    "Va": "Viola",
    "Vc": "Cello",
    "Db": "Double Bass",
    "Fl": "Flute",
    "Ob": "Oboe",
    "Cl": "Clarinet",
    "Sax": "Saxophone",
    "Bn": "Bassoon",
    "Tpt": "Trumpet",
    "Hn": "Horn",
    "Tbn": "Trombone",
    "Tba": "Tuba"
}

def extract_instruments(string):
    instruments = []
    for abbr, name in instrument_map.items():
        # Use re.findall to find all occurrences of the abbreviation, case insensitive
        matches = re.findall(rf'(?<![a-zA-Z]){abbr}(?![a-zA-Z])', string, re.IGNORECASE)
        instruments.extend([name] * len(matches))
    return instruments

In [None]:
!mkdir /content/URMP_Dataset/temp_midi

In [None]:
import os
import librosa

class URMPDataset(BaseAudioDataset):
    def __init__(self, root_dir, dataset_name, duration=25, target_sr=16000, return_audio=True):
        super().__init__(root_dir, dataset_name, duration, target_sr, return_audio)
        self.initialize_samples()

    def initialize_samples(self):
        for sub_folder in os.listdir(self.root_dir):
            sub_folder_path = os.path.join(self.root_dir, sub_folder)

            if os.path.isdir(sub_folder_path):
                sub_folder_name = os.path.basename(sub_folder_path)
                midi_path = os.path.join(sub_folder_path, f"Sco_{sub_folder_name}.mid")

                for file in os.listdir(sub_folder_path):
                    # OS specific files
                    if file.startswith('._') or not file.endswith('.wav'):
                        continue

                    audio_path = os.path.join(sub_folder_path, file)

                    if os.path.exists(midi_path):
                        total_duration = librosa.get_duration(path=audio_path)
                        num_chunks = math.ceil(total_duration / self.duration)

                        for i in range(num_chunks):
                            start_time = i * self.duration
                            chunk_duration = min(self.duration, total_duration - start_time)
                            midi_file_name = os.path.basename(midi_path)
                            output_midi_path = f'/content/URMP_Dataset/temp_midi/{midi_file_name}_{str(start_time)}.mid'

                            try:
                                extract_midi_segment(midi_path, output_midi_path, start_time, chunk_duration)
                                self.samples.append((audio_path, output_midi_path, start_time, chunk_duration, "", False))

                                # Generate short QA sample
                                if random.randint(1, 1) == 1:
                                    self.samples.append((audio_path, output_midi_path, start_time, chunk_duration, "", True))
                            except Exception as e:
                                print(f"Skipping sample due to MIDI extraction error: {e}")
                                continue


    def process_item(self, idx, audio_chunk):
        audio_path, midi_path, start_time, chunk_duration, _, synthetic_flag = self.samples[idx]

        audio_path_basename = os.path.basename(audio_path)
        instruments = extract_instruments(audio_path_basename)

        if "AuSep" in audio_path_basename:
            midi_track_number = int(audio_path_basename[6])
        else:
            midi_track_number = None

        query, answer = process_urmp_chunk(midi_track_number, instruments, midi_path, start_time, chunk_duration, synthetic_flag)

        id = os.path.basename(audio_path)

        sample_type = "qa" if synthetic_flag else "abc"

        return audio_chunk, query, answer, self.dataset_name, id, start_time, sample_type

In [None]:
import re
import math

urmp_dataset = URMPDataset(root_dir="/content/URMP_Dataset/Dataset", dataset_name="urmp")
maestro_dataset = MaestroDataset(root_dir="/content/maestro", dataset_name="maestro")

In [None]:
import torch
from torch.utils.data import Dataset
import random

class CombinedDataset(Dataset):
    def __init__(self, datasets):
        self.datasets = datasets
        self.lengths = [len(dataset) for dataset in datasets]
        self.cumulative_lengths = torch.cumsum(torch.tensor(self.lengths), dim=0)

        # Create a list of indices that covers all datasets
        self.indices = [(i, j) for i, dataset in enumerate(datasets) for j in range(len(dataset))]

        random.shuffle(self.indices)

    def __len__(self):
        return sum(self.lengths)

    def __getitem__(self, idx):
        dataset_idx, sample_idx = self.indices[idx]
        return self.datasets[dataset_idx][sample_idx]

In [None]:
combined_dataset = CombinedDataset([urmp_dataset, maestro_dataset])

## Model

In [None]:
# Import the login function
from huggingface_hub import login

# Log in to Hugging Face
login(token='hf_token')

In [None]:
!pip install -q -U bitsandbytes
!pip install git+https://github.com/huggingface/transformers
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install datasets
!pip install trl==0.11.4

**RESTART THE SESSION OTHERWISE THE TRAINER WON'T WORK!**

In [None]:
from transformers import TrainingArguments, Qwen2AudioForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType
from trl import SFTTrainer
import torch

model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", torch_dtype=torch.bfloat16, device_map="auto")
# model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", load_in_4bit=True, device_map="auto")

peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=128, lora_alpha=256, lora_dropout=0, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
model.enable_input_require_grads()
model = get_peft_model(model, peft_config)
#model.gradient_checkpointing_enable()

## Training

### Load wandb

In [None]:
import wandb, os
wandb.login()

wandb_project = "autotab"
if len(wandb_project) > 0:
    os.environ["WANDB_PROJECT"] = wandb_project

In [None]:
project = "Qwen2Audio-1"
run_name = "run"
project_and_run_name = project + "-" + run_name
output_dir = "./" + project_and_run_name

### Train with SFT

In [None]:
from transformers import Trainer, TrainingArguments
from datetime import datetime
from trl import SFTTrainer

wandbname = project + "-" + run_name

import torch
import torch.nn.functional as F

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,
    #max_steps=1000,
    # per_device_eval_batch_size = 0,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1,
    warmup_ratio = 0.1,
    logging_dir='./logs',
    learning_rate = 1e-5,
    logging_steps = 1,
    #evaluation_strategy="steps",
    #eval_steps=75,
    #save_steps=100,
    max_grad_norm=10.0,
    fp16 = not torch.cuda.is_bf16_supported(),
    gradient_checkpointing=True,
    bf16 = torch.cuda.is_bf16_supported(),
    optim = "adamw_8bit",
    weight_decay = 0.001,
    #seed = 3407,
    save_strategy="no",
    #save_strategy="epoch",
    lr_scheduler_type = "cosine",
    #load_best_model_at_end=True,
    report_to="wandb",
    run_name=f"{wandbname}-{datetime.now().strftime('%m-%d-%H-%M')}"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    # eval_dataset=test_dataset,
    data_collator=data_collator
)

trainer.train()

### Train with PPO

In [None]:
project = "Qwen2Audio-1"
run_name = "run"
project_and_run_name = project + "-" + run_name
output_dir = "./" + project_and_run_name

In [None]:
import wandb, os
wandb.login()

wandb_project = "autotab"
if len(wandb_project) > 0:
    os.environ["WANDB_PROJECT"] = wandb_project

#### Custom loss/reward model

In [None]:
import torch
import torch.nn as nn
import re
import math
import pretty_midi
import numpy as np
import difflib

class ABCLoss(nn.Module):
    def __init__(self):
        super(ABCLoss, self).__init__()

        # Mapping keys to their positions on the circle of fifths, accounting for both major and minor keys
        self.key_to_circle_position = {
            'C': 0, 'G': 1, 'D': 2, 'A': 3, 'E': 4, 'B': 5, 'F#': 6, 'C#': 7,
            'F': -1, 'Bb': -2, 'Eb': -3, 'Ab': -4, 'Db': -5, 'Gb': -6, 'Cb': -7,
            'Am': 0, 'Em': 1, 'Bm': 2, 'F#m': 3, 'C#m': 4, 'G#m': 5, 'D#m': 6, 'A#m': 7,
            'Dm': -1, 'Gm': -2, 'Cm': -3, 'Fm': -4, 'Bbm': -5, 'Ebm': -6, 'Abm': -7
        }

    def extract_metre(self, abc_string):
        match = re.search(r'M:\s*\(?\s*(\d+/\d+)\s*\)?', abc_string)
        return match.group(1) if match else None

    def extract_length(self, abc_string):
        match = re.search(r'L:\s*\(?\s*(\d+/\d+)\s*\)?', abc_string)
        return match.group(1) if match else None

    def extract_tempo(self, abc_string):
        match = re.search(r'Q:\s*\(?\s*(\d+)\s*\)?', abc_string)
        return match.group(1) if match else None

    def extract_key(self, abc_string):
        match = re.search(r'K:\s*\(?\s*(\S+)\s*\)?', abc_string)
        return match.group(1) if match else None

    def extract_voice_notes(self, abc_string):
        # Extracts the content of each voice, including the 'V:...' line
        voices = re.findall(r'(V:\d+.*?\n(?:.*?\n)*?)(?=V:|$)', abc_string)
        voice_contents = []
        for voice in voices:
            # Remove the voice header line
            content = '\n'.join(voice.split('\n')[1:])
            voice_contents.append(content)
        return voice_contents

    def parse_fraction(self, fraction_string):
        """Convert a fraction string like '4/4' into separate numerator and denominator."""
        num, denom = fraction_string.split('/')
        return int(num), int(denom)

    def compute_metre_loss(self, pred_metre, gt_metre):
        pred_num, pred_denom = self.parse_fraction(pred_metre)
        gt_num, gt_denom = self.parse_fraction(gt_metre)

        # Compute loss for numerator and denominator separately
        num_diff = abs(pred_num - gt_num)
        denom_diff = abs(pred_denom - gt_denom)

        # Compute a ratio difference
        ratio_pred = pred_num / pred_denom
        ratio_gt = gt_num / gt_denom
        ratio_diff = abs(ratio_pred - ratio_gt)

        # Define thresholds
        linear_threshold_num = 1
        linear_threshold_denom = 2
        linear_threshold_ratio = 0.25

        # Compute individual losses
        def compute_component_loss(diff, threshold):
            if diff <= threshold:
                return diff / threshold * 0.5
            else:
                base_loss = 0.5
                exp_scale = 1 - math.exp(-(diff - threshold) / threshold)
                return base_loss + 0.5 * exp_scale

        num_loss = compute_component_loss(num_diff, linear_threshold_num)
        denom_loss = compute_component_loss(denom_diff, linear_threshold_denom)
        ratio_loss = compute_component_loss(ratio_diff, linear_threshold_ratio)

        # Combine losses with emphasis on the ratio difference
        total_loss = 0.3 * num_loss + 0.3 * denom_loss + 0.4 * ratio_loss

        # Ensure loss is between 0 and 1
        total_loss = min(total_loss, 1.0)

        return torch.tensor(total_loss, dtype=torch.float32)

    def compute_tempo_loss(self, pred_tempo, gt_tempo):
        if not pred_tempo or not gt_tempo:
            return torch.tensor(0.0)

        pred_value = float(pred_tempo)
        target_value = float(gt_tempo)

        # Calculate the absolute difference
        diff = abs(pred_value - target_value)

        # Define thresholds
        linear_threshold = 10  # BPM
        max_diff = 200  # Maximum expected difference in BPM

        # Compute loss
        if diff <= linear_threshold:
            # Linear scaling for small differences
            loss = diff / linear_threshold * 0.25  # Max 0.5 for linear part
        else:
            # Exponential scaling for larger differences
            base_loss = 0.5  # Starting point after linear threshold
            exp_scale = 1 - math.exp(-(diff - linear_threshold) / 50)  # Adjust '50' to control steepness
            loss = base_loss + 0.5 * exp_scale  # Additional 0.5 for exponential part

        # Ensure loss is between 0 and 1
        loss = min(loss, 1.0)

        return torch.tensor(loss, dtype=torch.float32)

    def extract_instruments(self, abc_string):
        # Handle instrument names with or without quotes
        voice_headers = re.findall(r'V:\d+[^\n]*name="?([^"\n]*)"?', abc_string)
        return [self.instrument_name_to_program(name) for name in voice_headers]

    def instrument_name_to_program(self, name):
        # Clean up the name
        name = name.strip().lower()
        # Get list of all MIDI instrument names
        instrument_names = [pretty_midi.program_to_instrument_name(p).lower() for p in range(128)]
        # Use difflib to get the closest match
        matches = difflib.get_close_matches(name, instrument_names, n=1, cutoff=0.6)
        if matches:
            closest_name = matches[0]
            # Get the program number for the closest name
            for program in range(128):
                if pretty_midi.program_to_instrument_name(program).lower() == closest_name:
                    return program
        else:
            # No close match found
            return -1  # Return -1 if no matching instrument is found

    def program_to_family(self, program):
        # Returns the family index for a given program number
        return program // 8

    def compute_instrument_loss(self, pred_abc, gt_abc):
        pred_instruments = self.extract_instruments(pred_abc)
        gt_instruments = self.extract_instruments(gt_abc)

        total_instruments = max(len(pred_instruments), len(gt_instruments))
        if total_instruments == 0:
            return torch.tensor(0.0)  # No instruments in either prediction or ground truth

        total_loss = 0.0
        for pred_inst, gt_inst in zip(pred_instruments, gt_instruments):
            if pred_inst == gt_inst:
                loss = 0.0  # Exact match
            elif pred_inst == -1 or gt_inst == -1:
                # One of the instruments is unknown
                loss = 1.0
            elif self.program_to_family(pred_inst) == self.program_to_family(gt_inst):
                # Same family
                loss = 0.5
            else:
                # Compute similarity between instrument names
                pred_name = pretty_midi.program_to_instrument_name(pred_inst).lower()
                gt_name = pretty_midi.program_to_instrument_name(gt_inst).lower()
                similarity = difflib.SequenceMatcher(None, pred_name, gt_name).ratio()
                loss = 1 - similarity  # Higher similarity, lower loss
            total_loss += loss

        avg_loss = total_loss / total_instruments
        return torch.tensor(avg_loss, dtype=torch.float32)

    def note_to_number(self, note):
        note_map = {'C': 0, 'D': 2, 'E': 4, 'F': 5, 'G': 7, 'A': 9, 'B': 11}
        base_note = note[0].upper()
        base = note_map.get(base_note, 0)
        modifiers = note[1:]
        octave = 0
        accidental = 0
        idx = 1
        while idx < len(note):
            char = note[idx]
            if char == ',':
                octave -= 1
            elif char == "'":
                octave += 1
            elif char == '^':
                accidental += 1
            elif char == '_':
                accidental -= 1
            else:
                break
            idx += 1
        # Apply accidentals
        base += accidental
        # Shift by octave
        base += (octave + 5) * 12  # shift to avoid negative numbers
        return base

    def extract_notes_from_voice(self, voice_string):
        # Remove whitespace and newlines
        voice_string = re.sub(r'\s+', '', voice_string)
        # Extract notes
        notes = re.findall(r'[\^_]*[A-Ga-g][,\']*', voice_string)
        note_numbers = [self.note_to_number(note) for note in notes]
        return note_numbers

    def normalized_levenshtein_distance(self, seq1, seq2):
        len_seq1 = len(seq1)
        len_seq2 = len(seq2)
        max_len = max(len_seq1, len_seq2)
        if max_len == 0:
            return 0.0
        # Compute the Levenshtein distance
        d = np.zeros((len_seq1 + 1, len_seq2 + 1), dtype=int)
        for i in range(len_seq1 + 1):
            d[i][0] = i
        for j in range(len_seq2 + 1):
            d[0][j] = j
        for i in range(1, len_seq1 + 1):
            for j in range(1, len_seq2 + 1):
                if seq1[i - 1] == seq2[j - 1]:
                    cost = 0
                else:
                    cost = 1
                d[i][j] = min(
                    d[i - 1][j] + 1,      # deletion
                    d[i][j - 1] + 1,      # insertion
                    d[i - 1][j - 1] + cost  # substitution
                )
        distance = d[len_seq1][len_seq2]
        normalized_distance = distance / max_len
        return normalized_distance

    def forward(self, predicted_abc, ground_truth_abc):
        # Extract components
        pred_metre = self.extract_metre(predicted_abc)
        gt_metre = self.extract_metre(ground_truth_abc)

        pred_tempo = self.extract_tempo(predicted_abc)
        gt_tempo = self.extract_tempo(ground_truth_abc)

        pred_voice_notes = self.extract_voice_notes(predicted_abc)
        gt_voice_notes = self.extract_voice_notes(ground_truth_abc)

        metre_loss = self.compute_metre_loss(pred_metre, gt_metre) if pred_metre and gt_metre else torch.tensor(0.0)
        tempo_loss = self.compute_tempo_loss(pred_tempo, gt_tempo) if pred_tempo and gt_tempo else torch.tensor(0.0)

        # Compute pitch loss
        pred_voice_notes_list = [self.extract_notes_from_voice(voice) for voice in pred_voice_notes]
        gt_voice_notes_list = [self.extract_notes_from_voice(voice) for voice in gt_voice_notes]

        pitch_losses = []
        for pred_notes in pred_voice_notes_list:
            # For this predicted voice, compute distances to all ground truth voices
            distances = []
            for gt_notes in gt_voice_notes_list:
                distance = self.normalized_levenshtein_distance(pred_notes, gt_notes)
                distances.append(distance)
            if distances:
                # Take the smallest distance
                min_distance = min(distances)
                pitch_losses.append(min_distance)
            else:
                # No ground truth voices, maximum loss
                pitch_losses.append(1.0)
        # Now, compute the average pitch loss
        if pitch_losses:
            pitch_loss = torch.tensor(sum(pitch_losses) / len(pitch_losses), dtype=torch.float32)
        else:
            pitch_loss = torch.tensor(1.0)

        instrument_loss = self.compute_instrument_loss(predicted_abc, ground_truth_abc)

        # Total loss
        total_loss = (0.5 * pitch_loss +
                      0.15 * metre_loss +
                      0.15 * tempo_loss +
                      0.2 * instrument_loss)

        loss_dict = {
            'Total Loss': total_loss.item(),
            'Pitch Loss': pitch_loss.item(),
            'Metre Loss': metre_loss.item(),
            'Tempo Loss': tempo_loss.item(),
            'Instrument Loss': instrument_loss.item()
        }

        return total_loss, loss_dict

# Example usage
predicted_abc = """
M: 1/16 L: 1/16 Q: 120 K: none V:1 name="Piano"

c2 | c4 e3 d3 c'3 | G3 A3 B3 c'3 d3 E3 F3 |
G3 A3 B3 c'3 d3 E3 F3 | c2 d2 c2 e2 f2 |
g2 f2 e2 d2 c2 B2 A2 | G2 F2 E2 D2 C2 B2 A2 |
c2 d2 c2 e2 f2 | g2 f2 e2 d2 c2 B2 A2 |
c2 d2 c2 e2 f2 | g2 f2 e2 d2 c2 B2 A2 |
G2 F2 E2 D2 C2 B2 A2 | c2 d2 c2 e2 f2 |
g2 f2 e2 d2 c2 B2 A2 | G2 F2 E2 D2 C2 B2 A2 |
c2 d2 c2 e2 f2 | g2 f2 e2 d2 c2 B2 A2 |
G2 F2 E2 D2 C2 B2 A2 | c2 d2 c2 e2 f2 |
g2 f2 e2 d2 c2 B2 A2 | G2 F2 E2 D2 C2 B2 A2 |
c2 d2 c2 e2 f2 | g2 f2 e2 d2 c2 B2 A2 |
G2 F2 E2 D2 C2 B2 A2 | c2 d2 c2 e2 f2 |
g2 f2 e2 d2 c2 B2 A2 | G2 F2 E2 D2 C2 B2 A2 |
c2 d2 c2 e2 f2 | g2 f2 e2 d2 c2 B2 A2 |
G2 F2 E2 D2 C2 B2 A2 | c2 d2 c2 e2 f2 |
g2 f2 e2 d2 c2 B2 A2 | G2 F2 E2 D2 C2 B2 A2 |
c2 d2 c2 e2 f2 | g2 f2 e2 d2 c2 B2 A2 |
G2 F2 E2 D2 C2 B2 A2 | c2 d2 c2 e2 f2 |
g2 f2 e2 d2 c2 B2 A2 | G2 F2 E2 D2 C2 B2 A2 |
c2 d2 c2 e2 f2 | g2 f2 e2 d2 c2 B2 A2 |
G2 F2 E2 D2 C2 B2 A2 | c2 d2 c2 e2 f2 |
g2 f2 e2 d2 c2 B2 A2 | G2 F2 E2 D2 C2 B2 A2 |
c2 d2 c2 e2 f2 | g2 f2 e2 d2 c2 B2 A2 |
G2 F2 E2 D2 C2 B2 A2 | c2 d2 c2 e2 f2 |
g2 f2 e2 d2 c2 B2 A2 | G2 F2 E2 D2 C2 B2 A2 |
c2 d2 c2 e2 f2 | g2 f2 e2 d2 c2 B2 A2 |
G2 F2 E2 D2 C2 B2 A2 | c2 d2 c2 e2 f2 |
g2 f2 e2 d2 c2 B2 A2 | G2 F2 E2 D2 C2 B2 A2 |
c2 d2 c2 e2 f2 | g2 f2 e2 d2 c2 B2 A2 |
G2 F2 E2 D2 C2 B2 A2 | c2 d2 c2 e2 f2 |
g2 f2 e2 d2 c2 B2 A2 | G2 F2 E2 D2 C2 B2 A2
"""

ground_truth_abc = """
X:1
M:3/4
L:1/16
K:none
Q:220

V:1 name="Trumpet in Bb"
%%octave-default C6
B,,B,,2 | C,C,2B,, B,,G2 D^2DC | DDD^F4G2D^2D |
CDDCA^,4 G2D^2 | DCDD D^ F4G2D^2 |
DCD DCA^,4G2 | D2A^,G, A,A,A^,G, A,A,A^,G,G2 |
D2B, G,A,A,B, A,2<G,2 | G2D2 A^,G,A, A, A^,G,A,A, |

V:2 name="Trumpet in Bb"
%%octave-default C5
DDD | EF^G2 G,D^'2 A2AA | A^2>D'2 D'D^'F'D^'2A2A |
A2<A^2 D'D'C'A^ D^'2A2 | AAA^2>D'2D'D^' F'D^'2A2 |
AA2<A^2D' D' C'A^A^2 | A^2GG F^2D F^2DB2 |
B2G GF^2 F^2<D2 | A^2A^2 GGF^2DF^ |

V:3 name="Horn in F"
%%octave-default C5
A,B,C | B,A,G,D G,G2 G2FD^ | D2>A^,2 A^,CDG2G2F |
D^2<D2 A^,FD^D G2G2 | FD^D2>A^,2A^,C DG2G2 |
FD^2<D2A^, F D^DD2 | D2DA^, C2A^, C2A^,D2 |
D2D B,C2 C2<B,2 | D2D2 DA^,C2A^,C |

V:4 name="Trombone"
%%octave-default C3
D'3 | G3C'2 F2F2 | A^A^C'D'4C'2F2F2 |
A^F'D^'D'4 C'2F2 | F2A^A^ C' D'4C'2F2 |
F2A^ F'D^'D'4GA | A^C'D'D^' D'F^G D'F^GG |
GABBC'D' D^'D'E'F^' D'GGAB | GAA^C' D'D^'D' F^GD'F^ |
"""

abc_loss_fn = ABCLoss()
loss, loss_dict = abc_loss_fn(predicted_abc, ground_truth_abc)

print(f"loss: {loss}")
print(f"Detailed Losses: {loss_dict}")


#### Custom loss tests

In [None]:
import unittest
import torch

class TestABCLoss(unittest.TestCase):
    def setUp(self):
        self.loss_fn = ABCLoss()

    def test_extract_metre(self):
        abc_string = "M:4/4\nL:1/16\nK:C"
        self.assertEqual(self.loss_fn.extract_metre(abc_string), "4/4")

        abc_string = "M:3/4\nL:1/16\nK:C"
        self.assertEqual(self.loss_fn.extract_metre(abc_string), "3/4")

    def test_extract_tempo(self):
        abc_string = "Q:120\nM:4/4\nL:1/16\nK:C"
        self.assertEqual(self.loss_fn.extract_tempo(abc_string), "120")

        # Test when tempo is missing
        abc_string = "M:4/4\nL:1/16\nK:C"
        self.assertIsNone(self.loss_fn.extract_tempo(abc_string))

    def test_compute_metre_loss(self):
        # Test exact match
        loss = self.loss_fn.compute_metre_loss("4/4", "4/4")
        self.assertEqual(loss.item(), 0.0)

        # Test similar metres
        loss = self.loss_fn.compute_metre_loss("3/4", "4/4")
        self.assertLess(loss.item(), 1.0)
        self.assertGreater(loss.item(), 0.0)

        # Test very different metres
        loss = self.loss_fn.compute_metre_loss("2/4", "6/8")
        self.assertLess(loss.item(), 1.0)
        self.assertGreater(loss.item(), 0.0)

    def test_compute_tempo_loss(self):
        # Test exact match
        loss = self.loss_fn.compute_tempo_loss("120", "120")
        self.assertEqual(loss.item(), 0.0)

        # Test small difference
        loss = self.loss_fn.compute_tempo_loss("125", "120")
        self.assertLess(loss.item(), 0.5)

        # Test large difference
        loss = self.loss_fn.compute_tempo_loss("200", "120")
        self.assertGreater(loss.item(), 0.5)

    def test_compute_key_loss(self):
        # Test exact match
        loss = self.loss_fn.compute_key_loss("C", "C")
        self.assertEqual(loss.item(), 0.0)

        # Test adjacent keys on circle of fifths
        loss = self.loss_fn.compute_key_loss("C", "G")
        self.assertLess(loss.item(), 0.5)

        # Test opposite keys on circle of fifths
        loss = self.loss_fn.compute_key_loss("C", "F#")
        self.assertGreater(loss.item(), 0.5)

    def test_instrument_name_to_program(self):
        # Test common instruments
        self.assertGreaterEqual(self.loss_fn.instrument_name_to_program("Piano"), 0)
        self.assertGreaterEqual(self.loss_fn.instrument_name_to_program("Guitar"), 0)

        # Test invalid instrument
        self.assertEqual(self.loss_fn.instrument_name_to_program("InvalidInstrument"), -1)

    def test_complete_loss_computation(self):
        simple_pred = """X:1
M:4/4
L:1/16
K:C
Q:120
V:1 name="Piano"
CDEF|"""

        simple_gt = """X:1
M:4/4
L:1/16
K:C
Q:120
V:1 name="Piano"
CDEF|"""

        total_loss, loss_dict = self.loss_fn(simple_pred, simple_gt)

        # Check if loss is there
        self.assertIsInstance(total_loss, torch.Tensor)
        self.assertGreaterEqual(total_loss.item(), 0.0)
        self.assertLessEqual(total_loss.item(), 1.0)

        # Check if loss dictionary contains all components
        expected_keys = ['Total Loss', 'Pitch Loss', 'Metre Loss',
                        'Tempo Loss', 'Instrument Loss']
        for key in expected_keys:
            self.assertIn(key, loss_dict)

if __name__ == '__main__':
    unittest.main()

#### Training

**Ensure trl is on version 0.11.4 as the PPOTrainer (v1) is deprecated from 0.12.0 onward!**

In [None]:
from transformers import AutoTokenizer, AutoProcessor
from transformers import Qwen2AudioForConditionalGeneration
from trl import AutoModelForCausalLMWithValueHead
import torch.nn as nn

class Qwen2AudioForPPO(AutoModelForCausalLMWithValueHead):
    def __init__(self, pretrained_model):
        # We create a wrapper that exposes the lm_head and includes necessary attributes
        class LanguageModelWrapper(nn.Module):
            def __init__(self, language_model):
                super().__init__()
                self.language_model = language_model
                self.lm_head = language_model.lm_head  # Expose the lm_head
                self.config = language_model.config  # Include the config
                # Include other necessary attributes
                self.prepare_inputs_for_generation = language_model.prepare_inputs_for_generation
                self.get_output_embeddings = language_model.get_output_embeddings
                self.get_input_embeddings = language_model.get_input_embeddings

            def forward(self, *args, **kwargs):
                return self.language_model(*args, **kwargs)

        # Pass the wrapped language model to the parent class
        wrapped_lm = LanguageModelWrapper(pretrained_model.language_model)
        super().__init__(wrapped_lm)
        self.pretrained_model = pretrained_model

    def forward(self, input_features=None, feature_attention_mask=None, input_ids=None, attention_mask=None, **kwargs):
        outputs = self.pretrained_model(
            input_features=input_features,
            feature_attention_mask=feature_attention_mask,
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            **kwargs
        )

        last_hidden_state = outputs.hidden_states[-1]
        value = self.v_head(last_hidden_state).squeeze(-1)

        return outputs.logits, outputs.hidden_states, value

    def generate(self, *args, **kwargs):
        return self.pretrained_model.generate(*args, **kwargs)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")

model_for_ppo = Qwen2AudioForPPO(model)

In [None]:
import torch
from torch import nn
from transformers import (
    Qwen2AudioForConditionalGeneration,
    AutoProcessor,
    PreTrainedModel,
)
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW
import bitsandbytes as bnb

ppo_config = PPOConfig(
    model_name="Qwen/Qwen2-Audio-7B-Instruct",
    learning_rate=5e-5,          # Other options: [5e-5]
    batch_size=8,                # Other options: [16, 32]
    mini_batch_size=4,           # Should divide batch_size; other options: [2, 8]
    gradient_accumulation_steps=1,  # Increase if batch_size is too large
    ppo_epochs=4,                # Other options: [5]
    max_grad_norm=10.0,           # Other options: [0.5]
    cliprange=0.2,               # Other options: [0.1]
    # gradient_checkpointing=True,  # Uncomment to save memory
    # seed=3407,
    log_with='wandb',
)

optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=ppo_config.learning_rate, weight_decay=0.001)

def your_data_collator(batch):
    queries = [item[1] for item in batch]
    audios = [item[0] for item in batch]
    answers = [item[2] for item in batch]
    return {"queries": queries, "audios": audios, "answers": answers}

ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model_for_ppo,
    tokenizer=processor.tokenizer,
    dataset=urmp_dataset,
    data_collator=your_data_collator,
)

abc_loss_fn = ABCLoss()

def compute_rewards(answers, responses):
    rewards = []
    for i in range(0, len(responses)):
        loss, loss_dict = abc_loss_fn(answers[i], responses[i])
        reward = torch.tensor(1 - loss, device=inputs['input_ids'].device, dtype=torch.float32)
        rewards.append(reward)

    print(f"rewards {rewards}")
    return rewards

import json

def log_to_file(predicted_responses, target_responses, rewards, file_path='ppo_training_log.txt'):
    with open(file_path, 'a') as f:
        for pred, target, reward in zip(predicted_responses, target_responses, rewards):
            log_entry = {
                'predicted_response': pred,
                'target_response': target,
                'reward': reward.item() if isinstance(reward, torch.Tensor) else reward
            }
            f.write(json.dumps(log_entry) + '\n')

for epoch in range(ppo_config.ppo_epochs):
    for batch in ppo_trainer.dataloader:
        queries = batch["queries"]
        audios = batch["audios"]
        answers = batch["answers"]

        texts = []
        for query in queries:
            conversation = [
                {'role': 'system', 'content': 'You are a helpful assistant.'},
                {"role": "user", "content": [
                    {"type": "audio", "audio_url": "https://example.com/audio.mp3"},
                    {"type": "text", "text": query},
                ]},
            ]
            text = processor.apply_chat_template(
                conversation, tokenize=False, add_generation_prompt=True
            )
            texts.append(text)

        inputs = processor(
            text=texts,
            audios=audios,
            return_tensors="pt",
            padding=True,
            sampling_rate=16000,
        )

        inputs = {k: v.to('cuda') for k, v in inputs.items()}

        with torch.no_grad():
            response_ids = model_for_ppo.generate(
                **inputs,
                do_sample=True, top_p=0.8, temperature=1, max_length=2048, use_cache=True
            )

        generated_responses = response_ids[:, inputs['input_ids'].size(1):]

        # Decode responses
        responses = processor.tokenizer.batch_decode(generated_responses, skip_special_tokens=True)

        print(f"responses: {responses}")

        # Prepare query and response tensors
        # Tokenize queries separately to get their token lengths
        tokenized_queries = processor.tokenizer(queries, return_tensors='pt', padding=True)
        query_lengths = [len(q) for q in tokenized_queries['input_ids']]

        query_tensors = [input_ids[:ql] for input_ids, ql in zip(inputs['input_ids'], query_lengths)]
        response_tensors = [ids[ql:] for ids, ql in zip(response_ids, query_lengths)]

        rewards = compute_rewards(answers, responses)

        log_to_file(responses, answers, rewards)

        with torch.amp.autocast(device_type='cuda'):
            stats = ppo_trainer.step(
                queries=query_tensors,
                responses=response_tensors,
                scores=rewards
            )

            ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=['queries', 'answers'])

In [None]:
model.save_pretrained("/content/qwen_audio_finetune_3")

In [None]:
!cp -r "/content/qwen_audio_finetune_3" "/content/drive/My Drive/automatic-music-transcription/saved_models/"

## Try model

### Test model manually

In [None]:
!cp -r "/content/drive/My Drive/automatic-music-transcription/saved_models/qwen_audio_finetune_2" "/content/"

In [None]:
from transformers import TrainingArguments, Qwen2AudioForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType
from trl import SFTTrainer
import torch

model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", torch_dtype=torch.bfloat16, device_map="auto")

In [None]:
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=True, r=128, lora_alpha=256, lora_dropout=0, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
model = get_peft_model(model, peft_config)

In [None]:
model.load_adapter("/content/qwen_audio_finetune_2", adapter_name="test")

In [None]:
urmp_dataset = URMPDataset(root_dir="/content/URMP_Dataset/Dataset", dataset_name="urmp", return_audio=True)

In [None]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct" ,trust_remote_code=True)

In [None]:
audio = urmp_dataset[22][0]
query = urmp_dataset[22][1]
answer = urmp_dataset[22][2]
query, answer

In [None]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct" ,trust_remote_code=True)

conversation = [
    {'role': 'system', 'content': 'You are a helpful assistant.'},
    {"role": "user", "content": [
        # audio_url is required despite irrelevant!
        {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
        {"type": "text", "text": query},
    ]},
]

text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
texts = [text]
audios = [audio]

inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True, sampling_rate=16000)
inputs.to("cuda")

In [None]:
generate_ids = model.generate(**inputs, do_sample=True, top_p=0.8, temperature=1.3, max_length=2048, use_cache=True)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]

response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
response

### Run model over evaluation dataset

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [None]:
import random
import torch
import datetime

# Set random seed for reproducibility
random.seed(42)

# Generate 30 random unique indices
dataset_size = len(urmp_dataset)
random_indices = random.sample(range(dataset_size), 30)

# Create a timestamp for the log file
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = f"/content/qwen2audio_original_predictions_{timestamp}.txt"

def run_inference(audio, query):
    try:
        conversation = [
            {'role': 'system', 'content': 'You are a helpful assistant.'},
            {"role": "user", "content": [
                {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
                {"type": "text", "text": query},
            ]},
        ]

        text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        inputs = processor(text=text, audios=[audio], return_tensors="pt", padding=True, sampling_rate=16000)
        inputs = inputs.to("cuda")

        generate_ids = model.generate(**inputs, do_sample=True, top_p=0.7, temperature=1.3, max_length=2048, use_cache=True)
        generate_ids = generate_ids[:, inputs.input_ids.size(1):]

        prediction = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        return prediction, None
    except Exception as e:
        return None, str(e)

# Process each sample and write to log file
with open(log_file, 'w', encoding='utf-8') as f:
    for sample_num, idx in enumerate(random_indices, 1):
        print(f"Processing sample {idx} ({sample_num}/20)")

        try:
            # Get sample data
            audio, query, answer, _, _, _ = urmp_dataset[idx]

            # Write sample information
            f.write(f"Sample {sample_num} (Index: {idx})\n")
            f.write(f"Question: {query}\n")
            f.write(f"Ground Truth: {answer}\n")

            # Run inference 3 times
            f.write("Model Predictions:\n")
            for i in range(3):
                print(f"  Running prediction {i+1}/3 for sample {idx}")
                prediction, error = run_inference(audio, query)

                if prediction is not None:
                    f.write(f"Prediction {i+1}: {prediction}\n")
                else:
                    f.write(f"Prediction {i+1} failed: {error}\n")

                # Flush the file after each prediction
                f.flush()

                # Clear CUDA cache
                torch.cuda.empty_cache()

            torch.cuda.empty_cache()

            f.write("\n" + "="*30 + "\n\n")
            # Flush the file after each sample
            f.flush()

        except Exception as e:
            print(f"Error processing sample {idx}: {str(e)}")
            f.write(f"Error processing sample {idx}: {str(e)}\n")
            f.write("\n" + "="*30 + "\n\n")
            f.flush()
            continue

print(f"Results have been logged to: {log_file}")

In [None]:
!cp "{log_file}" "/content/drive/My Drive/automatic-music-transcription/logs/"

In [None]:
from google.colab import runtime
runtime.unassign()