# Outro Transition Models
In this notebook, we will train two models, which, when taken together, will allow us to identify start and end of transition points in the outro of an input song. We will use a fairly similar approach to the one described in the [Introduction Models notebook](2.%20Introduction%20Transition%20Models.ipynb). However, the training task is more difficult in the outro, as we do not have the start of the song as a direct anchor point, and the end of a song cannot provide reliable beat/downbeat information in the same way that the start of the song can. 

We will therefore need to invert the process taken for our Introduction models. In constructing that process, we first determined the timestamp of the first downbeat of the first phrase, then trained a timing model which was agnostic to specific phrase locations. We could then use the first downbeat timestamp to pick out where the transition points should occur in every subsequent phrase by taking 32 beat jumps. For the outro models we are going to need to use the timing model first, in order to determine the period of the song where the outro actually begins and therefore where the transition should begin. We will then apply a separate Start Bar Finder similar to the one trained for the Introduction models to narrow down that period and pinpoint the specific bar where the transition should begin. For songs which our Introduction models predicted that the first downbeat is on the first beat, we will calculate the BPM of the song and use it to build a downbeat grid which can take advantage of the uniform tempo structure of EDM to identify phrase points in the outro.

In [1]:
import pickle
import numpy as np
import pandas as pd
import librosa
import random
import time

from tensorflow.keras.layers import Input, Dense, Lambda, Concatenate, \
Embedding,ReLU,Flatten,Dropout,BatchNormalization,Activation,Dot
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Conv1D,MaxPooling1D,LSTM,Bidirectional
from tensorflow.keras.callbacks import EarlyStopping

import tensorflow.keras.backend as K
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(40)
from sklearn.metrics import confusion_matrix
from scipy import stats

## Data imports
We begin by importing our labelled data and chromagram/spectrogram audio data. We have stored these as pickled dictionaries of dataframes and numpy arrays respectively. We also define the same helper functions as for the Introduction models, and perform the same processing to slice the spectrograms, and truncate/pad the slices to a uniform size.

In [2]:
with open('grams_full.pkl','rb') as f:
    full_grams = pickle.load(f)

with open('labels_dict','rb') as f:
    labels_dict = pickle.load(f)


In [3]:
def get_ohs(df):
    """Given a labelled bar/beats input, appends columns with binary
        indicators at each beat, with 1 at the appropriate transition
        points and 0 otherwise.
        
        Args:
            df: Bar/beats dataframe with 'Start' and 'End' transition 
                labels in intro and outro
        
        Returns:
            df_copy: Copy of the dataframe with four columns of binary
                labels appended: Incoming Start, Incoming End, Outgoing
                Start, Outgoing End
        """
    df_copy = df.copy(deep=True)
    
    df_copy['Incoming Start'] = df_copy['Intro Label'].apply(
        lambda x: int('Start' in str(x)))
    df_copy['Incoming End'] = df_copy['Intro Label'].apply(
        lambda x: int('End' in str(x)))
    df_copy['Outgoing Start'] = df_copy['Outro Label'].apply(
        lambda x: int('Start' in str(x)))
    df_copy['Outgoing End'] = df_copy['Outro Label'].apply(
        lambda x: int('End' in str(x)))
    
    df_copy = df_copy.drop(['Intro Label','Outro Label'],axis=1)
    return df_copy


def get_slices(gram,frames):
    """Utility function for slicing a spectrogram/chromagram according
        to frames.
        
        Args:
            gram: Spectrogram or chromagram numpy array
            frames: indices at which to slice array
            
        Returns:
            List of array slices
        """
    return [gram[frames[i]:frames[i+1]] for i in range(len(frames)-1)]


In [4]:
def truncate_pad(outputs,length= 175):
    """Truncates or pads gram slices to be of input length
        
        Args: 
            outputs: length two list containing chromagram and spectrogram 
                inputs, i.e. list of four-beat slices
            length: axis 0 length of output of each slice
        
        Returns:
            length two list of truncated/padded chromagrams and spectrograms
    """
    chromagram,spectrogram = outputs
    
    size = spectrogram.shape[0]
    #We convert the spectrogram power values to db and divide by -80 
    #so that all values are between 0 and 1
    spectrogram = librosa.power_to_db(spectrogram.T, ref=np.max).T/-80.0
    
    
    if size>=length:
        return [x[:length] for x in [chromagram,spectrogram]]
    else:
        zeros_x = length-size
        zeros_chromagram = np.zeros((zeros_x,12))
        zeros_spectrogram = np.zeros((zeros_x,128))
        
        return [np.concatenate([chromagram,zeros_chromagram],axis = 0).astype(np.float32),
               np.concatenate([spectrogram,zeros_spectrogram],axis = 0).astype(np.float32)]

