In [None]:
import os
import pickle
import random

import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import mirdata
import pandas as pd
import tensorflow as tf

import steme.audio as audio
import steme.dataset as dataset
import steme.loader as loader
import steme.metrics as metrics
import steme.paths as paths
import steme.utils as utils

In [None]:
def plot_calibration(predictions, bins, model_name):
    fig, ax = plt.subplots(1, 1, figsize=(20,5))
    ax = [ax]
    ax[0].boxplot(predictions, vert=1)
    ax[0].xaxis.set_major_locator(ticker.FixedLocator(
        np.arange(1, len(bins)+1)
    ))
    ax[0].plot()
    ax[0].set_xlabel("BPM")
    ax[0].set_ylabel("Model output")
    ax[0].grid(True)
#     ax[0].title.set_text(model_name)#f"Prediction with fixed shift. a = {np.round(a_fixed, 2)}, b = {np.round(b_fixed, 2)}")

    fig.suptitle(model_name, fontsize=16)
    
    
def _calibrate(bpm_dict, model, kmin, kmax, n_predictions=100, fixed=False):
    print("Calibrating model")
    model_output = np.zeros(len(bpm_dict.keys()))
    j = 0
    for bpm in bpm_dict.keys():
        T = bpm_dict[bpm]["T"]

        preds = np.zeros(n_predictions)
        step = T.shape[1]//n_predictions
        
        for i in range(n_predictions):
            slice_idx = i*step
            s1, sh1, s2, sh2, _ = dataset.get_tempogram_slices(
                T=T, kmin=kmin, kmax=kmax, shift_1=0, shift_2=0, slice_idx=slice_idx
            )
            
            s1 = s1[np.newaxis, :]

            xhat1, xhat2, y1, y2 = model.predict([s1, s1, sh1, sh1], verbose=0)
            preds[i] = y1[0][0]

        bpm_dict[bpm]["slice"] = s1[0,:,0]
        bpm_dict[bpm]["shift"] = sh1
        bpm_dict[bpm]["estimation"] = xhat1[0,:,0]
        bpm_dict[bpm]["predictions"] = np.array(preds)

        model_output[j] = np.median(np.array(preds))
        j += 1

    quad = np.poly1d(np.polyfit(model_output, list(bpm_dict.keys()), 2))
    a, b = utils.get_slope(model_output, list(bpm_dict.keys()))

    return bpm_dict, a, b, quad
    
def read_dataset_info(main_file):
    dataset_metadata = os.path.join("/home/gigibs/Documents/steme/data", f"{main_file}_metadata.h5")
    print(f"Reading metadata file {dataset_metadata}")
    response = {}

    with h5py.File(dataset_metadata, "r") as hf:
        response["main_file"] = hf.get("main_file")[()].decode("UTF-8")
        response["validation_file"] = hf.get("validation_file")[()].decode("UTF-8")
        response["train_file"] = hf.get("train_file")[()].decode("UTF-8")
        response["main_filepath"] = hf.get("main_filepath")[()].decode("UTF-8")
        response["validation_filepath"] = hf.get("validation_filepath")[()].decode("UTF-8")
        response["train_filepath"] = hf.get("train_filepath")[()].decode("UTF-8")
        response["distribution"] = hf.get("distribution")[:]
        response["validation_setsize"] = hf.get("validation_setsize")[()]
        response["train_setsize"] = hf.get("train_setsize")[()]
        response["tmin"] = hf.get("tmin")[()]
        response["tmax"] = hf.get("tmax")[()]

    return response

def default_variables():
    return {
        "tmin": 25,
        "n_bins": 190,
        "bins_per_octave": 40,
        "kmin": 11, 
        "kmax": 19
    }

def wider_tempi_variables():
    return {
        "tmin": 20,
        "n_bins": 190,
        "bins_per_octave": 30,
        "kmin": 0, 
        "kmax": 8
    }

def calibration_results(dists, t_types):
    results_dict = {}
    for dist_name in dists:
        results_dict[dist_name] = {}
        for t_type in t_types:
            print(dist_name, t_type)
            dataset_name = f"{dist_name}_{t_type}"

            response = read_dataset_info(dataset_name)
            distribution = response["distribution"]

            results_dict[dist_name][t_type] = {}

            model_name = f"{dataset_name}_15_default"
            model_path = f"../models/{model_name}"

            model = tf.keras.models.load_model(model_path)

            for idx, val in enumerate(center):
                results_dict[dist_name][t_type][val] = {}
                sr = 22050
                preds = np.zeros(n_predictions*tracks_per_bin)

                j = 0

                for bpm in center_dict[val]:
                    x = audio.click_track(bpm=bpm, sr=sr)
                    T, t, bpms = audio.tempogram(x, sr, window_size_seconds=10, t_type=t_type, theta=theta)

                    step = T.shape[1]//n_predictions

                    for i in range(n_predictions):
                        slice_idx = i*step
                        s1, sh1, s2, sh2, _ = dataset.get_tempogram_slices(
                            T=T, kmin=kmin, kmax=kmax, shift_1=0, shift_2=0, slice_idx=slice_idx
                        )

                        s1 = s1[np.newaxis, :]

                        xhat1, xhat2, y1, y2 = model.predict([s1, s1, sh1, sh1], verbose=0)
                        preds[j] = y1[0][0]
                        j += 1
                results_dict[dist_name][t_type][val]["predictions"] = np.array(preds)
