<a href="https://colab.research.google.com/github/marinaniet0/groove_beat_tracking/blob/main/Groove_Beat_Downbeat_Tracking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Beat and downbeat tracking

In this notebook, two models are trained and evaluated on the task of beat and downbeat tracking using the Groove Midi Dataset. The data formatting, the models and the hyper parameters are taken (with some small modifications to fit the dataset used) from [[1]](#1).

<a id="1">[1]</a> 
Y. -C. Chuang and L. Su, "Beat and Downbeat Tracking of Symbolic Music Data Using Deep Recurrent Neural Networks", *2020 Asia-Pacific Signal and Information Processing Association Annual Summit and Conference (APSIPA ASC)*, 2020, pp. 346-352. Implementation: https://github.com/chuang76/symbolic-beat-tracking/

## Model architecture

The two models trained are a "vanilla" BLSTM and a BLSTM with an attention mechanism. In both cases we have a Recurrent Neural Network with 2 layers and
25 units per layer. The models output the probability of a time step being a beat or not, and the probability of it being a down beat or not. The models are trained over 50 epochs, using the Binary Cross Entropy loss of beats and downbeats. To transform the probabilities, the outputs are thresholded (0.3 for the BLSTM and 0.2 for the BLSTM + attention). An overview of the architecture is shown below in images.

### BLSTM

![](https://drive.google.com/uc?export=view&id=1BKMz9rfHZhdABfL9Soo1X1Xy-ZIPvcoE)

### BLSTM + Attention

![](https://drive.google.com/uc?export=view&id=11WpRqaU_20G6AP-C-cls-BpjSCOk3Vv0)

## Dataset preprocessing

The MIDI-only version of the [Groove MIDI Dataset](https://magenta.tensorflow.org/datasets/groove) is downloaded and processed here.  
Since the dataset is already split in train, test and validation sets, I use the train subset for its intended purpose and the validation set to compute the evaluation later on.

In [1]:
# Installing pretty_midi - Necessary for the preprocessing
%%time
!pip install pretty_midi

Collecting pretty_midi
[?25l  Downloading https://files.pythonhosted.org/packages/bc/8e/63c6e39a7a64623a9cd6aec530070c70827f6f8f40deec938f323d7b1e15/pretty_midi-0.2.9.tar.gz (5.6MB)
[K     |████████████████████████████████| 5.6MB 4.7MB/s 
Collecting mido>=1.1.16
[?25l  Downloading https://files.pythonhosted.org/packages/b5/6d/e18a5b59ff086e1cd61d7fbf943d86c5f593a4e68bfc60215ab74210b22b/mido-1.2.10-py2.py3-none-any.whl (51kB)
[K     |████████████████████████████████| 51kB 6.4MB/s 
Building wheels for collected packages: pretty-midi
  Building wheel for pretty-midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty-midi: filename=pretty_midi-0.2.9-cp37-none-any.whl size=5591954 sha256=e85c2072f5c7a5489813ad2653378b8c835a4ff16a5553515db3f7a497a60b84
  Stored in directory: /root/.cache/pip/wheels/4c/a1/c6/b5697841db1112c6e5866d75a6b6bf1bef73b874782556ba66
Successfully built pretty-midi
Installing collected packages: mido, pretty-midi
Successfully installed mido-1.2.10 pretty-mi

In [None]:
# Downloading and unzipping the Groove MIDI Dataset
!wget https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip
!unzip /content/groove-v1.0.0-midionly.zip

In [3]:
# Imports
import os
import pandas as pd
import pretty_midi
import numpy as np
from scipy.interpolate import interp1d
import pickle
import glob

In [4]:
current_dir = os.getcwd() + '/'
groove_dir = current_dir + 'groove/'
train_dir = current_dir + 'train/'
validation_dir = current_dir + 'validation/'

if not os.path.exists(train_dir):
  os.makedirs(train_dir)

if not os.path.exists(validation_dir):
  os.makedirs(validation_dir)

In [6]:
# Load dataframe and display Groove info examples
groove_info = pd.read_csv(groove_dir + 'info.csv')
groove_info.head()

Unnamed: 0,drummer,session,id,style,bpm,beat_type,time_signature,midi_filename,audio_filename,duration,split
0,drummer1,drummer1/eval_session,drummer1/eval_session/1,funk/groove1,138,beat,4-4,drummer1/eval_session/1_funk-groove1_138_beat_...,drummer1/eval_session/1_funk-groove1_138_beat_...,27.872308,test
1,drummer1,drummer1/eval_session,drummer1/eval_session/10,soul/groove10,102,beat,4-4,drummer1/eval_session/10_soul-groove10_102_bea...,drummer1/eval_session/10_soul-groove10_102_bea...,37.691158,test
2,drummer1,drummer1/eval_session,drummer1/eval_session/2,funk/groove2,105,beat,4-4,drummer1/eval_session/2_funk-groove2_105_beat_...,drummer1/eval_session/2_funk-groove2_105_beat_...,36.351218,test
3,drummer1,drummer1/eval_session,drummer1/eval_session/3,soul/groove3,86,beat,4-4,drummer1/eval_session/3_soul-groove3_86_beat_4...,drummer1/eval_session/3_soul-groove3_86_beat_4...,44.716543,test
4,drummer1,drummer1/eval_session,drummer1/eval_session/4,soul/groove4,80,beat,4-4,drummer1/eval_session/4_soul-groove4_80_beat_4...,drummer1/eval_session/4_soul-groove4_80_beat_4...,47.9875,test


In order to save and load the processing and the models, I used Google Drive, but for reproduction's sake, I will do the processing in the current working directory.

In [27]:
# Each beat value is mapped to a grid of resolution 16 divisions per beat
def find_nearest_to_grid(value):
  grid = np.arange(17) / 16.0
  idx = (np.abs(grid - value)).argmin()
  return grid[idx]

# Sync notes that are on the same start beat to the start time of the first of
# those notes
def align_notes_per_beat(dataframe):
  for idx, row in dataframe.iterrows():
    dataframe.loc[dataframe.index[dataframe['start_beat'] ==\
                  row['start_beat']],'start_time'] =\
                  dataframe.loc[dataframe.index[dataframe['start_beat'] ==\
                  row['start_beat']][0],'start_time']

# With this function, CSV files with the same format as those from Chuang et al.
# are created from the midi files
def midi2csv(midi_file, bpm, time_signature, save_dir):

  if not os.path.exists(save_dir):
    os.makedirs(save_dir)

  data = pretty_midi.PrettyMIDI(midi_file)
  name = midi_file.split('/')[-1].split('.')[0]

  time_signature_num = int(time_signature.split('-')[0])
  time_signature_den = int(time_signature.split('-')[1])

  beat_secs = (60/bpm) / (time_signature_den/4) # n beats in a second
  downbeat_samples = beat_secs * time_signature_num

  midi_list = []

  for instrument in data.instruments:
    for note in instrument.notes:
      start = int(note.start * 44100)
      end = int(note.end * 44100)
      pitch = note.pitch
      start_beat = note.start / beat_secs
      start_beat_decimal_grid = find_nearest_to_grid(start_beat % 1)
      start_beat_grid = float(str(start_beat).split('.')[0]) + start_beat_decimal_grid
      midi_list.append([start, end, start_beat_grid, pitch])

  midi_list = sorted(midi_list, key=lambda x: (x[0], x[3]))
  df = pd.DataFrame(midi_list, columns=['start_time', 'end_time', 'start_beat', 'note'])
  align_notes_per_beat(df)
  new_path = save_dir + midi_file.split('.')[0].split('/')[-1]
  new_filename = new_path +'.csv'
  df.to_csv(new_filename)
  return new_filename

In [8]:
# Create a list with the subsets of midi files
train_file_list = []
validation_file_list = []

groove_info = pd.read_csv(groove_dir + 'info.csv')

for i in np.arange(len(groove_info)):
  if groove_info.iloc[i]['split'] == 'train':
    train_file_list.append(groove_dir + groove_info.iloc[i]['midi_filename'])
  else:
    validation_file_list.append(groove_dir + groove_info.iloc[i]['midi_filename'])

In [28]:
%%time
spinner=['⠛','⠗','⠕','⠕','⠧','⠑']
count = 0

train_csv_list = []
validation_csv_list = []

train_csv_dir = train_dir + 'csv/'
validation_csv_dir = validation_dir + 'csv/'

if not os.path.exists(train_csv_dir):
  os.makedirs(train_csv_dir)
if not os.path.exists(validation_csv_dir):
  os.makedirs(validation_csv_dir)

i = 0
for file in train_file_list:
  time_signature = file.split('.')[0].split('_')[-1]
  bpm = int(file.split('_')[-3])
  print("\r" + spinner[count%6] + "  " + str(i) + "/" +\
        str(len(train_file_list)), file, time_signature, str(bpm), sep=" | ",\
        end="")
  train_csv_list.append(midi2csv(file, bpm, time_signature, train_csv_dir))
  count += 1
  i += 1

i = 0
for file in validation_file_list:
  time_signature = file.split('.')[0].split('_')[-1]
  bpm = int(file.split('_')[-3])
  print("\r" + spinner[count%6] + "  " + str(i) + "/" +\
        str(len(validation_file_list)), file, time_signature, str(bpm),\
        sep=" | ", end="")
  validation_csv_list.append(midi2csv(file, bpm, time_signature,\
                                      validation_csv_dir))
  count += 1
  i += 1

⠕  252/253 | /content/groove/drummer2/session2/2_rock_130_beat_4-4.mid | 4-4 | 130CPU times: user 9min 23s, sys: 25.8 s, total: 9min 49s
Wall time: 9min 15s


Now that we have the csv files similar to the ones used in Chuang et al., we calculate the X, Yb and Yd that will serve as input and output to the model.

In [29]:
midi_notes_groove = [22,26,36,37,38,40,42,43,44,45,46,47,48,49,50,51,52,53,55,
                      57,58,59]
def preprocess_data(dir_csv, dir_npz, csv_list, window_sz, large_save_dir,\
                    save_name, join=True):
  # No need to calculate onsets, offsets, etc. for notes not used in the dataset
  # and this is the reduced mapping of the Groove Dataset
  
  n_notes = len(midi_notes_groove)
  print(str(len(csv_list)) + " files!")

  # Template arrays for the onsets and durations
  N = []
  for i in midi_notes_groove:
    N.append(str(i) + '/onset')
  for i in midi_notes_groove:
    N.append(str(i) + '/dur')

  # Iterate through csv files
  for file_idx in range(len(csv_list)):
    data = pd.read_csv(csv_list[file_idx])
    name = csv_list[file_idx].split('/')[-1]

    print("\r" + "  " + str(file_idx+1) + '/' +\
          str(len(csv_list)), name, end="")

    # Create array of resolution 100 frames per second up to maximum end_time
    length = int(data['end_time'].max() / 44100)
    length = length if length > 0 else 1
    z = pd.DataFrame(np.zeros((length * 100, 1)))

    # Initialize IOI, beat and downbeat columns
    new = pd.DataFrame()
    for i in range(len(N)):
      new[N[i]] = z[0]
    new['IOI'], new['beat'], new['downbeat'] = z[0], z[0], z[0]

    # Get time signature from file name
    time_signature = name.split('_')[-1].split('.')[0]
    time_signature_den = int(time_signature.split('-')[1])
    # to account for 6/8 time signature, divide by denom/4
    m = int(time_signature.split('-')[0]) / (time_signature_den/4)

    # Get last beat
    B = int(data['start_beat'].max())

    # These are the beats that should be detected - reference list
    ref_list = np.arange(B+1)


    # Creating x,y and interpolation lists
    x_list, y_list = [], []
    x_interp = data['start_beat'].to_numpy()
    y_interp = np.around(data['start_time'] / 44100 * 100, 0).to_numpy()

    # Appending integer existing beats to x and y lists
    for i in range(len(data)):
      if data['start_beat'][i].is_integer():
        x_list.append(data['start_beat'][i])
        y_list.append(np.around(data['start_time'][i] / 44100 * 100, 0))

    # Creating list of 'missing' beats
    lost_list = []
    for i in range(len(ref_list)):
      if ref_list[i] not in y_list and ref_list[i] >= data['start_beat'].min():
        lost_list.append(ref_list[i])

    # Interpolation to create array of insert beats & downbeats
    f = interp1d(y_interp, x_interp, kind='linear', fill_value='extrapolate')
    insert_beat = []
    for i in range(len(lost_list)):
      idx = lost_list[i]
      t = np.round(f(idx), 0)
      if idx % m == 0:
        d = 1
      else:
        d = 0 
      insert_beat.append([t, 1, d])

    # Onset and duration calculations
    onset_arr = []                     
    for i in range(len(data)):
      pitch = data['note'][i]
      onset = int(np.round(data['start_time'][i] / 44100 * 100, 0))
      offset = int(np.round(data['end_time'][i] / 44100 * 100, 0))
      beat = data['start_beat'][i] 
      onset_arr.append(onset)

      # Assign onsets and durations
      new[str(pitch) + '/onset'][onset] = 1
      new[str(pitch) + '/dur'][onset:offset+1] = 1 
      
      # Add beat and downbeat if it corresponds
      if beat.is_integer(): # beat
        new['beat'][onset] = 1 
        if beat % int(m) == 0:
          new['downbeat'][onset] = 1

    onset_arr = np.array(onset_arr)
    onset_arr = np.unique(onset_arr)
    onset_list = onset_arr.tolist()

    # IOI 
    for i in range(len(onset_list)):
      if i == 0:
        num = 0.0
        new['IOI'][onset_list[i]] = num
      else:
        num = np.round(float(onset_list[i] - onset_list[i-1]) * 0.01, 2)
        new['IOI'][onset_list[i]] = num

    # Add missing beats and downbeats
    for i in range(len(insert_beat)):
      new['beat'][insert_beat[i][0]] = insert_beat[i][1]
      new['downbeat'][insert_beat[i][0]] = insert_beat[i][2]

    # Save to temporary csv
    path = dir_csv + name
    new.to_csv(path)

    # =========================================================================

    # Load Yb with beat column, Yd with downbeat column, and X with the
    # remaining info
    X = pd.read_csv(path)
    name = path.split('/')[-1].split('.')[0]
    Yb, Yd = pd.DataFrame(X, columns=['beat']).to_numpy(), pd.DataFrame(X,\
                                            columns=['downbeat']).to_numpy()
    X = X.drop(['beat', 'downbeat', 'Unnamed: 0'], axis=1).to_numpy()   

    # setting 
    rows = len(X)
    z = np.zeros((rows, 1))
    z = pd.DataFrame(z)

    # Calculate spectral flux 
    X_sf = np.zeros((len(X), len(midi_notes_groove)))
    for i in range(len(X)):
      if i != len(X) - 1:
        X_sf[i] = np.maximum(X[i+1][:len(midi_notes_groove)] \
                             - X[i][:len(midi_notes_groove)], 0)

    X_sf = np.sum(X_sf, axis=1)
    X_sf = X_sf[:, np.newaxis]
    X = np.concatenate((X, X_sf), axis=1)

    # Split the data in overlapping fragments of window_sz size
    num_frag = int(len(X) / window_sz)
    if num_frag * window_sz + window_sz > len(X):
      num_frag = num_frag - 1

    idx_list = []
    for i in range(num_frag):
      start_idx = window_sz * i
      end_idx = start_idx + (window_sz * 2)
      idx = int((start_idx + end_idx) / 2)
      idx_list.append(idx)

    # Transform to frame-level 
    X_data = []
    for i in idx_list:    
      X_data.append(X[i-window_sz:i+window_sz])

    Yb_data, Yd_data = [], []
    for i in idx_list:    
      Yb_data.append(Yb[i-window_sz:i+window_sz])
      Yd_data.append(Yd[i-window_sz:i+window_sz])
    Yb_data, Yd_data = np.squeeze(Yb_data, axis=2), np.squeeze(Yd_data, axis=2)

    # Save to temporary individual npz files
    with open(dir_npz + name + '.npz', 'wb') as f:
      pickle.dump([X_data, Yb_data, Yd_data], f)


  if(join):
    print('\n----------------------------')
    print('Joining files...')

    data_list = np.sort(glob.glob(dir_npz + '*.npz')).tolist()
    X_data, Yb_data, Yd_data, Yt_data = [], [], [], []

    for idx in range(len(data_list)):
      with open(data_list[idx], 'rb') as f:
        data = pickle.load(f)
        X_data.append(data[0])
        Yb_data.append(data[1])
        Yd_data.append(data[2])

    X = np.concatenate(X_data)
    Yb = np.concatenate(Yb_data)
    Yd = np.concatenate(Yd_data)

    with open(large_save_dir + save_name, 'wb') as f:
      pickle.dump([X, Yb, Yd], f, protocol=4)

In [30]:
# Process training data
print('Training data...')
dir_csv_training = train_dir + 'tmp/csv/'
dir_npz_training = train_dir + 'tmp/npz/'
if not os.path.exists(dir_csv_training):
    os.makedirs(dir_csv_training)
if not os.path.exists(dir_npz_training):
    os.makedirs(dir_npz_training)
preprocess_data(dir_csv=dir_csv_training, dir_npz=dir_npz_training,\
                csv_list=train_csv_list, window_sz=50,\
                large_save_dir=train_dir, save_name='train_data.npz')

print('\n\nValidation data...')
dir_csv_validation = validation_dir + 'tmp/csv/'
dir_npz_validation = validation_dir + 'tmp/npz/'
if not os.path.exists(dir_csv_validation):
    os.makedirs(dir_csv_validation)
if not os.path.exists(dir_npz_validation):
    os.makedirs(dir_npz_validation)
preprocess_data(dir_csv=dir_csv_validation, dir_npz=dir_npz_validation,\
                csv_list=validation_csv_list, window_sz=50,\
                large_save_dir=validation_dir, save_name='validation_data.npz',\
                join=False)


Training data...
897 files!
  13/897 110_funk_95_fill_4-4.csv

  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]


  226/897 11_country_114_fill_4-4.csv

  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]


  897/897 15_rock_130_beat_4-4.csv
----------------------------
Joining files...


Validation data...
253 files!
  253/253 2_rock_130_beat_4-4.csv

In [None]:
#@markdown Download processed files for later use
from google.colab import files
import os
if not os.path.exists(current_dir + 'train_data.zip'):
  print("Zipping training file...")
  !zip -j -r /content/train_data.zip /content/train/train_data.npz
files.download(current_dir + 'train_data.zip')
if not os.path.exists(current_dir + 'validation_data.zip'):
  print("Zipping validation files folder...")
  !zip -j -r /content/validation_data.zip /content/validation/tmp/npz
files.download(current_dir + 'validation_data.zip')

## Training

Now that we have the data processed and saved, we proceed to the training of the model.

### Load data

In [1]:
#@markdown ### Loading preprocessed data
#@markdown If the preprocessing step was skipped, the data can be loaded from
#@markdown the GitHub repo, 
import os
!wget https://github.com/marinaniet0/groove_beat_tracking/blob/main/train_data.zip?raw=true -O train_data.zip
if not os.path.exists(os.getcwd() + '/train/'):
  os.makedirs(os.getcwd() + '/train/')
!unzip -d /content/train/ /content/train_data.zip

--2021-05-28 21:36:49--  https://github.com/marinaniet0/groove_beat_tracking/blob/main/train_data.zip?raw=true
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/marinaniet0/groove_beat_tracking/raw/main/train_data.zip [following]
--2021-05-28 21:36:49--  https://github.com/marinaniet0/groove_beat_tracking/raw/main/train_data.zip
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/marinaniet0/groove_beat_tracking/main/train_data.zip [following]
--2021-05-28 21:36:50--  https://raw.githubusercontent.com/marinaniet0/groove_beat_tracking/main/train_data.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.

In [2]:
# Mount Google Drive to save and load models - in case runtime disconnects!
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# In case we're starting from training directly, define dirs
current_dir = os.getcwd() + '/'
train_dir = current_dir + 'train/'

In [4]:
drive_save_dir = current_dir + 'drive/MyDrive/marinanieto_mir/models/'
if not os.path.exists(drive_save_dir):
  os.makedirs(drive_save_dir)

### BLSTM training

In [5]:
# Imports & model parameters
import os
import glob 
import numpy as np
import pickle
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as func
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(1)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Same chunking/windowing size as in the preprocessing
sz = 50

# Sequences are 2 * the window size
seq_len = sz * 2 

# 46 is the input dim -> 22 notes onsets + 22 notes durations + IOI + SF
input_dim = 46

# 25 is the hidden dim -> 25 units per layer
hidden_dim = 25

# beat & down beat dim -> 2 * seq_len since outputs are seq_len x probab of beat,
# seq_len x probab of not a beat, same for downbeat  
beat_dim, downbeat_dim = seq_len * 2, seq_len * 2      

lr, epochs, batch_sz, n_weights, gamma = 0.01, 50, 8, 5, 0.1

model_name = 'blstm'
training_file = train_dir + 'train_data.npz'

In [6]:
# Class to load data to torch tensors
class custom_dataset(Dataset):
    def __init__(self, x, yb, yd):
        self.x = x 
        self.yb, self.yd = yb, yd
        
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.yb[idx]), torch.FloatTensor(self.yd[idx])

# BLSTM model class
class Model(nn.Module):
    
    def __init__(self, input_dim, hidden_dim):
        super(Model, self).__init__()

        # setting 
        self.flag = True
        self.layer_sz = 2
        self.bi_num = 2

        # layer 
        self.norm = nn.LayerNorm([seq_len, input_dim])
        self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=self.flag, num_layers=self.layer_sz)  
        for name, param in self.rnn.named_parameters():                   # initialization
            if 'bias' in name:
                nn.init.uniform_(param, a=-0.1, b=0.1)
            elif 'weight' in name:
                nn.init.uniform_(param, a=-0.1, b=0.1)

        self.beat = nn.ModuleList([nn.Linear(hidden_dim * self.bi_num, 1), \
                                   nn.Linear(hidden_dim * self.bi_num, 1)])         
        self.act = nn.Sigmoid()

    def forward(self, raw):     
        x = self.norm(raw)
        out, (hn, cn) = self.rnn(x)
        b, d = self.beat[0](out), self.beat[1](out)
        b, d = self.act(b), self.act(d)
        b, d = b.squeeze(-1), d.squeeze(-1)
        return b, d, out    

# Training function
def main():

    print('Loading dataset...')
    with open(training_file, 'rb') as f:
        data = pickle.load(f)
    X_data, Yb_data, Yd_data = data[0], data[1], data[2]
    dataset = custom_dataset(X_data, Yb_data, Yd_data)
    train_loader = DataLoader(dataset, batch_size=batch_sz, drop_last=True)
    print('Dataset loaded.')

    del dataset
    del X_data, Yb_data, Yd_data

    model = Model(input_dim, hidden_dim)
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=gamma)           
    beat_loss = nn.BCELoss()

    train_arr, test_arr = [], []

    current_epoch = 0

    # load latest checkpoint in drive_save_dir
    if os.path.exists(drive_save_dir):
      latest_ckpt = -1
      latest_file = ''
      for root, dirs, files in os.walk(drive_save_dir, topdown=False):
        for name in files:
          if name.startswith(model_name) and int(name.split('.')[0].split('_')[-1])\
           > latest_ckpt:
              latest_ckpt = int(name.split('.')[0].split('_')[-1])
              latest_file = os.path.join(root, name)

      if latest_ckpt > -1 and len(latest_file) > 0:
          ckpt = torch.load(latest_file)
          model.load_state_dict(ckpt['model_state_dict'])
          optimizer.load_state_dict(ckpt['optimizer_state_dict'])
          current_epoch = latest_ckpt + 1
          print('loaded!')

    for epoch in range(current_epoch, epochs):

        print('\n[info] epoch %d' %(epoch), end='\t')
        epoch = "%02d" %(epoch)

        lr_data = optimizer.param_groups[0]['lr']

        model.train()                             
        bl_loss, dl_loss = 0, 0      
        for idx, (x, yb, yd) in enumerate(train_loader):

            x, yb, yd = x.to(device), yb.to(device), yd.to(device)
            b, d, out = model.forward(x)
            bl, dl = beat_loss(b, yb), beat_loss(d, yd)
            bl_loss += bl.item()                
            dl_loss += dl.item()
            train_loss = bl + dl * n_weights

            optimizer.zero_grad()                # update 
            train_loss.backward()
            optimizer.step()

        print('train loss = %.4f | beat loss = %.4f | downbeat loss = %.4f' % ((bl_loss + dl_loss) / idx, bl_loss / idx, dl_loss / idx))
        train_arr.append([epoch, np.round(bl_loss / idx, 4), np.round(dl_loss / idx, 4)])

        scheduler.step()

        save_name = drive_save_dir + model_name + '_' + str(epoch) +'.pkl'
        torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, save_name)

# Run main() :)
if __name__ == '__main__':
    main()

Loading dataset...
Dataset loaded.
loaded!

[info] epoch 49	train loss = 0.0322 | beat loss = 0.0238 | downbeat loss = 0.0084


### BLSTM + Attention training

In [5]:
# Imports & model parameters
import os
import glob 
import numpy as np
import pickle
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as func
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(1)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Same chunking/windowing size as in the preprocessing
sz = 50

# Sequences are 2 * the window size
seq_len = sz * 2 

# 46 is the input dim -> 22 notes onsets + 22 notes durations + IOI + SF
input_dim = 46

# 25 is the hidden dim -> 25 units per layer
hidden_dim = 25

# beat & down beat dim -> 2 * seq_len since outputs are seq_len x probab of beat,
# seq_len x probab of not a beat, same for downbeat  
beat_dim, downbeat_dim = seq_len * 2, seq_len * 2      

lr, epochs, batch_sz, n_weights, gamma = 0.01, 50, 8, 5, 0.1

model_name = 'attn'
training_file = train_dir + 'train_data.npz'

In [6]:
# Load data to tensor function
class custom_dataset(Dataset):
    def __init__(self, x, yb, yd):
        self.x = x 
        self.yb, self.yd = yb, yd
        
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.yb[idx]), torch.FloatTensor(self.yd[idx])

# Attention mechanism definition
class Attention(nn.Module):
    def __init__(self, dimensions, attention_type='general'):
        super(Attention, self).__init__()

        if attention_type not in ['dot', 'general']:
            raise ValueError('Invalid attention type selected.')

        self.attention_type = attention_type
        if self.attention_type == 'general':
            self.linear_in = nn.Linear(dimensions, dimensions, bias=False)

        self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.tanh = nn.Tanh()

    def forward(self, query, context):

        batch_size, output_len, dimensions = query.size()
        query_len = context.size(1)

        if self.attention_type == "general":
            query = query.reshape(batch_size * output_len, dimensions)
            query = self.linear_in(query)
            query = query.reshape(batch_size, output_len, dimensions)

        # Computing attention scores
        attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())
        attention_scores = attention_scores.view(batch_size * output_len, query_len)
        attention_weights = self.softmax(attention_scores)
        attention_weights = attention_weights.view(batch_size, output_len, query_len)

        mix = torch.bmm(attention_weights, context)
        combined = torch.cat((mix, query), dim=2)
        combined = combined.view(batch_size * output_len, 2 * dimensions)

        # Pass through linear layer & hyperbolic tangent
        output = self.linear_out(combined).view(batch_size, output_len, dimensions)
        output = self.tanh(output)

        return output, attention_weights

