This notebook conducts inference and analysis for [ReconVAT](https://arxiv.org/abs/2107.04954) model via the ReconVAT [github](https://github.com/KinWaiCheuk/ReconVAT).

To run the notebook, runtime type should be set to 'GPU' for best performance; otherwise use the script `transcribe_files_cpu.py`.

This script assumes validation and test data are located in the subdirectories `ReconVAT/MusicNet/validation` and `ReconVAT/MusicNet/test`. For exact reproducibility of our results using ReconVAT, these should be .flac files exported from MusicNet audio as described in the repo above, divided into 20s chunks as [suggested](https://github.com/KinWaiCheuk/ReconVAT/issues/1) by the ReconVAT authors.

It is also necessary to manually download the weight files for `string_musicnet` linked in the ReconVAT git repo and place them at `Reconvat/Weight`.

# Setup

In [3]:
import os
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

 # This is an exact clone of the repo at the github link above. It is more convenient to write to 
 # persistent Google Drive folders  instead of ephemeral colab storage.

DIRPATH = '/content/drive/MyDrive/mt3-baselines'

if not os.path.exists(DIRPATH):
  os.makedirs(DIRPATH)
os.chdir(DIRPATH)

Mounted at /content/drive


In [4]:
!git clone https://github.com/KinWaiCheuk/ReconVAT.git && cd ReconVAT

Cloning into 'ReconVAT'...
remote: Enumerating objects: 147, done.[K
remote: Counting objects: 100% (147/147), done.[K
remote: Compressing objects: 100% (116/116), done.[K
remote: Total 147 (delta 37), reused 132 (delta 22), pack-reused 0[K
Receiving objects: 100% (147/147), 7.06 MiB | 16.78 MiB/s, done.
Resolving deltas: 100% (37/37), done.


In [5]:
os.chdir(os.path.join(DIRPATH, 'ReconVAT'))
!pwd

/content/drive/MyDrive/mt3-baselines/ReconVAT


In [6]:
os.makedirs(os.path.join(DIRPATH, 'ReconVAT', 'Application', 'Input'))
os.makedirs(os.path.join(DIRPATH, 'ReconVAT', 'Application', 'Output'))

In [7]:
!pip3 install -r requirements.txt



In [8]:
!pip3 install note_seq

Collecting joblib>=0.14
  Downloading joblib-1.1.0-py2.py3-none-any.whl (306 kB)
[K     |████████████████████████████████| 306 kB 14.0 MB/s 
Installing collected packages: joblib
  Attempting uninstall: joblib
    Found existing installation: joblib 0.13.2
    Uninstalling joblib-0.13.2:
      Successfully uninstalled joblib-0.13.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
imbalanced-learn 0.8.1 requires scikit-learn>=0.24, but you have scikit-learn 0.19.2 which is incompatible.[0m
Successfully installed joblib-1.1.0


In [9]:
import note_seq
import pretty_midi
import numpy as np
from typing import Tuple
import sklearn
import mir_eval

VELOCITY_THRESHOLD = 30

## Utility functions for computing transcription metrics.

def aggregate_results(df):
  # Onset-offset
  onoff_precision_sum = (df['onoff_precision'] * df['len_est_intervals']).sum()
  onoff_precision_counts = df['len_est_intervals'].sum()
  onoff_precision = onoff_precision_sum / onoff_precision_counts

  onoff_recall_sum = (df['onoff_recall'] * df['len_ref_intervals']).sum()
  onoff_recall_counts = df['len_ref_intervals'].sum()
  onoff_recall = onoff_recall_sum / onoff_recall_counts
  onoff_f_measure = mir_eval.util.f_measure(onoff_precision, onoff_recall)

  # Onset only
  on_precision_sum = (df['on_precision'] * df['len_est_intervals']).sum()
  on_precision_counts = df['len_est_intervals'].sum()
  on_precision = on_precision_sum / on_precision_counts

  on_recall_sum = (df['on_recall'] * df['len_ref_intervals']).sum()
  on_recall_counts = df['len_ref_intervals'].sum()
  on_recall = on_recall_sum / on_recall_counts
  on_f_measure = mir_eval.util.f_measure(on_precision, on_recall)

  # Frame
  frame_precision_sum = (df['frame_precision'] * df['len_est_frames']).sum()
  frame_precision_counts = df['len_est_frames'].sum()
  frame_precision = frame_precision_sum / frame_precision_counts

  frame_recall_sum = (df['frame_recall'] * df['len_ref_frames']).sum()
  frame_recall_counts = df['len_ref_frames'].sum()
  frame_recall = frame_recall_sum / frame_recall_counts
  frame_f_measure = mir_eval.util.f_measure(frame_precision, frame_recall)
  print(f" f / prec / rec: \n## ONSET/OFFSET: {onoff_f_measure} / {onoff_precision} / {onoff_recall} \n## ONSET: {on_f_measure} / {on_precision} / {on_recall} \n## FRAME : {frame_f_measure} / {frame_precision} / {frame_recall}")


def get_prettymidi_pianoroll(ns: note_seq.NoteSequence, fps: float,
                             is_drum: bool):
  """Convert NoteSequence to pianoroll through pretty_midi."""
  for note in ns.notes:
    if is_drum or note.end_time - note.start_time < 0.05:
      # Give all drum notes a fixed length, and all others a min length
      note.end_time = note.start_time + 0.05

  pm = note_seq.note_sequence_to_pretty_midi(ns)
  end_time = pm.get_end_time()
  cc = [
      # all sound off
      pretty_midi.ControlChange(number=120, value=0, time=end_time),
      # all notes off
      pretty_midi.ControlChange(number=123, value=0, time=end_time)
  ]
  pm.instruments[0].control_changes = cc
  if is_drum:
    # If inst.is_drum is set, pretty_midi will return an all zero pianoroll.
    for inst in pm.instruments:
      inst.is_drum = False
  pianoroll = pm.get_piano_roll(fs=fps)
  return pianoroll


def frame_metrics(ref_pianoroll: np.ndarray,
                  est_pianoroll: np.ndarray,
                  velocity_threshold: int) -> Tuple[float, float, float]:
  """Frame Precision, Recall, and F1."""
  # Pad to same length
  if ref_pianoroll.shape[1] > est_pianoroll.shape[1]:
    diff = ref_pianoroll.shape[1] - est_pianoroll.shape[1]
    est_pianoroll = np.pad(est_pianoroll, [(0, 0), (0, diff)], mode='constant')
  elif est_pianoroll.shape[1] > ref_pianoroll.shape[1]:
    diff = est_pianoroll.shape[1] - ref_pianoroll.shape[1]
    ref_pianoroll = np.pad(ref_pianoroll, [(0, 0), (0, diff)], mode='constant')

  # For ref, remove any notes that are too quiet (consistent with Cerberus.)
  ref_frames_bool = ref_pianoroll > velocity_threshold
  # For est, keep all predicted notes.
  est_frames_bool = est_pianoroll > 0

  precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support(
      ref_frames_bool.flatten(),
      est_frames_bool.flatten(),
      labels=[True, False])

  return precision[0], recall[0], f1[0]




def compute_transcription_metrics(ns_ref, ns_est):
  """Helper function to compute onset/offset, onset only, and frame metrics."""
  intervals_ref, pitches_ref, _ = note_seq.sequences_lib.sequence_to_valued_intervals(ns_ref)
  intervals_est, pitches_est, _ = note_seq.sequences_lib.sequence_to_valued_intervals(ns_est)
  len_est_intervals = len(intervals_est)
  len_ref_intervals = len(intervals_ref)

  # onset-offset
  onoff_precision, onoff_recall, onoff_f1, onoff_overlap = mir_eval.transcription.precision_recall_f1_overlap(
    intervals_ref, pitches_ref, intervals_est, pitches_est)

  # onset-only
  on_precision, on_recall, on_f1, on_overlap = mir_eval.transcription.precision_recall_f1_overlap(
    intervals_ref, pitches_ref, intervals_est, pitches_est, offset_ratio=None)

  # frame
  ref_pr = get_prettymidi_pianoroll(ns_ref, fps=62.5,
                                                        is_drum=False)
  est_pr = get_prettymidi_pianoroll(ns_est, fps=62.5,
                                                        is_drum=False)
  # For ref, remove any notes that are too quiet (consistent with Cerberus.)
  len_ref_frames = (ref_pr > VELOCITY_THRESHOLD).sum()
  # For est, keep all predicted notes.
  len_est_frames = (est_pr > 0).sum()

  frame_precision, frame_recall, frame_f1 = frame_metrics(
      ref_pr, est_pr, velocity_threshold=VELOCITY_THRESHOLD)
  return {
      'len_ref_intervals': len_ref_intervals, 
      'len_est_intervals': len_est_intervals,
      'onoff_precision': onoff_precision, 
      'onoff_recall': onoff_recall, 
      'onoff_f1': onoff_f1, 
      'onoff_overlap': onoff_overlap, 
      'on_precision': on_precision, 
      'on_recall': on_recall, 
      'on_f1': on_f1, 
      'on_overlap': on_overlap,
      'frame_precision': frame_precision, 
      'frame_recall': frame_recall, 
      'frame_f1': frame_f1,
      'len_ref_frames': len_ref_frames,
      'len_est_frames': len_est_frames,
  }

  from collections import Mapping, defaultdict


In [10]:
def initialize_input(split):
  """Copy the audio files corresponding to a split into the ReconVAT input directory."""
  track_ids = TRACKS_BY_SPLIT[split]
  for track_id in track_ids:
    track_chunk_files = glob.glob(f"./MusicNet/{split}/{track_id}_*.flac")
    for f in track_chunk_files:
      dest = os.path.join(DIRPATH, "ReconVAT/Application/Input", os.path.basename(f))
      print(f"copying {f} to {dest}")
      shutil.copyfile(f, dest)

In [11]:
import glob
import tqdm


def _parse_filename(filepath): 
  """Split a filepath to extract the track ID and chunk number."""
  fname = os.path.basename(filepath)[:-5]
  track_id, chunk_id = fname.split("_")
  return track_id, chunk_id


def collect_inference_results(split):
  """Fetch the transcription results for a split, and compute relevant metrics."""
  results = list() 
  split_track_ids = TRACKS_BY_SPLIT[split]
  for midi_ref in tqdm.tqdm(glob.glob(f"MusicNet/{split}/*.midi")):
    track_id, chunk_id = _parse_filename(midi_ref)
    if track_id in split_track_ids:
      ns_ref = note_seq.midi_file_to_note_sequence(midi_ref)
      if len(ns_ref.notes):
        midi_est = os.path.join("Application/Output", "ReconVAT-" + os.path.basename(midi_ref)[:-1]) # output uses .mid, not .midi
        ns_est = note_seq.midi_file_to_note_sequence(midi_est)

        track_metrics = compute_transcription_metrics(ns_ref, ns_est)
        track_metrics["track_id"] = str(track_id)
        track_metrics["chunk_id"] = str(chunk_id)
        results.append(track_metrics)
  return results

# Inference

ReconVAT requires `.flac` files for inference.

This portion of the code assumes that a copy of MusicNet audio and MIDI files has been exported to `.flac`  and `.midi` format, respectively, and stored in the directory structure below.

```
./MusicNet/
|---validation
   1733_0.flac
   1733_0.midi
   1733_1.flac
   1733_1.midi
   ...
   2611_27.flac
   2611_27.midi

|---test
   1729_0.flac
   1729_0.midi
   ...
   2621_32.flac
   2621_32.midi

```

In [19]:
# After adding MusicNet files and trained model weights to the ReconVAT directory, you may need to run the following:

# drive.mount('/content/drive', force_remount=True)
# os.chdir(os.path.join(DIRPATH, "ReconVAT"))

Mounted at /content/drive


In [22]:
!ls ./MusicNet && ls -l ./MusicNet/validation | head -n 5 

!ls -l ./Weight/String_MusicNet/

test  validation
total 98341
-rw------- 1 root root 237705 Sep 30 06:37 1733_0.flac
-rw------- 1 root root    334 Sep 30 06:37 1733_0.midi
-rw------- 1 root root 366016 Sep 30 06:37 1733_10.flac
-rw------- 1 root root    761 Sep 30 06:37 1733_10.midi
total 8
drwx------ 2 root root 4096 Sep 21 06:03  baseline_Multi_Inst
drwx------ 2 root root 4096 Sep 21 06:02 'Unet_R_VAT-XI=1e-06-eps=1.3-String_MusicNet-lr=0.001'


In [23]:
import glob
import os
import shutil

TRACKS_BY_SPLIT = {
    "validation": ['2336', '2466', '2160', '1818', '1733', '1765', '2198',
                   '2300', '2308', '2477', '2611', '2289', '1790', '2315', '2504',],
    "test": [ '2118', '2501', '1813', '1729', '1893', '2296', '1776', '2487',
             '2537', '2186', '2431', '2432', '2497', '2621', '2507', ]}

In [24]:
# Copy inference chunks into the input dir
initialize_input("validation")

copying ./MusicNet/validation/2336_8.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2336_8.flac
copying ./MusicNet/validation/2336_4.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2336_4.flac
copying ./MusicNet/validation/2336_3.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2336_3.flac
copying ./MusicNet/validation/2336_1.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2336_1.flac
copying ./MusicNet/validation/2336_5.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2336_5.flac
copying ./MusicNet/validation/2336_2.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2336_2.flac
copying ./MusicNet/validation/2336_6.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2336_6.flac
copying ./MusicNet/validation/2336_7.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2336_7.flac
copying ./MusicNet/validation/23

In [25]:
# uncomment to run CPU; this can allow for larger tracks (although is still likely to OOM for full-length audio tracks)
# !python3 transcribe_files_cpu.py

# to use this, ensure accelerator type is set to 'gpu' for the colab notebook.
!python3 transcribe_files.py

INFO - transcription - Running command 'main'
INFO - transcription - Started
Loading files: 100% 300/300 [00:02<00:00, 144.42it/s]
STFT kernels created, time used = 0.2966 seconds
STFT filter created, time used = 0.0037 seconds
Mel filter created, time used = 0.0037 seconds
Loading model weight
Loading done
Transcribing Music
  0% 0/300 [00:00<?, ?it/s]midi_path = Application/Output/ReconVAT-2336_8.mid
  0% 1/300 [00:00<00:32,  9.15it/s]midi_path = Application/Output/ReconVAT-2336_4.mid
  1% 2/300 [00:00<00:32,  9.14it/s]midi_path = Application/Output/ReconVAT-2336_3.mid
midi_path = Application/Output/ReconVAT-2336_1.mid
  1% 4/300 [00:00<00:29, 10.12it/s]midi_path = Application/Output/ReconVAT-2336_5.mid
  2% 5/300 [00:00<00:30,  9.64it/s]midi_path = Application/Output/ReconVAT-2336_2.mid
midi_path = Application/Output/ReconVAT-2336_6.mid
  2% 7/300 [00:00<00:27, 10.81it/s]midi_path = Application/Output/ReconVAT-2336_7.mid
midi_path = Application/Output/ReconVAT-2336_0.mid
  3% 9/300 

In [26]:
!ls Application/Output | head -n 20

ReconVAT-1733_0.mid
ReconVAT-1733_10.mid
ReconVAT-1733_11.mid
ReconVAT-1733_12.mid
ReconVAT-1733_13.mid
ReconVAT-1733_14.mid
ReconVAT-1733_15.mid
ReconVAT-1733_16.mid
ReconVAT-1733_17.mid
ReconVAT-1733_18.mid
ReconVAT-1733_19.mid
ReconVAT-1733_1.mid
ReconVAT-1733_20.mid
ReconVAT-1733_21.mid
ReconVAT-1733_22.mid
ReconVAT-1733_23.mid
ReconVAT-1733_24.mid
ReconVAT-1733_25.mid
ReconVAT-1733_26.mid
ReconVAT-1733_27.mid


# Compute validation set transcription metrics

In [27]:
results = collect_inference_results("validation")

100%|██████████| 300/300 [00:15<00:00, 19.54it/s]


In [28]:
import pandas as pd
results_df = pd.DataFrame.from_dict(results)
results_df.to_csv("reconvat_musicnet_validation_inference_results_by_track.csv", index=False)
aggregate_results(results_df)

 f / prec / rec: 
## ONSET/OFFSET: 0.09019987886129618 / 0.08838401823231963 / 0.09209191876716057 
## ONSET: 0.2721502119927317 / 0.2666714146665717 / 0.2778588567046776 
## FRAME : 0.4633491428492872 / 0.6905161506034759 / 0.3486498183951744


# Inference (Test)

In [29]:
!rm Application/Input/*
!rm Application/Output/*

In [30]:
# Copy inference chunks into the input dir
initialize_input("test")

copying ./MusicNet/test/2118_6.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2118_6.flac
copying ./MusicNet/test/2118_4.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2118_4.flac
copying ./MusicNet/test/2118_8.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2118_8.flac
copying ./MusicNet/test/2118_2.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2118_2.flac
copying ./MusicNet/test/2118_1.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2118_1.flac
copying ./MusicNet/test/2118_11.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2118_11.flac
copying ./MusicNet/test/2118_7.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2118_7.flac
copying ./MusicNet/test/2118_9.flac to /content/drive/MyDrive/mt3-baselines/ReconVAT/Application/Input/2118_9.flac
copying ./MusicNet/test/2118_0.flac to /content/drive/MyDrive/mt3-baselines/Re

In [31]:
# uncomment to run CPU; this can allow for larger tracks (although is still likely to OOM for full-length audio tracks)
# !python3 transcribe_files_cpu.py

# to use this, ensure accelerator type is set to 'gpu' for the colab notebook.
!python3 transcribe_files.py

INFO - transcription - Running command 'main'
INFO - transcription - Started
Loading files: 100% 279/279 [00:01<00:00, 149.24it/s]
STFT kernels created, time used = 0.2750 seconds
STFT filter created, time used = 0.0045 seconds
Mel filter created, time used = 0.0045 seconds
Loading model weight
Loading done
Transcribing Music
  0% 0/279 [00:00<?, ?it/s]midi_path = Application/Output/ReconVAT-2118_6.mid
  0% 1/279 [00:00<00:30,  9.01it/s]midi_path = Application/Output/ReconVAT-2118_4.mid
midi_path = Application/Output/ReconVAT-2118_8.mid
  1% 3/279 [00:00<00:27,  9.96it/s]midi_path = Application/Output/ReconVAT-2118_2.mid
midi_path = Application/Output/ReconVAT-2118_1.mid
  2% 5/279 [00:00<00:24, 11.21it/s]midi_path = Application/Output/ReconVAT-2118_11.mid
midi_path = Application/Output/ReconVAT-2118_7.mid
  3% 7/279 [00:00<00:21, 12.90it/s]midi_path = Application/Output/ReconVAT-2118_9.mid
midi_path = Application/Output/ReconVAT-2118_0.mid
  3% 9/279 [00:00<00:20, 13.14it/s]midi_path 

In [32]:
!ls Application/Output | head -n 20

ReconVAT-1729_0.mid
ReconVAT-1729_10.mid
ReconVAT-1729_11.mid
ReconVAT-1729_12.mid
ReconVAT-1729_13.mid
ReconVAT-1729_14.mid
ReconVAT-1729_15.mid
ReconVAT-1729_16.mid
ReconVAT-1729_17.mid
ReconVAT-1729_18.mid
ReconVAT-1729_19.mid
ReconVAT-1729_1.mid
ReconVAT-1729_20.mid
ReconVAT-1729_21.mid
ReconVAT-1729_2.mid
ReconVAT-1729_3.mid
ReconVAT-1729_4.mid
ReconVAT-1729_5.mid
ReconVAT-1729_6.mid
ReconVAT-1729_7.mid


In [33]:
results = collect_inference_results("test")

100%|██████████| 279/279 [00:14<00:00, 19.61it/s]


In [34]:
import pandas as pd
results_df = pd.DataFrame.from_dict(results)
results_df.to_csv("reconvat_musicnet_test_inference_results_by_track.csv", index=False)

# Computes results shown in https://arxiv.org/pdf/2111.03017.pdf, Table 2
aggregate_results(results_df)

 f / prec / rec: 
## ONSET/OFFSET: 0.11311974588016308 / 0.1265161115133961 / 0.10228873561536068 
## ONSET: 0.2947466388807397 / 0.3296524257784214 / 0.26652518340986847 
## FRAME : 0.48004957645964097 / 0.6751211756880742 / 0.37243659541886054