In [5]:
slice_length = 175
gram_slices_tp = {}
gram_slice_times = {}
for song in [x for x in labels_dict if x in full_grams]:
    grams = full_grams[song]
    full_gram_shape = grams[0].shape[0]
    
    tags = labels_dict[song]
    tags['Frame'] = librosa.time_to_frames(tags.values[:,0],sr=22050,hop_length=256)
    
    if tags.shape[0]%4==0:
        indices = [i*4 for i in range(tags.shape[0]//4)]
    else:
        indices = [i*4 for i in range(1+tags.shape[0]//4)]
    frames = tags.values[indices,-1].astype(np.int32).tolist()
    if full_gram_shape not in frames:
        frames.append(full_gram_shape)
    
    times = tags.values[indices,0].tolist()
    gram_slice_times[song] = times
    
    chromagrams,spectrograms = [get_slices(gram,frames) for gram in grams]
    
    #We check to make sure there are no empty slices, and add zeros at the start and end
    non_zero_inds = [x for x in range(len(spectrograms)) if spectrograms[x].shape[0]>0]
    
    chromagrams = [chromagrams[i] for i in non_zero_inds]
    chromagrams = [np.zeros((slice_length,12))]+chromagrams+[np.zeros((slice_length,12))]
    
    spectrograms = [spectrograms[i] for i in non_zero_inds]
    spectrograms = [np.zeros((slice_length,128))]+spectrograms+[np.zeros((slice_length,128))]
    
    #We now perform the truncation/padding
    gram_slices_tp[song] = list(zip(*[truncate_pad(
        x) for x in zip(*[chromagrams,spectrograms])]))

## Outro Transition Timing
We will first train a model similar to the Introduction transition timing model trained in the [Introduction Models notebook](2.%20Introduction%20Transition%20Models.ipynb). As before, it will consist of 1D convolution on four-beat chunks which are then input into a Bidirectional LSTM. The input will be the last 256 beats of the song, and the training labels are binary labels on each bar of whether the transition should start or end on that bar.
### Data Preparation
We need to extract the binary start/end labels from the labelled beats for each song.

In [6]:
tm_seq_len = 64

In [7]:
ohs_dict = {}
timing_model_labels = {}
for song in gram_slices_tp:
    tags = labels_dict[song]
    ohs = get_ohs(tags)
    ohs_dict[song] = ohs
    
    indices = [i*4 for i in range(tags.shape[0]//4)]
    ohs_slices = [ohs.values[indices[i]:indices[i+1],-2:] for i in range(len(indices)-1)]
    ohs_slices += [ohs.values[indices[-1]:,-2:]]
    ohs_slices = ohs_slices[-1*tm_seq_len:]
    slice_labels = [np.max(slce,axis = 0) for slce in ohs_slices if slce.shape[0]!=0] 
    slice_labels.append(np.array([0,0]))
    while len(slice_labels) < tm_seq_len + 1:
        slice_labels = [np.array([0,0])] + slice_labels
    timing_model_labels[song] = slice_labels

In [8]:
def get_timing_model_inputs(song):
    """Takes a song as input and returns stacked and concatenated
        array slices representing the last 256 beats of the song.
    """
    chromagrams,spectrograms = gram_slices_tp[song]
    
    chromagrams_inp = np.stack(chromagrams[-(tm_seq_len+1):])
    spectrograms_inp = np.stack(spectrograms[-(tm_seq_len+1):])
    
    if chromagrams_inp.shape[0] < tm_seq_len + 1:
        padding_needed = tm_seq_len + 1 - chromagrams_inp.shape[0]
        zeros_pad_chromagram = np.zeros((padding_needed,slice_length,12))
        chromagrams_inp = np.concatenate([zeros_pad_chromagram,chromagrams_inp],
                                        axis = 0)
        
        zeros_pad_spectrogram = np.zeros((padding_needed,slice_length,128))
        spectrograms_inp = np.concatenate([zeros_pad_spectrogram,spectrograms_inp],
                                        axis = 0)
        
    return np.concatenate([chromagrams_inp,spectrograms_inp],axis = -1).astype(np.float32)

As explained in the [Introduction Models notebook](2.%20Introduction%20Transition%20Models.ipynb), we implement an approach to sample weighting which allows the model to focus on the positive labels. In the Introduction Models, we only needed to take into consideration the previous label when determining the weight, but for the outro we need to take into account the previous and the subsequent labels, as some outros end more than 32 beats before the end of the song. We also want to place more weight on the first positive label and surrounding predictions, as this is what will be used to determine the overall transition timing.

In [9]:
def get_single_weight(i,sums,other_weight=0.01):
    """Determines training weights for transition timing model. 
        All bars with positive labels are set at 1, along with the 
        other bars which are multiples of eight bars (i.e. a phrase) 
        away and within 32 bars (or four phrases).
        
        Args:
            i: Index in sliced label input
            sums: List of sum of labels at each index.
            other_weight: Scaling weight for less important inputs
            
        Returns:
            Scaled weight (either 1 or other_weight)
    """
    factor = other_weight/(1-other_weight)
    if i > len(sums)-9:
        return (int(
            sums[i]!=0 or sums[i-8]!=0 or sums[i-16]!=0)+factor)/(1+factor)
    elif i > len(sums) - 17:
        return (int(
            sums[i]!=0 or sums[i+8]!=0 or sums[i-8]!=0 or sums[i-16]!=0)+factor)/(1+factor)
    elif i > len(sums) - 25:
        return (int(
            sums[i]!=0 or sums[i+8]!=0 or sums[i+16]!=0 or sums[i-8]!=0)+factor)/(1+factor)
    elif i > len(sums) - 33:
        return (int(
            sums[i]!=0 or sums[i+8]!=0 or sums[i+16]!=0 or sums[i+24]!=0)+factor)/(1+factor)
    else:
        return (int(
            sums[i]!=0 or sums[i+8]!=0 or sums[i+16]!=0 or sums[i+24]!=0 or sums[i+32]!=0)+factor)/(
            1+factor)


def get_weights(song):
    """Wrapper function for get_single_weight function to apply
        to full label input for a song. Multiplies the first
        positive example weight and the preceding weight by 1.5
    """
    
    labels = timing_model_labels[song]
    sums = [np.sum(label) for label in labels]
    weights = [get_single_weight(i,sums) for i in range(len(sums))]
    pos_weight_indices = [i for i in range(len(weights)) if weights[i] > 0.5 and sums[i]!=0]
    first_pos_weight_ind = pos_weight_indices[0]
    weights[first_pos_weight_ind] *= 1.5
    if first_pos_weight_ind >= 8:
        weights[first_pos_weight_ind-8] *= 1.5
        
    return weights
            


We load the same validation and test sets as were used for the intro models.

In [10]:
with open('sc_vad_set.pkl','rb') as f:
    vad_set = pickle.load(f)

with open('sc_test_set.pkl','rb') as f:
    test_set = pickle.load(f)
    
train_set = [x for x in gram_slices_tp if x not in vad_set and x not in test_set]

In [11]:
tm_train_input = np.stack(
    [get_timing_model_inputs(song) for song in train_set]).astype('float32')
tm_train_target = np.stack(
    [timing_model_labels[song] for song in train_set]).astype('float32')
tm_train_weights = np.stack(
    [get_weights(song) for song in train_set]).astype('float32')

tm_vad_input = np.stack(
    [get_timing_model_inputs(song) for song in vad_set]).astype('float32')
tm_vad_target = np.stack(
    [timing_model_labels[song] for song in vad_set]).astype('float32')
tm_vad_weights = np.stack(
    [get_weights(song) for song in vad_set]).astype('float32')

tm_test_input = np.stack(
    [get_timing_model_inputs(song) for song in test_set]).astype('float32')
tm_test_target = np.stack(
    [timing_model_labels[song] for song in test_set]).astype('float32')
tm_test_weights = np.stack(
    [get_weights(song) for song in test_set]).astype('float32')


## Model Definition

In [12]:
tm_gram_in = Input((slice_length,140),name = 'tm_analysis_in')

tm_conv_bar_c = Conv1D(filters = 16,kernel_size = 11,activation = 'relu',strides = 3)
tm_pool_bar_c = MaxPooling1D(pool_size = 2,strides = 2)
tm_bar_out_c = BatchNormalization()(tm_pool_bar_c(tm_conv_bar_c(tm_gram_in)))

tm_conv_bar_2_c = Conv1D(filters = 8,kernel_size = 2,activation = 'relu',strides = 2)
tm_pool_bar_2_c = MaxPooling1D(pool_size = 1,strides =1)
tm_bar_out_2_c = BatchNormalization()(tm_pool_bar_2_c(tm_conv_bar_2_c(tm_bar_out_c)))

tm_bar_out_c_flat = Flatten()(tm_bar_out_2_c)
tm_gram_model = Model(tm_gram_in,tm_bar_out_c_flat)

In [13]:
tm_gram_input = Input((tm_seq_len+1,175,140))
tm_gram_flat = Lambda(lambda x: K.reshape(x,(-1,175,140)))(tm_gram_input)

tm_conv = tm_gram_model(tm_gram_flat)

tm_conv_seq = Lambda(lambda x: K.reshape(x,(-1,tm_seq_len+1,tm_conv.shape[-1])))(tm_conv)


tm_conv_dense = Dense(48,activation='tanh')(Dropout(rate=0.4)(tm_conv_seq))
tm_conv_dense_2 = Dense(32,activation='tanh')(Dropout(rate=0.4)(tm_conv_dense))

tm_lstm_out = Bidirectional(LSTM(
    48,return_sequences=True,recurrent_dropout = 0.45,dropout=0.45))(tm_conv_dense_2)
tm_dense_1 = Dense(16,activation='tanh')(Dropout(rate=0.4)(tm_lstm_out))
tm_out = Dense(2,activation='sigmoid')(Dropout(rate=0.3)(tm_dense_1))
tm_final_model = Model(tm_gram_input,tm_out)

## Model Training
We train the model using the same approach as in the Introduction case, with early stopping based on the validation loss.

In [14]:
tm_adam_opt = tf.keras.optimizers.Adam(lr = 2e-4)

tm_final_model.compile(optimizer = tm_adam_opt, loss = 'binary_crossentropy',
                    weighted_metrics = ['accuracy'],sample_weight_mode='temporal')

In [15]:
tm_final_model.fit(tm_train_input,tm_train_target,batch_size = 16,
                sample_weight=tm_train_weights,epochs = 100,
                   validation_data = (tm_vad_input,tm_vad_target,tm_vad_weights),
                  verbose = 0)

tm_es = EarlyStopping(restore_best_weights=True,monitor='val_loss',patience=20)
tm_final_model.fit(tm_train_input,tm_train_target,batch_size = 16,
                sample_weight=tm_train_weights,epochs = 200,
                   validation_data = (tm_vad_input,tm_vad_target,tm_vad_weights),
                  callbacks = [tm_es],verbose = 0)

In [16]:
tm_train_pred = tm_final_model.predict(tm_train_input)

In [17]:
print('Training performance:')
tm_final_model.evaluate(tm_train_input,tm_train_target,sample_weight=tm_train_weights)

Training performance:


[0.03390813618898392, 0.8424115777015686]

In [18]:
print('Validation performance:')
tm_final_model.evaluate(tm_vad_input,tm_vad_target,sample_weight=tm_vad_weights)

Validation performance:


[0.03266860172152519, 0.85056072473526]

In [19]:
print('Test performance:')
tm_final_model.evaluate(tm_test_input,tm_test_target,sample_weight=tm_test_weights)

Test performance:


[0.036313679069280624, 0.8491887450218201]

We can examine an example song to see what our predictions look like.

In [20]:
example_df = pd.DataFrame(np.stack(timing_model_labels['Chris Lake - Lose My Mind']),
            columns = ['Transition Start Label','Transition End Label'])

example_pred = np.round(
    tm_final_model.predict(
        tm_train_input[[train_set.index('Chris Lake - Lose My Mind')]]),2)[0]

example_df = pd.concat([example_df,
                        pd.DataFrame(example_pred,
                                     columns = ['Start Probability','End Probability'])],axis=1)

example_df.iloc[-34:]

Unnamed: 0,Transition Start Label,Transition End Label,Start Probability,End Probability
31,0.0,0.0,0.08,0.02
32,0.0,0.0,0.09,0.02
33,0.0,0.0,0.11,0.02
34,0.0,0.0,0.15,0.03
35,0.0,0.0,0.2,0.03
36,0.0,0.0,0.27,0.03
37,0.0,0.0,0.35,0.04
38,0.0,0.0,0.47,0.04
39,0.0,0.0,0.51,0.06
40,1.0,0.0,0.52,0.08


## Outro Start Bar Finder
Now that we have a model which can determine rough timings for the transition, we need to narrow those timings down. As discussed above, we do not have the advantage of the start of the song providing a fixed reference point from which to identify the starting bar. Instead, we will have to use the output of the timing model to give us a window in which we can narrow down the correct bar. We will want the model to be predicting around the point where the Start Probability exceeds 0.5 for the first time; in the example directly above, this is at index 38, so we might want the model to be predicting the starting bar over bars 36 to 44, with the correct bar being bar 40.

We will train a model which is similar in architecture to the Intro Start Bar Finder for this purpose. However, the training regime will be different, as we want to provide the model with a more diverse range of starting points around the first label. We will therefore train the model on randomly sampled batches of songs and starting points which are within eight bars of the first Start label. We will also need to give the model more context; in the introduction model, only the 'right' context is available since to the 'left' is the start of the song. In this case however it is likely that both left and right context will be relevant.
### Data Preparation
Rather than building a fixed training set, we are going to define a method for sampling individual training batches. To make this easier, we first extract from our sliced grams a fixed-length sequence for each song. The total sequence length of the model input will be 32, so we construct sequences of length 40 with the correct starting label exactly in the middle at index 20. We can then easily generate a training example of length 32 with the correct label being at any index between 16 and 24. 

In [21]:
gram_slices_sbf = {}
for song in gram_slices_tp:    
    slice_labels = timing_model_labels[song]
    pos_labels = [i for i in range(len(slice_labels)) if np.sum(slice_labels[i])>0]
    first_pos_label = pos_labels[0]
    
    slice_tuples = list(zip(*gram_slices_tp[song])) 
    slices_concat = [np.concatenate(
        tup,axis=-1).astype('float32') for tup in slice_tuples][-65:]
    while len(slices_concat) < 65:
        slices_concat.append(np.zeros((slice_length,140)))
    if first_pos_label >= len(slice_labels) - 20:
        slices = slices_concat[first_pos_label-20:]
        slices += [np.zeros((slice_length,140)) for _ in range(
            first_pos_label-(len(slice_labels)-20))]
        gram_slices_sbf[song] = slices
    elif first_pos_label < 20:
        slices = slices_concat[:first_pos_label+20]
        slices = [np.zeros((slice_length,140)) for _ in range(20 - first_pos_label)] + slices
        gram_slices_sbf[song] = slices
    else:
        gram_slices_sbf[song] = slices_concat[first_pos_label-20:first_pos_label+20]


In [22]:
del tm_train_input
import gc
gc.collect()

sbf_fixed_data_train_input = np.stack([np.stack(
    gram_slices_sbf[song]) for song in train_set]).astype('float32')
sbf_fixed_data_vad_input = np.stack([np.stack(
    gram_slices_sbf[song]) for song in vad_set]).astype('float32')
sbf_fixed_data_test_input = np.stack([np.stack(
    gram_slices_sbf[song]) for song in test_set]).astype('float32')

In [23]:
def generate_sbf_training_data(batch_inds):
    """Generates a training batch for input into the outro Start Bar
        Finder model by randomly sampling how much the correct 
        target label should be shifted"""
    batch_size = len(batch_inds)
    fixed_batch_input = sbf_fixed_data_train_input[batch_inds]
    
    batch_targets = random.choices(range(8),k=batch_size)
    
    batch_input = np.stack(
        [fixed_batch_input[i][8-batch_targets[i]:40-batch_targets[i]] \
         for i in range(batch_size)])
    
    return tf.constant(batch_input,dtype = tf.float64),\
            tf.constant(batch_targets,dtype = tf.int64)

We also build a fixed validation set for evaluation during training.

In [24]:
sbf_vad_targets = random.choices(range(8),k=len(vad_set))
sbf_vad_input = np.stack(
        [sbf_fixed_data_vad_input[i][8-sbf_vad_targets[i]:40-sbf_vad_targets[i]]\
         for i in range(len(vad_set))])

sbf_test_targets = random.choices(range(8),k=len(test_set))
sbf_test_input = np.stack(
        [sbf_fixed_data_test_input[i][8-sbf_test_targets[i]:40-sbf_test_targets[i]]\
         for i in range(len(test_set))])

### Model Definition

In [25]:
sbf_seq_len = 32

In [26]:
sbf_gram_in = Input((slice_length,140),name = 'sbf_analysis_in')

sbf_conv_bar_c = Conv1D(filters = 16,kernel_size = 11,activation = 'relu',strides = 3)
sbf_pool_bar_c = MaxPooling1D(pool_size = 2,strides = 2)
sbf_bar_out_c = BatchNormalization()(sbf_pool_bar_c(sbf_conv_bar_c(sbf_gram_in)))

sbf_conv_bar_2_c = Conv1D(filters = 8,kernel_size = 2,activation = 'relu',strides = 2)
sbf_pool_bar_2_c = MaxPooling1D(pool_size = 1,strides =1)
sbf_bar_out_2_c = BatchNormalization()(sbf_pool_bar_2_c(sbf_conv_bar_2_c(sbf_bar_out_c)))
sbf_bar_out_c_flat = Flatten()(sbf_bar_out_2_c)
sbf_gram_model = Model(sbf_gram_in,sbf_bar_out_c_flat)

In [27]:
sbf_gram_input = Input((sbf_seq_len,slice_length,140))

sbf_gram_flat = Lambda(lambda x: K.reshape(x,(-1,slice_length,140)))(sbf_gram_input)

sbf_conv = sbf_gram_model(sbf_gram_flat)

sbf_conv_seq = Lambda(lambda x: K.reshape(x,(-1,sbf_seq_len,sbf_conv.shape[-1])))(sbf_conv)


sbf_conv_dense = Dense(64,activation='tanh')(Dropout(rate=0.6)(sbf_conv_seq))
sbf_conv_dense_2 = Dense(32,activation='tanh')(Dropout(rate=0.6)(sbf_conv_dense))

sbf_lstm_out = Bidirectional(
    LSTM(32,return_sequences=True,recurrent_dropout = 0.5,dropout=0.5))(sbf_conv_dense_2)
sbf_dense_1 = Dense(32,activation='tanh')(Dropout(rate=0.45)(sbf_lstm_out))

sbf_zeros = Lambda(lambda x: K.zeros_like(x)[:,:8])(sbf_dense_1)
sbf_dense_1_left = Lambda(lambda x: x[:,:-8])(sbf_dense_1)
sbf_dense_1_right = Lambda(lambda x: x[:,8:])(sbf_dense_1)
sbf_left_attention = Concatenate(axis=1)([sbf_zeros,sbf_dense_1_right])
sbf_right_attention = Concatenate(axis=1)([sbf_dense_1_left,sbf_zeros])

sbf_dense_1_attention = Concatenate(axis=-1)([sbf_left_attention,sbf_lstm_out,sbf_right_attention])
sbf_dense_1_attention_trimmed = Lambda(lambda x: x[:,12:-12])(sbf_dense_1_attention)

sbf_dense_2 = Dense(24,activation='tanh')(Dropout(rate=0.5)(sbf_dense_1_attention_trimmed))
sbf_dense_3 = Dense(8,activation='tanh')(Dropout(rate=0.4)(sbf_dense_2))
sbf_out = Dense(1)(Dropout(rate=0.3)(sbf_dense_3))
sbf_out_soft = Activation('softmax')(Lambda(lambda x: x[:,:,0])(sbf_out))
sbf_final_model = Model(sbf_gram_input,sbf_out_soft)


### Model Training
We now build a custom training loop to implement our batch sampling method. We will also need to implement a custom Early Stopping callback to find optimal performance on the validation set.

In [28]:
@tf.function
def train_step(model,inputs,targets):
    """tf.function for applying gradient updates to the model.
        
        Args:
            model: Keras model to update
            inputs: Model inputs used to calculate losses for gradient descent
                
        Returns:
            List of model's losses"""
    with tf.GradientTape() as tape:
        pred = model(inputs, training=True)
        loss_value = tf.keras.losses.SparseCategoricalCrossentropy()(
            targets,pred)
        
    grads = tape.gradient(loss_value, model.trainable_variables)
    adam_opt.apply_gradients(zip(grads, model.trainable_variables))
    
    acc_metric.update_state(targets,pred)
    
    return loss_value

In [29]:
sbf_final_model = load_model('outro_start_bar_finder_v3')

In [30]:
adam_opt = tf.keras.optimizers.Adam(lr = 5e-4)
sbf_final_model.compile(optimizer = adam_opt, loss = None)

In [31]:
def run_sbf_epoch(batch_inds_lst,batch_size):
    """Runs a single training epoch for the Outro Start Bar Finder model,
        keeping track of training loss and accuracy.
    
        Args:
            batch_inds_lst: Shuffled list of training set indices to 
                be split into batches.
            batch_size: Batch size to use for training.
        
        Returns:
            Training/validation accuracy and loss
    """
    batches = [batch_inds_lst[i*batch_size:(i+1)*batch_size] for i in range(num_batches)]
    
    epoch_losses = []
    for batch in batches:
        
        batch_input,batch_target = generate_sbf_training_data(batch)
        
        loss_value = train_step(sbf_final_model,batch_input,batch_target)
        epoch_losses.append(loss_value.numpy())
    train_acc = acc_metric.result().numpy()
    acc_metric.reset_states()
    
    training_info.append((np.mean(epoch_losses),train_acc))
    
    vad_pred = sbf_final_model(sbf_vad_input)
    vad_loss = tf.keras.losses.SparseCategoricalCrossentropy()(
            sbf_vad_targets,vad_pred)
    acc_metric.update_state(sbf_vad_targets,vad_pred)
    
    vad_acc = acc_metric.result().numpy()
    acc_metric.reset_states()
    
    vad_info.append((vad_loss.numpy(),vad_acc))
    return training_info[-1],vad_info[-1]

In [32]:
num_epochs = 600
patience = 75
patience_counter = 0
batch_size = 32
num_training_examples = len(train_set)
num_batches = num_training_examples//batch_size

acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

training_info = []
vad_info = []
best_vad_acc = 0
best_vad_loss = 10000

In [33]:
for epoch in range(num_epochs):
    start = time.time()
    batch_inds_lst = list(range(num_training_examples))
    batch_inds_lst = random.sample(batch_inds_lst,num_training_examples)
    
    train_tup,vad_tup = run_sbf_epoch(batch_inds_lst,batch_size)
    if epoch == 300:
        weights_300 = sbf_final_model.get_weights()
    if epoch>=300:
        vad_loss = vad_tup[0]
        if vad_loss <= best_vad_loss:
            best_weights = sbf_final_model.get_weights()
            best_epoch = epoch+1
            best_train_info = training_info[-1]
            patience_counter = 0
            best_vad_loss = vad_loss
        else:
            patience_counter += 1
    if patience_counter > patience:
        break

    print(epoch+1,np.round(time.time()-start,3),training_info[-1],vad_info[-1])
sbf_final_model.set_weights(best_weights)

We can examine the performance on our fixed validation and test set data:

In [34]:
sbf_final_model.compile(optimizer = adam_opt, 
                        loss = 'sparse_categorical_crossentropy',
                        metrics = 'sparse_categorical_accuracy')

In [35]:
sbf_final_model.evaluate(sbf_vad_input,np.stack(sbf_vad_targets))



[1.194238305091858, 0.47999998927116394]

In [36]:
sbf_final_model.evaluate(sbf_test_input,np.stack(sbf_test_targets))



[1.1707124710083008, 0.5099999904632568]

## Full Outro Model
We can now combine the two models trained above into an end-to-end process which, given beat and downbeat timestamps alongside chromagram and spectrogram data, can label transition points in the outro of a song. We will construct this process and evaluate overall performance on the test set. First, the transition timing model is used to find the approximate starting point of the transition. Then, if the song was predicted by the Introduction models to have its first downbeat on the first beat of the song, we will calculate the BPM and use it to construct a phrase grid which can be used to identify exact phrase points in the outro. For the remaining songs, the starting point is narrowed down to a single bar by the Start Bar Finder. The madmom downbeat prediction is then used to determine the exact starting beat, and subsequent transition points are identified by 32 beat jumps along with the output of the timing model.

In [37]:
tm_test_pred = tm_final_model.predict(tm_test_input)

We first identify the point in the transition timing model output where the start prediction goes above a certain threshold value. This will indicate approximately where the transition should begin, for use with the Start Bar Finder. We also identify the distance from this point to the correct starting point of the transition.

In [38]:
first_pred_inds = {}
first_pred_diffs = []
for j,song in enumerate(test_set):
    tm_start_pred = tm_test_pred[j][:,0]
    tm_labels = timing_model_labels[song]
    
    #Since some songs might have low probability predictions,
    #we need to account for these cases when finding the start point
    try:
        first_pred = [i for i in range(65) if tm_start_pred[i] >= 0.65][0]
    except:
        try:
            first_pred = [i for i in range(65) if tm_start_pred[i] >= 0.55][0]
        except:
            try:
                first_pred = [i for i in range(65) if tm_start_pred[i] >= 0.45][0]
            except:
                try:
                    first_pred = [i for i in range(65) if tm_start_pred[i] >= 0.3][0]
                except:
                    first_pred = [i for i in range(65) if tm_start_pred[i] >= 0.15][0]
    
    first_pred_inds[song] = first_pred
    pos_labels = [i for i in range(len(tm_labels)) if np.sum(tm_labels[i])>0]
    next_label_diff = [i-first_pred for i in pos_labels if i>=first_pred][0]
    first_pred_diffs.append(next_label_diff)

We then build the input arrays for the Start Bar Finder using these points, and calculate the output.

In [39]:
gram_slices_sbf_pred = {}
for song in test_set:    
    slice_labels = timing_model_labels[song]
    
    first_pos_label = first_pred_inds[song]
    
    slice_tuples = list(zip(*gram_slices_tp[song])) 
    slices_concat = [np.concatenate(
        tup,axis=-1).astype('float32') for tup in slice_tuples][-65:]
    
    while len(slices_concat) < 65:
        slices_concat.append(np.zeros((slice_length,140)))
    if first_pos_label >= len(slice_labels) - 20:
        slices = slices_concat[first_pos_label-20:]
        slices += [np.zeros((slice_length,140)) for _ in range(
            first_pos_label-(len(slice_labels)-20))]
        
        gram_slices_sbf_pred[song] = slices[8:]
    elif first_pos_label < 20:
        slices = slices_concat[:first_pos_label+20]
        slices = [np.zeros((slice_length,140)) for _ in range(
            20 - first_pos_label)] + slices
        gram_slices_sbf_pred[song] = slices[8:]
    else:
        gram_slices_sbf_pred[song] = slices_concat[first_pos_label-20:first_pos_label+20][8:]

In [40]:
sbf_test_pred_input = np.stack([np.stack(
    gram_slices_sbf_pred[song]) for song in test_set]).astype('float32')

In [41]:
sbf_test_pred = sbf_final_model.predict(sbf_test_pred_input)
sbf_test_pred_ind = np.argmax(sbf_test_pred,axis=-1)

Next we begin the process of generating labels for each song. We first import the timestamp of the first downbeat prediction by the Introduction models for each song in our test set.

In [42]:
with open('first_downbeat_predictions.pkl','rb') as f:
    first_downbeat_pred = pickle.load(f)

In [43]:
def get_first_label_timestamp(song):
    """Utility function which calculates the timestamp of the
    first label in the outro."""
    
    ohs = ohs_dict[song].iloc[:,[0,-2,-1]]
    return ohs[(ohs['Outgoing Start']+ohs['Outgoing End'])>=1].values[0,0]


Now we need to calculate the BPM of each song to construct our phrase grids. We do this by calculating the BPM across the song in 32 beat slices and taking the mode, in order to account for small irregularities in beat structure in the madmom beat prediction.

In [44]:
def get_bpm(song):
    """Calculates the BPM of a song by taking the mode across the
        32 beat slices"""
    
    def get_bpm_slce(song,slce):
        """Calculates the BPM of a single 32 beat slice of a song using
        linear regression.
        
        Args:
            when_beats: beat timestamps
            slce: slice index
        Returns:
            BPM rounded to one decimal place
        """
        when_beats = labels_dict[song].iloc[:,0].apply(float).values[32*slce:32*(slce+1)]
        m_res = stats.linregress(np.arange(len(when_beats)),when_beats)
        beat_step = m_res.slope
        
        return np.round(60/beat_step,decimals = 1)
    
    num_slce = labels_dict[song].shape[0]//32
    slce_bpms = [get_bpm_slce(song,i) for i in range(num_slce)]
    mode = stats.mode(slce_bpms)[0][0]
    if slce_bpms.count(mode) == slce_bpms.count(np.round(mode)):
        return np.round(mode)
    else:
        return mode

We can then build our phrase grids by incrementing in 32-beat jumps, multiplying by the beat length in seconds to determine phrase timestamps. We see that 81 of our 100 songs have been predicted by the Introduction models as having their first downbeat on the first beat. Furthermore, 54 of these 81 songs have their first label in the outro exactly on 32-beat phrase based on the calculated phrase grid.

In [45]:
downbeat_grids = {}
phrase_grids = {}
bpms = {}
good_beat = []
first_beat_songs = []
for i,song in enumerate(list(test_set)):
    tags = labels_dict[song]
    bpm = get_bpm(song)
    beat_len = 60/bpm
    bpms[song] = bpm
    first_beat = tags.values[0,0]
    last_beat = tags.values[-1,0]

    starting_time = first_downbeat_pred[song]
    
    phrase_grid = [starting_time + 32*i*beat_len for i in range(-500,500)]
    phrase_grid = [x for x in phrase_grid if x>=first_beat and x<=last_beat]

    downbeat_grid = [starting_time + 4*i*beat_len for i in range(-500,500)]
    downbeat_grid = [x for x in downbeat_grid if x>=first_beat and x<=last_beat]
    if starting_time == first_beat:
        first_beat_songs.append(song)
        first_outro_label_timestamp = get_first_label_timestamp(song)
        
        if min([abs(x-first_outro_label_timestamp) for x in phrase_grid])<beat_len/3:
            good_beat.append(song)
    downbeat_grids[song] = downbeat_grid
    phrase_grids[song] = phrase_grid
    
print('Number of songs with first downbeat on first beat in Intro:',len(first_beat_songs))
print('Number of first beat songs with exact phrase match on outro label:', len(good_beat))

Number of songs with first downbeat on first beat in Intro: 81
Number of first beat songs with exact phrase match on outro label: 54


We are now ready to extract the final labels and evaluate overall performance. 

In [46]:
def get_labels(prob_pair,threshold = 0.4):
    """Generates labels based on the transition timing model
        output at a single timestep.
        
        Args:
            prob_pair: tuple containing the (start prob,end prob)
                predicted by the model
            threshold: Threshold probability at which a label 
                will be generated
        Returns:
            Label of either Start, End, Start/End, or nan
    """
    start_prob,end_prob = prob_pair
    if start_prob > threshold:
        if end_prob > threshold:
            return 'Start/End'
        else:
            return 'Start'
    elif end_prob > threshold:
        return 'End'
    else:
        return np.nan

In [47]:
def get_nearest_slice_time_inds(downbeats,slice_times):
    """Determines the index of the nearest slice to a list of 
        downbeat/phrase times.
    
        Args:
            downbeats: List of downbeat/phrase timestamps.
            slice_times: List of model input slice timestamps
        
        Returns:
            List of indices of same length of downbeats, containing
                index of nearest slice to each downbeat
    """
    nearest_slice_time_inds = []
    for downbeat in downbeats:
        nearest_slice_time_ind = np.argmin([abs(downbeat-x) for x in slice_times])
        nearest_slice_time_inds.append(nearest_slice_time_ind)
    return nearest_slice_time_inds

We extract the timing of transition points and generate the relevant labels for each song in the test set. As mentioned above, for songs where the Introduction models predict that the first downbeat is on the first beat, we use the phrase grid to determine phrase timings in the outro; otherwise, we use the Start Bar Finder.

In [48]:
test_song_pred_info = {}
for i,song in enumerate(test_set):
    
    slice_times = gram_slice_times[song]
    timing_model_probs = tm_test_pred[i]
    
    first_pred_ind = first_pred_inds[song]
    tags = labels_dict[song]
    if len(slice_times)>=64:
        outro_slice_times = slice_times[-64:]
    else:
        outro_slice_times = slice_times
        
    first_outro_slice_time = outro_slice_times[0]
    
    if song in first_beat_songs:
        phrases = [x for x in phrase_grids[song] if x-first_outro_slice_time > -1]
        first_nearest_slice_time_ind = get_nearest_slice_time_inds(phrases,outro_slice_times)[0]
        first_nearest_slice_time = outro_slice_times[first_nearest_slice_time_ind]
        
        try:
            nearest_phrase_time = [x for x in phrases if x>outro_slice_times[first_pred_ind-1]][0]
        except:
            nearest_phrase_time = [x for x in phrases if x>outro_slice_times[first_pred_ind]][0]

        phrase_bar_inds = [first_nearest_slice_time_ind+8*j for j in range(-8,8)]
        phrase_bar_inds = [x for x in phrase_bar_inds if x < 64 and x>=0]
        
        nearest_beat_ind = np.argmin([abs(nearest_phrase_time-x) for x in tags.values[:,0]])
        phrase_beat_inds = [nearest_beat_ind+j*32 for j in range(-8,8)]
        phrase_beat_inds = [x for x in phrase_beat_inds if x<tags.shape[0]]
        phrase_beat_inds = phrase_beat_inds[-8:]
        
        phrase_times = tags.values[phrase_beat_inds,0].tolist()
    else:
        start_bar_index = first_pred_ind + sbf_test_pred_ind[i]
        
        start_bar_time = outro_slice_times[start_bar_index-1]
        start_bar_beat_index = tags[tags['Beat Timestamp']==start_bar_time].index[0]
        
        start_bar_tags = tags.iloc[start_bar_beat_index:start_bar_beat_index+4]
        start_bar_downbeat_index = start_bar_tags[start_bar_tags['Downbeat']==1].index[0]
    
                
        phrase_bar_inds = [start_bar_index+8*j for j in range(-8,8)]
        phrase_bar_inds = [x for x in phrase_bar_inds if x< 64 and x>=0]
        
        phrase_beat_inds = [start_bar_downbeat_index + j*32 for j in range(-8,8)]
        phrase_beat_inds = [x for x in phrase_beat_inds if  x<tags.shape[0]]
        phrase_beat_inds = phrase_beat_inds[-8:]
        phrase_times = tags.values[phrase_beat_inds,0].tolist()
        

        
    phrase_ind_probs = tm_test_pred[i][phrase_bar_inds,:]
    phrase_labels = [get_labels(pair) for pair in phrase_ind_probs]
        
        
    info = list(zip(*(phrase_times,phrase_labels)))
    info_df = pd.DataFrame(info,columns = ['Beat Timestamp','Predicted Outro Label'])
    test_song_pred_info[song] = info_df

We can take a look at the resulting labels for an example song, and join on the original labels to compare.

In [49]:
song = "Redlight - Sports Mode"
test_song_pred_info[song]

Unnamed: 0,Beat Timestamp,Predicted Outro Label
0,162.27,
1,177.03,
2,192.03,
3,206.8,
4,221.57,
5,236.34,
6,251.11,Start
7,265.88,End


In [50]:
tags_label = labels_dict[song].loc[:,['Beat Timestamp','Outro Label']]
tags_label = tags_label[tags_label['Outro Label'].apply(lambda x: x in ['Start'])]
                                   
test_song_pred_info[song].merge(labels_dict[song].loc[:,['Beat Timestamp','Outro Label']].dropna(),
                              on = 'Beat Timestamp',how='outer').sort_values('Beat Timestamp')

Unnamed: 0,Beat Timestamp,Predicted Outro Label,Outro Label
0,162.27,,
1,177.03,,
2,192.03,,
3,206.8,,
4,221.57,,
5,236.34,,
6,251.11,Start,Start
7,265.88,End,End


As we did in the [Introduction Models notebook](2.%20Introduction%20Transition%20Models.ipynb), we can now evaluate performance across the full test set, looking at songs where the labelling is exactly correct along with the downbeat and number of labels. Since the starting point in this case isn't fixed, to evaluate the downbeat we check if there is any overlap between the timestamps predicted and the ones which were manually labelled.

In [51]:
downbeat_right = []
downbeat_diffs = []
length_right = []
length_diffs = []
exact = []
first_downbeats = {}
for song in test_set:
    tags_label = labels_dict[song].loc[:,['Beat Timestamp','Outro Label']]
    tags_label = tags_label.dropna().reset_index(drop = True)
    pred_df = test_song_pred_info[song].dropna().reset_index(drop = True)
    pred_df.columns = ['Beat Timestamp','Outro Label']
    first_downbeats[song] = pred_df.values[0,0]
    if pred_df.shape[0]>0:
        if len(set(tags_label.values[:,0]).intersection(set(pred_df.values[:,0])))>0:
            downbeat_right.append(song)
        else:
            downbeat_diffs.append(tags_label.values[0,0] - pred_df.values[0,0])
    if tags_label.shape[0] == pred_df.shape[0]:
        length_right.append(song)
        if tags_label.equals(pred_df):
            exact.append(song)
        #We will relax the exact check slightly by allowing 'Start/End' to be 
        #equal to 'Start' or 'End'
        elif tags_label.replace('Start/End','Start').equals(pred_df.replace('Start/End','Start')):
            exact.append(song)
        elif tags_label.replace('Start/End','End').equals(pred_df.replace('Start/End','End')):
                exact.append(song)
    else:
        length_diffs.append((song,tags_label.shape[0] - pred_df.shape[0]))

            
print('Number of songs with downbeat prediction correct:',len(downbeat_right))
print('Number of songs with same number of transition points:',len(length_right))
print('Number of songs which are exactly correct:', len(exact))

Number of songs with downbeat prediction correct: 63
Number of songs with same number of transition points: 66
Number of songs which are exactly correct: 42


Similarly to the Introduction models, the downbeat prediction is the key part of the prediction, as if it is incorrect then the timing of the transitions will be either off-downbeat or off-phrase. However, if this is correct, then there is some subjectivity to the timing of the transition. A manual review of the 21 songs which had the correct downbeat but were not an exact match found that 15 of them had appropriate transitions, with the inference of some missing labels based on simple rules (as explained in the [Introduction Models notebook](2.%20Introduction%20Transition%20Models.ipynb). Two examples of these are below.

In [52]:
ex_song_1 = 'Motez - Roll Out (Benson Remix)'
print('Predicted Label:')
print(test_song_pred_info[ex_song_1].dropna().reset_index(drop=True))
print('\nManual Label:')
print(labels_dict[ex_song_1].loc[:,['Beat Timestamp','Outro Label']]\
      .dropna().reset_index(drop=True))

Predicted Label:
   Beat Timestamp Predicted Outro Label
0          304.79                 Start
1          320.03                   End
2          335.26                   End

Manual Label:
   Beat Timestamp Outro Label
0          274.32       Start
1          289.55         End
2          304.79   Start/End
3          320.03         End
4          335.26         End


In [53]:
ex_song_2 = 'Green Velvet & Mauro Venti - Share Now'
print('Predicted Label:')
print(test_song_pred_info[ex_song_2].dropna().reset_index(drop=True))
print('\nManual Label:')
print(labels_dict[ex_song_2].loc[:,['Beat Timestamp','Outro Label']]\
      .dropna().reset_index(drop=True))

Predicted Label:
   Beat Timestamp Predicted Outro Label
0          377.96                   End
1          393.08                   End
2          408.20                   End
3          423.34                   End

Manual Label:
   Beat Timestamp Outro Label
0          362.85       Start
1          377.96         End
2          393.08         End
3          408.20         End
4          423.34         End


This leaves us with 57 songs which have been labelled with high quality outro transition timings; 43 of these songs were also labelled correctly by the Introduction models, meaning we have 43 songs which are labelled completely for use in transitions. As discussed above, labelling the outro is more difficult than labelling the introduction, since the start of the song is not as useful a reference point. Future work should focus on identifying the downbeat correctly, and perhaps seeing if there is a more effective way to combine the phrase grid structure and the output of the Start Bar Finder. 