# BLSTM + Attention model class
class Model(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, beat_dim, downbeat_dim):
        super(Model, self).__init__()

        # setting 
        self.flag = True
        self.layer_sz = 2

        # layer 
        self.norm = nn.LayerNorm([seq_len, input_dim])
        self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=self.flag, num_layers=self.layer_sz)  
        for name, param in self.rnn.named_parameters():                   # initialization
            if 'bias' in name:
                nn.init.uniform_(param, a=-0.1, b=0.1)
            elif 'weight' in name:
                nn.init.uniform_(param, a=-0.1, b=0.1)

        self.beat = nn.ModuleList([nn.Linear(hidden_dim, 1), \
                                   nn.Linear(hidden_dim, 1)])          

        self.attn = Attention(hidden_dim)
        self.act = nn.Sigmoid()
    
    def forward(self, raw):      

        x = self.norm(raw)
        out, (hn, cn) = self.rnn(x) 
        out_tmp = torch.chunk(out, 2, -1)
        out_tmp = out_tmp[0] + out_tmp[1] 
        hn = hn.permute(1, 0, 2)
        attn_out, weights = self.attn(out_tmp, hn)     # (8, 1201, 25)

        # beat 
        b1, d1 = self.beat[0](attn_out), self.beat[1](attn_out)                   # (8, 1201, 1)
        b, d = self.act(b1), self.act(d1)
        b, d = b.squeeze(-1), d.squeeze(-1)                   # (8, 1201)

        return b, d 

