In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import pickle
from tqdm import tqdm

sns.set_theme("notebook")

In [None]:
# RESULTS_DIR = Path("/data5/anasynth_nonbp/andre/reverse-dj-mix/results")
RESULTS_DIR = Path("/home/etiandre/stage/results")
# UNMIXDB_PATH = Path("/data2/anasynth_nonbp/schwarz/abc-dj/data/unmixdb-zenodo")
UNMIXDB_PATH = Path("/home/etiandre/stage/datasets/unmixdb")

In [None]:
results = []
for i in tqdm(RESULTS_DIR.glob("*/*/results.pickle")):
    with open(i, "rb") as f:
        data = pickle.load(f)
        data["run"] = i.parent.parent.name
        results.append(data)

In [None]:
data = (
    pd.json_normalize(results)
)
data[["id", "stretch", "fx", "submix"]] = data["name"].str.extract(r"(set\d+mix\d+)-(\w+)-(\w+)-(\d+)")
data["stretch"] = data["stretch"].replace("resample", "resamp")
data["fx"] = data["fx"].replace("distortion", "dist")
data["fx"] = data["fx"].replace("compressor", "comp")
data["transformation"] = data["stretch"] + "\n" + data["fx"]

In [None]:
data.groupby("run").count()

In [None]:
order = [
    "none\nnone",
    "none\nbass",
    "none\ncomp",
    "none\ndist",
    "resamp\nnone",
    "resamp\nbass",
    "resamp\ncomp",
    "resamp\ndist",
    "stretch\nnone",
    "stretch\nbass",
    "stretch\ncomp",
    "stretch\ndist",
]
plt.figure(figsize=(8, 6))
sns.boxplot(data, x="transformation", y="gain_err", hue="run", log_scale=True, order=order)
plt.ylabel("mean linear estimated gain error")
plt.show()

In [None]:
sns.boxplot(data, x="transformation", y="warp_err", hue="run", log_scale=False, order=order)
plt.xticks(rotation=90)
plt.ylabel("mean estimated warp error (s)")
plt.show()

In [None]:
HIGHPARAMS = [
    "track_start_err",
    "speed_err",
    "fadein_start_err",
    "fadein_stop_err",
    "fadeout_start_err",
    "fadeout_stop_err",
]

HIGHPARAMS_names = [
    "track start mean error (s)",
    "speed mean error",
    "fade-in start mean error (s)",
    "fade-in end mean error (s)",
    "fade-out start mean error (s)",
    "fade-out end mean error (s)",
]

for i,p in enumerate(HIGHPARAMS):
    sns.boxplot(data, x="transformation", y=p, hue="run", log_scale=i != 1, order=order)
    plt.xticks(rotation=90)
    plt.ylabel(HIGHPARAMS_names[i])
    if i == 0:
        plt.ylim(0, 1e3)
    plt.show()

In [None]:
data.loc[(data.stretch == "none") & (data.fx == "none")].sort_values(by="warp_err", ascending=False)

In [None]:
data.loc[240]