In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import pickle
from tqdm import tqdm
import json
import dictdiffer
from pprint import pprint

sns.set_theme("paper")

In [None]:
# RESULTS_DIR = Path("/data5/anasynth_nonbp/andre/reverse-dj-mix/results")
RESULTS_DIR = Path("/home/etiandre/stage/results")
# UNMIXDB_PATH = Path("/data2/anasynth_nonbp/schwarz/abc-dj/data/unmixdb-zenodo")
UNMIXDB_PATH = Path("/home/etiandre/stage/datasets/unmixdb")
BASELINE = "2024-08-14T15:20:57.073921"

In [None]:
results = []
for i in tqdm(RESULTS_DIR.glob("**/results.pickle")):
    with open(i, "rb") as f:
        data = pickle.load(f)
        run = i.parent.parent.name
        if run == BASELINE:
            run = "baseline"
        data["run"] = run
        results.append(data)
data = pd.json_normalize(results)
data[["id", "stretch", "fx", "submix"]] = data["name"].str.extract(
    r"(set\d+mix\d+)-(\w+)-(\w+)-(\d+)"
)
data["stretch"] = data["stretch"].replace("resample", "resamp")
data["fx"] = data["fx"].replace("distortion", "dist")
data["fx"] = data["fx"].replace("compressor", "comp")
data["transformation"] = data["stretch"] + "\n" + data["fx"]

metas = {}
for i in RESULTS_DIR.glob("*/meta.json"):
    with open(i) as f:
        run = i.parent.name
        if run == BASELINE:
            run = "baseline"
        metas[run] = json.load(f)

In [None]:
data = data.groupby("run").filter(lambda x: len(x) >= 1931)  # only keep full results

In [None]:
data.groupby("run").count()

In [None]:
for i in sorted(metas.keys()):
    print("baseline", i)
    pprint(list(dictdiffer.diff(metas["baseline"], metas[i])))

In [None]:
data.loc[data["run"] == "2024-08-15T09:54:37.938964", "run"] = "NOISE_DIM=25"
data.loc[data["run"] == "2024-08-15T09:52:24.231529", "run"] = "NOISE_DIM=15"
data.loc[data["run"] == "2024-08-14T16:03:54.912066", "run"] = "NOISE_DIM=5"

In [None]:
ret = []
for l in data.tracks:
    if type(l) is not list:
        ret.append(np.nan)
        continue
    print([i["speed_est"] / i["speed_real"] for i in l])
    ret.append(np.mean([i["speed_est"] / i["speed_real"] for i in l]))

data["speed_ratio"] = ret

In [None]:
def analysis_boxplots(runs, name=None, figsize=(7,4)):
    order = [
        "none\nnone",
        "none\nbass",
        "none\ncomp",
        "none\ndist",
        "resamp\nnone",
        "resamp\nbass",
        "resamp\ncomp",
        "resamp\ndist",
        "stretch\nnone",
        "stretch\nbass",
        "stretch\ncomp",
        "stretch\ndist",
    ]

    HIGHPARAMS = [
        "gain_err",
        "warp_err",
        "track_start_err",
        "speed_ratio",
        "fadein_start_err",
        "fadein_stop_err",
        "fadeout_start_err",
        "fadeout_stop_err",
    ]
    LOGSCALE = [
        True,
        False,
        False,
        False,
        False,
        False,
        False,
        False,
    ]
    YLIM = [
        (None, None),
        (None, None),
        (0, 35),
        (0.5, 1.5),
        (None, None),
        (None, None),
        (None, None),
        (None, None),
    ]

    HIGHPARAMS_names = [
        "Gain error",
        "Warp error [s]",
        "Cue point error [s]",
        "Speed ratio",
        "Fade-in start error [s]",
        "Fade-in end error [s]",
        "Fade-out start error [s]",
        "Fade-out end error [s]",
    ]

    for i, p in enumerate(HIGHPARAMS):
        plt.figure(figsize=figsize)
        sns.boxplot(
            data,
            x="transformation",
            y=p,
            hue="run",
            log_scale=LOGSCALE[i],
            order=order,
            hue_order=runs,
            fliersize=1,
        )
        plt.ylabel(HIGHPARAMS_names[i])
        plt.ylim(*YLIM[i])
        if len(runs) == 1:
            plt.gca().get_legend().remove()
        plt.tight_layout()

        if name is not None:
            plt.savefig(f"results-plots/{name}_{p}.svg")
        plt.show()

In [None]:
analysis_boxplots(["NOISE_DIM=15"], "best")

In [None]:
analysis_boxplots(
    [
        "baseline",
        "NOISE_DIM=15",
    ],
    "noise"
)

In [None]:
sorted_data = data.sort_values(
    by=["gain_err"], ascending=True, na_position="last"
).reset_index()
best_gain = sorted_data.loc[4]
print(best_gain.run, best_gain["name"])
print(best_gain.gain_err)
print(best_gain.warp_err)
import plot

plt.figure(figsize=(6, 4))
plot.plot_gain(best_gain.tau, best_gain.gain_est, best_gain.gain_real)
plt.savefig("best_gain.svg")

plt.figure(figsize=(6, 4))
plot.plot_warp(best_gain.tau, best_gain.warp_est, best_gain.warp_real)
plt.savefig("best_warp.svg")