def main():
    print('Loading dataset...')
    with open(training_file, 'rb') as f:
        data = pickle.load(f)
    X_data, Yb_data, Yd_data = data[0], data[1], data[2]
    dataset = custom_dataset(X_data, Yb_data, Yd_data)
    train_loader = DataLoader(dataset, batch_size=batch_sz, drop_last=True)
    print('Dataset loaded.')

    del dataset
    del X_data, Yb_data, Yd_data

    # model, optim, loss 
    model = Model(input_dim, hidden_dim, beat_dim, downbeat_dim)
    model = model.to(device)

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=gamma)           
    beat_loss = nn.BCELoss()

    train_arr, test_arr, lr_arr = [], [], []

    current_epoch = 0

    if os.path.exists(drive_save_dir):
      latest_ckpt = -1
      latest_file = ''
      for root, dirs, files in os.walk(drive_save_dir, topdown=False):
        for name in files:
          if name.startswith(model_name) and int(name.split('.')[0].split('_')[-1])\
           > latest_ckpt:
              latest_ckpt = int(name.split('.')[0].split('_')[-1])
              latest_file = os.path.join(root, name)

      if latest_ckpt > -1 and len(latest_file) > 0:
          ckpt = torch.load(latest_file)
          model.load_state_dict(ckpt['model_state_dict'])
          optimizer.load_state_dict(ckpt['optimizer_state_dict'])
          current_epoch = latest_ckpt + 1
          print('Model loaded')

    for epoch in range(current_epoch, epochs):

        lr_data = optimizer.param_groups[0]['lr']
        # print('lr_data =', lr_data)

        print('\n[info] epoch %d' %(epoch), end='\t')

        model.train()                               
        bl_loss, dl_loss = 0, 0    

        for idx, (x, yb, yd) in enumerate(train_loader):

            x, yb, yd = x.to(device), yb.to(device), yd.to(device)
            b, d = model.forward(x)

            bl, dl = beat_loss(b, yb), beat_loss(d, yd)
            bl_loss += bl.item()                
            dl_loss += dl.item()
            train_loss = bl + dl * n_weights

            optimizer.zero_grad()                # update 
            train_loss.backward()
            optimizer.step()

        print('train loss = %.4f | beat loss = %.4f | downbeat loss = %.4f' % ((bl_loss + dl_loss) / idx, bl_loss / idx, dl_loss / idx))
        train_arr.append([epoch, np.round(bl_loss / idx, 4), np.round(dl_loss / idx, 4)])

        scheduler.step()

        save_name = drive_save_dir + model_name + '_' + str(epoch) +'.pkl'
        torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, save_name)   