#                 results_dict[dist_name][t_type][val]["tracks"] = bpm_tracks

            del model
        return results_dict

## Calibrate model with synthetic data


In [None]:
variables = default_variables()
tmin = variables["tmin"]
n_bins = variables["n_bins"]
bins_per_octave = variables["bins_per_octave"]
kmin, kmax = variables["kmin"], variables["kmax"]
theta = dataset.variables_non_linear(tmin, n_bins=n_bins, bins_per_octave=bins_per_octave)

In [None]:
step = 8
offset = 5
left = theta[(theta > 30) & (theta < 350)][::step]
center = theta[(theta > 30) & (theta < 350)][offset::step]
right = theta[(theta > 30) & (theta < 350)][offset::step]

bins_tmp = []
for i, j, k in zip(left, center, right):
    print(f"boundaries for {np.round(j,2)}: [{np.round(np.sqrt(i*j),2)}, {np.round(np.sqrt(j*k))}]")
    bins_tmp.append(i)
    bins_tmp.append(j)

# len(bins_tmp)

In [None]:
dists = [
    "gtzan_augmented_log_25_190_40"
#     "log_uniform_25_190_40",
#     "synthetic_lognorm_0.7_30_50_1000_25_190_40", 
#     "synthetic_lognorm_0.7_70_50_1000_25_190_40",
#     "synthetic_lognorm_0.7_120_50_1000_25_190_40", 
#     "gtzan_25_190_40",
]

t_types = ["fourier", "autocorrelation", "hybrid"]

In [None]:
n_predictions = 2
tracks_per_bin = 50

center_dict = {}
for idx, val in enumerate(center):
    left_boundary = np.sqrt(left[idx]*center[idx])
    right_boundary = np.sqrt(center[idx]*right[idx])
    
    center_dict[val] = np.random.uniform(left_boundary, right_boundary, size=tracks_per_bin)

In [None]:
calculate_results = False
calculate_center_bins = False

try:
    with open("results_dict_aug.pkl", "rb") as pickle_file:
        results_dict = pickle.load(pickle_file)
except:
    calculate_results = True 
    
try:
    with open('center_dict_aug.pkl', 'rb') as f:
        center_dict = pickle.load(f)
except:
    calculate_center_bins = True

In [None]:
calculate_results, calculate_center_bins

