In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

In [None]:
dataset_dir = 'dataset/e-gmd-v1.0.0'
slim_metadata_df = pd.read_csv('dataset/e-gmd-v1.0.0-slim.csv') # See `create_slim_metadata.py` for details.
note_occurrences_slim_df = pd.read_csv('dataset/note_occurrences_slim.csv') # See `create_label_mappings.py` for details.
labels_df = pd.read_csv('dataset/label_mapping.csv') # See `create_label_mappings.py` for details.
chopped_df = pd.read_csv('dataset/chopped.csv') # See `chop_dataset.py` for details.

# Note occurrences in slim dataset

In [None]:
note_occurrences_slim_df.head()

In [None]:
plt.xticks(rotation='vertical')
plt.title('Note Occurrences')
_ = plt.bar(note_occurrences_slim_df['name'], note_occurrences_slim_df['occurrences'])

# Label mappings

The top-5 most frequencly occurring drum instrument types are used for training.

The label mappings contain a row for each training drum instrument, with the following columns:
- `id`: Used for one-hot encoding during training. Corresponds to the instrument's occurrence frequency rank in the slim dataset, with the smallest value corresponding to the most common.
- `note`: The MIDI note of the drum instrument.
- `name`: The human-readable name of the drum instrument.

In [None]:
labels_df

In [None]:
get_name = lambda label: labels_df.iloc[label]['name']

# Chopped dataset

The "chopped" dataset is the final, processed dataset used for training.

It consists of a row per "drum hit", which is composed of one or more simultaneously sounding drum instruments, and it has the following columns:
- `file_path`: The path to the audio file in the E-GMD dataset.
- `begin_frame`: The frame (sample index) of the beginning of the hit.
- `num_frames`: The length, in frames, of the hit.
- `label`: A drum instrument label, corresponding to the `id` column in the `dataset/label_mapping.csv` file generated by the `create_label_mapping.py` script.
- `slim_id`: The session ID (index in `e-gmd-v1.0.0-slim.csv`) in which this hit was found, for access to any other metadata.

In [None]:
chopped_df.head()

In [None]:
label_counts = chopped_df.label.value_counts()
label_counts.index = label_counts.index.map(get_name)
label_counts.plot(kind='bar')
plt.title('Label occurrences in "chopped" dataset')
plt.xlabel('')
plt.ylabel('Occurrences')
_ = plt.xticks(rotation='vertical')

In [None]:
from IPython.display import Audio
from scipy.io import wavfile
import os

def get_audio(file_path):
    audio_file_path = f'{dataset_dir}/{file_path}'
    if not os.path.exists(audio_file_path):
        return None, None
    sample_rate, data_int16 = wavfile.read(audio_file_path)
    return sample_rate, (data_int16 / (2**15))  # Convert from int16 to float32

def get_clip(row):
    sample_rate, track_data = get_audio(row.file_path)
    return sample_rate, track_data[row.begin_frame:row.begin_frame + row.num_frames]

def preview_clip_row(row):
    sample_rate, clip = get_clip(row)
    if sample_rate is None:
        return None
    length = clip.shape[0] / sample_rate
    time = np.linspace(0, length, clip.shape[0])
    label = row.label
    name = labels_df.iloc[label]['name']
    session = slim_metadata_df.iloc[row.slim_id]
    kit_name = session.kit_name

    plt.plot(time, clip, label=f'{name} ({kit_name})')

    plt.legend()
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.xlim([0, length])
    plt.show()

    return Audio(clip, rate=sample_rate)

def preview_track(audio_filename):
    sample_rate, clip = get_audio(audio_filename)
    if sample_rate is None:
        return None
    length = clip.shape[0] / sample_rate
    time = np.linspace(0, length, clip.shape[0])

    plt.plot(time, clip, label=f'{audio_filename}')

    # Get the default y-axis limits (slightly larger than min and max values, which I want to keep).
    current_axes = plt.gca()
    y_min, y_max = current_axes.get_ylim()

    plt.legend()
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.xlim([0, length])
    plt.ylim([y_min, y_max]) # Need to reset ylim after adding rects to prevent further autoscaling above/below the rects.
    plt.show()

    return Audio(clip, rate=sample_rate)

num_rows = lambda label: len(chopped_df[chopped_df.label == label])
get_row = lambda label, index: chopped_df[chopped_df.label == label].iloc[index]
preview_clip = lambda label, index: preview_clip_row(get_row(label, index))

In [None]:
get_row(label=2, index=0)

In [None]:
# Update the range of index based on the current selected label.
def update_index_range(*args):
    index_slider.max = num_rows(label_slider.value) - 1

def display_preview_clip(label, index):
    with preview_output:
        preview_output.clear_output()
        display(preview_clip(label, index))

label_slider = widgets.IntSlider(min=0, max=len(labels_df) - 1, step=1, value=0, description='Label')
label_slider.observe(update_index_range, 'value')
index_slider = widgets.IntSlider(min=0, max=num_rows(0) - 1, step=1, value=0, description='Index')
preview_output = widgets.Output()

widgets.interactive(display_preview_clip, label=label_slider, index=index_slider)
display(widgets.HBox([label_slider, index_slider]), preview_output)

In [None]:
def clip_with_end_padding(row, padding_ms=250):
    sample_rate, clip = get_clip(row)
    return np.concatenate([clip, np.zeros(int(padding_ms * sample_rate / 1000))])

def create_supercut(label, num_records=100):
    print(f'Creating a random supercut of {num_records} records for label {label} ({get_name(label)})...')
    records = chopped_df[chopped_df.label == label].sample(num_records)
    supercut = np.concatenate([clip_with_end_padding(row) for _, row in records.iterrows()])
    return Audio(supercut, rate=44100)

In [None]:
def display_supercut(label, num_records):
    with output_area:
        output_area.clear_output()
        description_label.value = get_name(label)
        display(create_supercut(label, num_records))

label_input = widgets.IntSlider(min=0, max=len(labels_df) - 1, step=1, value=0, description='Label')
num_records_input = widgets.IntSlider(min=1, max=100, step=1, value=50, description='Num clips')
description_label = widgets.Label(value=get_name(label_input.value))
slider_with_description = widgets.HBox([label_input, description_label, num_records_input])
output_area = widgets.Output()

widgets.interactive(display_supercut, label=label_input, num_records=num_records_input)
display(slider_with_description, output_area)

In [None]:
def find_closest_matches(query, max_results=10):
    return slim_metadata_df[slim_metadata_df.audio_filename.str.lower().str.contains(query.lower())].audio_filename.unique()[:max_results]

def on_text_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        matches = find_closest_matches(change['new'])
        dropdown.options = matches

text_input = widgets.Text(description='Audio file:')
dropdown = widgets.Dropdown(description='Matches:')
text_input.observe(on_text_change)

def display_preview_track(audio_filename):
    with preview_output:
        preview_output.clear_output()
        display(preview_track(audio_filename))

preview_output = widgets.Output()

widgets.interactive(display_preview_track, audio_filename=dropdown)
display(widgets.VBox([text_input, dropdown]), preview_output)