if __name__ == '__main__':
    main()

Loading dataset...
Dataset loaded.
Model loaded

[info] epoch 46	train loss = 0.0324 | beat loss = 0.0240 | downbeat loss = 0.0084

[info] epoch 47	train loss = 0.0324 | beat loss = 0.0240 | downbeat loss = 0.0084

[info] epoch 48	train loss = 0.0324 | beat loss = 0.0240 | downbeat loss = 0.0084

[info] epoch 49	train loss = 0.0324 | beat loss = 0.0240 | downbeat loss = 0.0084


## Evaluation

### Reference beats for evaluation

Downloading the validation preprocessed data in order to evaluate the models.

In [None]:
#@markdown ### Loading preprocessed data
#@markdown If the preprocessing step was skipped, the validation data can be
#@markdown loaded from the GitHub repo.
import os
!wget https://github.com/marinaniet0/groove_beat_tracking/blob/main/validation_data.zip?raw=true -O validation_data.zip
if not os.path.exists(os.getcwd() + '/eval/npz/'):
  os.makedirs(os.getcwd() + '/eval/npz/')
!unzip -d /content/eval/npz/ /content/validation_data.zip

Once downloaded, the data is processed to match the output format (txt file with
beat positions in seconds).

In [8]:
import os
import glob 
import numpy as np
import pickle 