In [None]:
if calculate_results:
    results_dict = calibration_results(dists, t_types)
    
    with open('results_dict_aug.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
        
    with open('center_dict_aug.pkl', 'wb') as f:
        pickle.dump(center_dict, f)

In [None]:
# for dist_name in dists:
#     for t_type in t_types:
#         model_name = f"{dist_name}_{t_type}"
#         res = results_dict[dist_name][t_type]
#         predictions = [v["predictions"] for k, v in res.items()]
#         plot_calibration(predictions, np.round(center, 2), model_name)

In [None]:
def plot_calibration_ax(predictions, bins, model_name, ax):
    ax.boxplot(predictions, vert=1)
    ax.xaxis.set_major_locator(ticker.FixedLocator(
        np.arange(1, len(bins)+1, 3)
    ))
#     ax.xaxis.set_xticks()
    ax.xaxis.set_major_formatter(ticker.FixedFormatter(bins[::3]))
    ax.plot()
    ax.grid(True, alpha=0.8)
#     ax.set_ylim(0, 1)
#     ax.title.set_text(model_name)#f"Prediction with fixed shift. a = {np.round(a_fixed, 2)}, b = {np.round(b_fixed, 2)}")

#     fig.suptitle(model_name, fontsize=16)

In [None]:
model_dict = {
#     "synthetic_lognorm_0.7_30_50_1000_25_190_40": "lognorm @ 70", 
#     "synthetic_lognorm_0.7_70_50_1000_25_190_40": "lognorm @ 120",
#     "synthetic_lognorm_0.7_120_50_1000_25_190_40": "lognorm @ 170",
#     "gtzan_25_190_40": "GTZAN",
#     "log_uniform_25_190_40": "log uniform",
    "gtzan_augmented_log_25_190_40": "gtzan_augmented"
}

In [None]:
plt.rc("axes", labelsize=18)
plt.rc("xtick", labelsize=15)
plt.rc("ytick", labelsize=15)

In [None]:
import matplotlib
from scipy.stats import lognorm, uniform

def gtzan_data():
    import mirdata
    gtzan = mirdata.initialize("gtzan_genre",
            data_home="../../datasets/gtzan_genre",
            version="default")
    tracks = gtzan.track_ids
    tracks.remove("reggae.00086")
    tempi = [gtzan.track(track_id).tempo for track_id in tracks]

    return gtzan, tracks, tempi

theta = dataset.variables_non_linear(25, 40, 190)
bins = theta[(theta > 30) & (theta < 370)][::2]
cmap = matplotlib.cm.get_cmap('tab10')
dist_low = lognorm.rvs(0.25, loc=30, scale=50, size=1000, random_state=42)
dist_medium = lognorm.rvs(0.25, loc=70, scale=50, size=1000, random_state=42)
dist_high = lognorm.rvs(0.25, loc=120, scale=50, size=1000, random_state=42)
dist_uniform = uniform.rvs(30, scale=210,size=1000, random_state=42)
dist_log_uniform = 30*np.e**(np.random.rand(1000)*np.log(240/30))
_, _, dist_gtzan = gtzan_data()
dist_gtzan = np.array(dist_gtzan)

In [None]:
# add_subplot example
# https://towardsdatascience.com/customizing-multiple-subplots-in-matplotlib-a3e1c2e099bc
# https://python-course.eu/numerical-programming/creating-subplots-in-matplotlib.php

fig = plt.figure(figsize=(20, 10))
plt.subplots_adjust(wspace= 0.25, hspace= 0.25)

kwargs = {
    "alpha": 0.7,
    "histtype": "stepfilled"
}

dist_bins = np.round(bins)

p0 = fig.add_subplot(4, 5, 1)
p0.hist(dist_log_uniform, bins=dist_bins, label="log uniform", edgecolor="black", color=cmap.colors[3],**kwargs)
p0.set_xscale("log")
p0.set_xticks([], [])
p0.set_xticks(center.astype(int)[::3])
p0.xaxis.set_major_formatter(ticker.ScalarFormatter())
p0.set_xlim(28, 360)
p0.set_ylim(0, 200)
p0.set_ylabel("Distribution")
p0.set_title("log uniform", fontsize=18)

p2 = fig.add_subplot(4, 5, 2, sharex=p0, sharey=p0)
p2.hist(dist_low, bins=dist_bins, label="lognorm @ 70", edgecolor="black", color=cmap.colors[0], **kwargs)
p2.set_title("lognorm @ 70", fontsize=18)

p3 = fig.add_subplot(4, 5, 3, sharex=p0, sharey=p0)
p3.hist(dist_medium, bins=dist_bins, label="lognorm @ 120", edgecolor="black",color=cmap.colors[2], **kwargs)
p3.set_title("lognorm @ 120", fontsize=18)

p4 = fig.add_subplot(4, 5, 4, sharex=p0, sharey=p0)
p4.hist(dist_high, bins=dist_bins, label="lognorm @ 170", edgecolor="black",color=cmap.colors[4], **kwargs)
p4.set_title("lognorm @ 170", fontsize=18)

p5 = fig.add_subplot(4, 5, 5, sharex=p0, sharey=p0)
p5.hist(dist_gtzan, bins=dist_bins, label="GTZAN", edgecolor="black", color=cmap.colors[8], **kwargs)
p5.set_title("GTZAN", fontsize=18)

calibration_bins = center[::2].astype(int)
plot_index = 6
for t_type in t_types:
    for dist_name in dists:
        model_name = model_dict[dist_name]
        res = results_dict[dist_name][t_type]
        predictions = [v["predictions"] for k, v in res.items()]
        
        p = fig.add_subplot(4, 5, plot_index)
        plot_calibration_ax(predictions, center.astype(int), model_name, p)
        
        if plot_index > 15:
            p.set_xlabel("BPM")
        if plot_index == 6:
            p.set_ylabel("Fourier")
        if plot_index == 11:
            p.set_ylabel("Autocorrelation")
        if plot_index == 16:
            p.set_ylabel("Hybrid")        
        
        plot_index += 1

#plt.savefig("calibration_with_dists.png", format="png")

In [None]:
fig = plt.figure(figsize=(20, 10))
plot_index = 1
for t_type in t_types:
    for dist_name in dists:
        model_name = model_dict[dist_name]
        res = results_dict[dist_name][t_type]
        predictions = [v["predictions"] for k, v in res.items()]
        
        p = fig.add_subplot(3, 1, plot_index)
        p.set_title(f"{model_name}_{t_type}", fontsize=18)

        
        plot_calibration_ax(predictions, center.astype(int), model_name, p)
        plot_index += 1
        
plt.tight_layout()