file_list = np.sort(glob.glob(os.getcwd() + '/eval/npz/*.npz', recursive=True)).tolist()
save_dir = os.getcwd() + '/eval/txt/reference/'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for file_idx in range(len(file_list)):

    name = file_list[file_idx].split('/')[-1].split('.')[0]
    print("\r " + str(file_idx+1) + "/" + str(len(file_list)) + " " + name, end = "")
    with open(file_list[file_idx], 'rb') as f: 
        data = pickle.load(f)

    b, d = data[1], data[2]
    b, d = np.squeeze(b), np.squeeze(d)

    if b.shape[0] > 1 and len(b.shape) > 1:
      b_dechunked = b[0,:]
      for chunk_idx in np.arange(1, b.shape[0]):
        b_dechunked = np.append(b_dechunked, b[chunk_idx,50:])
      Yb = b_dechunked
    else:
      Yb = b
    if d.shape[0] > 1 and len(d.shape) > 1:
      d_dechunked = d[0,:]
      for chunk_idx in np.arange(1, d.shape[0]):
        d_dechunked = np.append(d_dechunked, d[chunk_idx,50:])
      Yd = d_dechunked
    else:
      Yd = d

    Yb = np.argwhere(Yb == 1).flatten().astype(float)/100.0
    Yd = np.argwhere(Yd == 1).flatten().astype(float)/100.0
    
    with open(save_dir + name + '_beat.txt', 'w') as f:
      f.writelines('\n'.join(Yb.astype(np.str)))
    with open(save_dir + name + '_downbeat.txt', 'w') as f:
      f.writelines('\n'.join(Yd.astype(np.str)))

 223/223 9_soul-groove9_105_beat_4-4

### BLSTM output

First we download & unzip the model to be used from GitHub

In [9]:
#@markdown ### Loading pre-trained model
#@markdown If the training was skipped, we load the model from GitHub
import os
!wget https://github.com/marinaniet0/groove_beat_tracking/blob/main/models/blstm_49.pkl?raw=true -O blstm.pkl
blstm_pkl = os.getcwd() + '/blstm.pkl'

--2021-05-28 22:05:33--  https://github.com/marinaniet0/groove_beat_tracking/blob/main/models/blstm_49.pkl?raw=true
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/marinaniet0/groove_beat_tracking/raw/main/models/blstm_49.pkl [following]
--2021-05-28 22:05:34--  https://github.com/marinaniet0/groove_beat_tracking/raw/main/models/blstm_49.pkl
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/marinaniet0/groove_beat_tracking/main/models/blstm_49.pkl [following]
--2021-05-28 22:05:34--  https://raw.githubusercontent.com/marinaniet0/groove_beat_tracking/main/models/blstm_49.pkl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubuser

Then we run the validation files through the model to get the predictions

In [10]:
import os
import glob 
import numpy as np
import pickle
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as func
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(1)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
sz = 50
seq_len, batch_sz = sz * 2, 8
input_dim, hidden_dim, beat_dim, downbeat_dim = 46, 25, seq_len * 2, seq_len * 2
model_path = os.getcwd() + '/blstm.pkl'
npz_path = os.getcwd() + '/eval/npz/'
output_path = os.getcwd() + '/eval/txt/blstm/tmp/'

In [11]:
class custom_dataset(Dataset):
    def __init__(self, x, yb, yd):
        self.x = x 
        self.yb, self.yd = yb, yd
        
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.yb[idx]), torch.FloatTensor(self.yd[idx])

class Model(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, beat_dim, downbeat_dim):
        super(Model, self).__init__()

        # setting 
        self.flag = True
        self.layer_sz = 2
        self.bi_num = 2
        self.norm = nn.LayerNorm([seq_len, input_dim])
        self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=self.flag, num_layers=self.layer_sz)  

        # initialization 
        for name, param in self.rnn.named_parameters():                   
            if 'bias' in name:
                nn.init.uniform_(param, a=-0.1, b=0.1)
            elif 'weight' in name:
                nn.init.uniform_(param, a=-0.1, b=0.1)

        self.beat = nn.ModuleList([nn.Linear(hidden_dim * self.bi_num, 1), \
                                   nn.Linear(hidden_dim * self.bi_num, 1)])         
        self.act = nn.Sigmoid()

    def forward(self, x):      

        x = self.norm(x)
        out, (hn, cn) = self.rnn(x)
        b, d = self.beat[0](out), self.beat[1](out)                 
        b, d = self.act(b), self.act(d)
        b, d = b.squeeze(-1), d.squeeze(-1)               

        return b, d 

def main():

    print('Calculate tracking results of BLSTM.')
    
    file_list = np.sort(glob.glob(npz_path + '*.npz', recursive=True)).tolist() 
    model = Model(input_dim, hidden_dim, beat_dim, downbeat_dim)
    model = model.to(device)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])

    s = output_path
    
    for file_idx in range(len(file_list)):

        # create a folder to store frames
        name = file_list[file_idx].split('/')[-1].split('.')[0]
        directory = s + str(name)
        if not os.path.exists(directory):
            os.makedirs(directory)     

        # prepare dataset and loader 
        with open(file_list[file_idx], 'rb') as f:                      
            data = pickle.load(f)
        X_data, Yb_data, Yd_data = data[0], data[1], data[2]
        dataset = custom_dataset(X_data, Yb_data, Yd_data)
        loader = DataLoader(dataset, batch_size=batch_sz, drop_last=False)

        model.eval()  
        for idx, (x, yb, yd) in enumerate(loader):
            x, yb, yd = x.to(device), yb.to(device), yd.to(device)
            b, d = model.forward(x)
            b, d = b.cpu().detach().numpy(), d.cpu().detach().numpy()

            with open(s + str(name) + '/' + str(idx) + '_unit.npz', 'wb') as fp:
                pickle.dump([b, d], fp)

        # output 
        tmp_list = np.sort(glob.glob(s + str(name) + '/*.npz', recursive=True)).tolist() 
        b_arr, d_arr = [], []
        for idx in range(len(tmp_list)):
            with open(tmp_list[idx], 'rb') as f:
                data = pickle.load(f)
            b_arr.append(data[0])
            d_arr.append(data[1])

        beat, downbeat = np.concatenate(b_arr), np.concatenate(d_arr)
        beat, downbeat = beat.reshape(-1, 1), downbeat.reshape(-1, 1)

        with open(s + str(name) + '.npz', 'wb') as f:
            pickle.dump([beat, downbeat], f) 


if __name__ == '__main__':
    main()

Calculate tracking results of BLSTM.


We apply thresholds to the outputs: 0.3 to the beats and 0.2 to the downbeats

In [12]:
import os
import glob 
import numpy as np
import pickle 

pred_path = os.getcwd() + "/eval/txt/blstm/pred/"
if not os.path.exists(pred_path):
  os.makedirs(pred_path)

# Threshold -> probabilities to beat or no beat, downbeat or no downbeat
beat_thres, db_thres = 0.3, 0.2
file_list = np.sort(glob.glob(output_path + '/*.npz', recursive=True)).tolist() 

for file_idx in range(len(file_list)):

    name = file_list[file_idx].split('/')[-1].split('.')[0]

    with open(file_list[file_idx], 'rb') as f: 
        data = pickle.load(f)

    b, d = data[0], data[1]
    b, d = np.squeeze(b), np.squeeze(d)
    b, d = b.tolist(), d.tolist()

    b_out, d_out = [], []
    for i in range(len(b)):
        if b[i] >= beat_thres:
            b_out.append(str(np.round(float(i) * 0.01, 2)) + '\n')
        if d[i] >= db_thres:
            d_out.append(str(np.round(float(i) * 0.01, 2)) + '\n')

    s = pred_path + str(name) + '_beat.txt'
    if os.path.exists(s):
        os.system("rm " + s)
    with open(pred_path + str(name) + '_beat.txt', 'a') as fp:
        fp.writelines(b_out)

    s = pred_path + str(name) + '_downbeat.txt'
    if os.path.exists(s):
        os.system("rm " + s)
    with open(pred_path + str(name) + '_downbeat.txt', 'a') as fp:
        fp.writelines(d_out)

### BLSTM + Attention output

First we download & unzip the model to be used from GitHub

In [13]:
#@markdown ### Loading pre-trained model
#@markdown If the training was skipped, we load the model from GitHub
import os
!wget https://github.com/marinaniet0/groove_beat_tracking/blob/main/models/attn_49.pkl?raw=true -O blstm-attn.pkl
blstm_pkl = os.getcwd() + '/blstm-attn.pkl'

--2021-05-28 22:07:03--  https://github.com/marinaniet0/groove_beat_tracking/blob/main/models/attn_49.pkl?raw=true
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/marinaniet0/groove_beat_tracking/raw/main/models/attn_49.pkl [following]
--2021-05-28 22:07:03--  https://github.com/marinaniet0/groove_beat_tracking/raw/main/models/attn_49.pkl
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/marinaniet0/groove_beat_tracking/main/models/attn_49.pkl [following]
--2021-05-28 22:07:03--  https://raw.githubusercontent.com/marinaniet0/groove_beat_tracking/main/models/attn_49.pkl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubuserconte

Then we run the validation files through the model to get the predictions

In [14]:
import os
import glob 
import numpy as np
import pickle
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as func
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(1)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
sz = 50
seq_len, batch_sz = sz * 2, 8
input_dim, hidden_dim, beat_dim, downbeat_dim = 46, 25, seq_len * 2, seq_len * 2
model_path = os.getcwd() + '/blstm-attn.pkl'
npz_path = os.getcwd() + '/eval/npz/'
output_path = os.getcwd() + '/eval/txt/attn/tmp/'

In [15]:
class custom_dataset(Dataset):
    def __init__(self, x, yb, yd):
        self.x = x 
        self.yb, self.yd = yb, yd
        
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.yb[idx]), torch.FloatTensor(self.yd[idx])

class Attention(nn.Module):

    def __init__(self, dimensions, attention_type='general'):
        super(Attention, self).__init__()

        if attention_type not in ['dot', 'general']:
            raise ValueError('Invalid attention type selected.')

        self.attention_type = attention_type
        if self.attention_type == 'general':
            self.linear_in = nn.Linear(dimensions, dimensions, bias=False)

        self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.tanh = nn.Tanh()

    def forward(self, query, context):

        batch_size, output_len, dimensions = query.size()
        query_len = context.size(1)

        if self.attention_type == "general":
            query = query.reshape(batch_size * output_len, dimensions)
            query = self.linear_in(query)
            query = query.reshape(batch_size, output_len, dimensions)

        attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())
        attention_scores = attention_scores.view(batch_size * output_len, query_len)

        attention_weights = self.softmax(attention_scores)
        attention_weights = attention_weights.view(batch_size, output_len, query_len)

        mix = torch.bmm(attention_weights, context)

        combined = torch.cat((mix, query), dim=2)
        combined = combined.view(batch_size * output_len, 2 * dimensions)

        output = self.linear_out(combined).view(batch_size, output_len, dimensions)
        output = self.tanh(output)

        return output, attention_weights

class Model(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, beat_dim, downbeat_dim):
        super(Model, self).__init__()

        # setting 
        self.flag = True
        self.layer_sz = 2

        # layer 
        self.norm = nn.LayerNorm([seq_len, input_dim])
        self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=self.flag, num_layers=self.layer_sz)  
        for name, param in self.rnn.named_parameters():                   # initialization
            if 'bias' in name:
                nn.init.uniform_(param, a=-0.1, b=0.1)
            elif 'weight' in name:
                nn.init.uniform_(param, a=-0.1, b=0.1)

        self.beat = nn.ModuleList([nn.Linear(hidden_dim, 1), \
                                   nn.Linear(hidden_dim, 1)])       

        self.attn = Attention(hidden_dim)
        self.act = nn.Sigmoid()
    
    def forward(self, raw):     

        x = self.norm(raw)
        out, (hn, cn) = self.rnn(x) 
        out_tmp = torch.chunk(out, 2, -1)
        out_tmp = out_tmp[0] + out_tmp[1] 
        hn = hn.permute(1, 0, 2)
        attn_out, weights = self.attn(out_tmp, hn)    

        # beat 
        b, d = self.beat[0](attn_out), self.beat[1](attn_out) 
        b, d = self.act(b), self.act(d)
        b, d = b.squeeze(-1), d.squeeze(-1)                   

        return b, d, out, attn_out   

def main():

    print('Calculate tracking results of BLSTM-Attn.')

    file_list = np.sort(glob.glob(npz_path + '*.npz', recursive=True)).tolist() 
    model = Model(input_dim, hidden_dim, beat_dim, downbeat_dim)
    model = model.to(device)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])

    s = output_path
    if not os.path.exists(s):
      os.makedirs(s)
    
    for file_idx in range(len(file_list)):

        print('\r' + file_list[file_idx], end="")
        # create a folder to store frames
        name = file_list[file_idx].split('/')[-1].split('.')[0]
        directory = s + str(name)
        if not os.path.exists(directory):
            os.makedirs(directory)     

        # prepare dataset and loader 
        with open(file_list[file_idx], 'rb') as f:                      
            data = pickle.load(f)
        X_data, Yb_data, Yd_data = data[0], data[1], data[2]
        dataset = custom_dataset(X_data, Yb_data, Yd_data)
        loader = DataLoader(dataset, batch_size=batch_sz, drop_last=False)

        model.eval()  
        for idx, (x, yb, yd) in enumerate(loader):

            x, yb, yd = x.to(device), yb.to(device), yd.to(device)
            b, d, _, _ = model.forward(x)
            b, d = b.cpu().detach().numpy(), d.cpu().detach().numpy()

            with open(s + str(name) + '/' + str(idx) + '_unit.npz', 'wb') as fp:
                pickle.dump([b, d], fp)

        # output 
        tmp_list = np.sort(glob.glob(s + str(name) + '/*.npz', recursive=True)).tolist() 
        b_arr, d_arr = [], []
        for idx in range(len(tmp_list)):
            with open(tmp_list[idx], 'rb') as f:
                data = pickle.load(f)
            b_arr.append(data[0])
            d_arr.append(data[1])

        beat, downbeat = np.concatenate(b_arr), np.concatenate(d_arr)
        beat, downbeat = beat.reshape(-1, 1), downbeat.reshape(-1, 1)

        with open(s + str(name) + '.npz', 'wb') as f:
            pickle.dump([beat, downbeat], f) 

if __name__ == '__main__':
    main()

Calculate tracking results of BLSTM-Attn.
/content/eval/npz/9_soul-groove9_105_beat_4-4.npz

We apply thresholds to the outputs: 0.3 to the beats and 0.2 to the downbeats

In [16]:
import os
import glob 
import numpy as np
import pickle 

pred_path = os.getcwd() + "/eval/txt/attn/pred/"
if not os.path.exists(pred_path):
  os.makedirs(pred_path)

beat_thres, db_thres = 0.3, 0.2
file_list = np.sort(glob.glob(output_path + '/*.npz', recursive=True)).tolist() 

for file_idx in range(len(file_list)):

    name = file_list[file_idx].split('/')[-1].split('.')[0]

    with open(file_list[file_idx], 'rb') as f: 
        data = pickle.load(f)

    b, d = data[0], data[1]
    b, d = np.squeeze(b), np.squeeze(d)
    b, d = b.tolist(), d.tolist()

    b_out, d_out = [], []
    for i in range(len(b)):
        if b[i] >= beat_thres:
            b_out.append(str(np.round(float(i) * 0.01, 2)) + '\n')
        if d[i] >= db_thres:
            d_out.append(str(np.round(float(i) * 0.01, 2)) + '\n')

    s = pred_path + str(name) + '_beat.txt'
    if os.path.exists(s):
        os.system("rm " + s)
    with open(pred_path + str(name) + '_beat.txt', 'a') as fp:
        fp.writelines(b_out)

    s = pred_path + str(name) + '_downbeat.txt'
    if os.path.exists(s):
        os.system("rm " + s)
    with open(pred_path + str(name) + '_downbeat.txt', 'a') as fp:
        fp.writelines(d_out)

### Getting scores with mir_eval

In [17]:
!pip install mir_eval

Collecting mir_eval
[?25l  Downloading https://files.pythonhosted.org/packages/0a/fe/be4f7a59ed71938e21e89f23afe93eea0d39eb3e77f83754a12028cf1a68/mir_eval-0.6.tar.gz (87kB)
[K     |███▊                            | 10kB 11.4MB/s eta 0:00:01[K     |███████▌                        | 20kB 14.9MB/s eta 0:00:01[K     |███████████▏                    | 30kB 9.7MB/s eta 0:00:01[K     |███████████████                 | 40kB 8.1MB/s eta 0:00:01[K     |██████████████████▋             | 51kB 5.7MB/s eta 0:00:01[K     |██████████████████████▍         | 61kB 5.3MB/s eta 0:00:01[K     |██████████████████████████      | 71kB 6.0MB/s eta 0:00:01[K     |█████████████████████████████▉  | 81kB 6.5MB/s eta 0:00:01[K     |████████████████████████████████| 92kB 4.1MB/s 
Building wheels for collected packages: mir-eval
  Building wheel for mir-eval (setup.py) ... [?25l[?25hdone
  Created wheel for mir-eval: filename=mir_eval-0.6-cp37-none-any.whl size=96515 sha256=2fabfa78cfff35af39cce52

In [18]:
ref_beat_dir = os.getcwd() + '/eval/txt/reference/'
attn_beat_dir = os.getcwd() + '/eval/txt/attn/pred/'
blstm_beat_dir = os.getcwd() + '/eval/txt/blstm/pred/'

beat_end = '_beat.txt'
downbeat_end = '_downbeat.txt'

file_list = []

for root, dirs, files in os.walk(ref_beat_dir):
  for f in files:
    file_list.append(f)

In [19]:
import mir_eval

n_files = len(file_list)
scores_beats_blstm = []
scores_downbeats_blstm = []
scores_beats_attn = []
scores_downbeats_attn = []

print(n_files)
for file in file_list:
  print('\r' + file, end="")
  if file.endswith(beat_end):
    ref_beats = mir_eval.io.load_events(ref_beat_dir + file)
    attn_pred_beats = mir_eval.io.load_events(attn_beat_dir + file)
    blstm_pred_beats = mir_eval.io.load_events(blstm_beat_dir + file)
    
    scores_beats_attn.append(mir_eval.beat.evaluate(ref_beats, attn_pred_beats))
    scores_beats_blstm.append(mir_eval.beat.evaluate(ref_beats, blstm_pred_beats))

  elif file.endswith(downbeat_end):
    ref_downbeats = mir_eval.io.load_events(ref_beat_dir + file)
    attn_pred_downbeats = mir_eval.io.load_events(attn_beat_dir + file)
    blstm_pred_downbeats = mir_eval.io.load_events(blstm_beat_dir + file)
    
    scores_downbeats_attn.append(mir_eval.beat.evaluate(ref_beats, attn_pred_beats))
    scores_downbeats_blstm.append(mir_eval.beat.evaluate(ref_beats, blstm_pred_beats))

446
11_rock_100_beat_4-4_beat.txt109_rock_95_beat_4-4_beat.txt33_hiphop_100_fill_4-4_downbeat.txt88_neworleans-funk_93_fill_4-4_beat.txt167_afrocuban-rhumba_110_fill_4-4_downbeat.txt257_funk-purdieshuffle_130_fill_4-4_beat.txt54_jazz_125_fill_4-4_downbeat.txt205_rock-halftime_140_fill_4-4_downbeat.txt24_rock_90_beat_4-4_beat.txt



218_rock-halftime_140_fill_4-4_downbeat.txt



7_latin-brazilian-maracatu_96_beat_4-4_beat.txt

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


26_rock_120_beat_4-4_downbeat.txt

In [20]:
beat_blstm_pscore=0
beat_blstm_fmeasure=0
downbeat_blstm_pscore=0
downbeat_blstm_fmeasure=0
for score in scores_beats_blstm:
  beat_blstm_fmeasure+=score["F-measure"]
  beat_blstm_pscore+=score["P-score"]
for score in scores_downbeats_blstm:
  downbeat_blstm_fmeasure+=score["F-measure"]
  downbeat_blstm_pscore+=score["P-score"]

beat_blstm_pscore=beat_blstm_pscore/len(scores_beats_blstm)
beat_blstm_fmeasure=beat_blstm_fmeasure/len(scores_beats_blstm)
downbeat_blstm_pscore=downbeat_blstm_pscore/len(scores_beats_blstm)
downbeat_blstm_fmeasure=downbeat_blstm_fmeasure/len(scores_beats_blstm)


beat_attn_pscore=0
beat_attn_fmeasure=0
downbeat_attn_pscore=0
downbeat_attn_fmeasure=0
for score in scores_beats_attn:
  beat_attn_fmeasure+=score["F-measure"]
  beat_attn_pscore+=score["P-score"]
for score in scores_downbeats_attn:
  downbeat_attn_fmeasure+=score["F-measure"]
  downbeat_attn_pscore+=score["P-score"]

beat_attn_pscore=beat_attn_pscore/len(scores_beats_attn)
beat_attn_fmeasure=beat_attn_fmeasure/len(scores_beats_attn)
downbeat_attn_pscore=downbeat_attn_pscore/len(scores_beats_attn)
downbeat_attn_fmeasure=downbeat_attn_fmeasure/len(scores_beats_attn)

print("BLSTM -------------------------")
print("P-score beats: " + str(beat_blstm_pscore))
print("F-measure beats: " + str(beat_blstm_fmeasure))
print("P-score downbeats: " + str(downbeat_blstm_pscore))
print("F-measure downbeats: " + str(downbeat_blstm_fmeasure))

print("BLSTM+Attn ---------------------")
print("P-score beats: " + str(beat_attn_pscore))
print("F-measure beats: " + str(beat_attn_fmeasure))
print("P-score downbeats: " + str(downbeat_attn_pscore))
print("F-measure downbeats: " + str(downbeat_attn_fmeasure))


BLSTM -------------------------
P-score beats: 0.05231426141500952
F-measure beats: 0.04412704099978594
P-score downbeats: 0.04108000833190337
F-measure downbeats: 0.031770015576410265
BLSTM+Attn ---------------------
P-score beats: 0.05247023405771346
F-measure beats: 0.04348409893576909
P-score downbeats: 0.0405724957697746
F-measure downbeats: 0.031342496367446654
