In [107]:
import json
import pandas as pd
import warnings
import seaborn as sns
from pyprojroot import here
import plotly.express as px

methods = ["JustCopy", "TimeGAN", "Time-Transformer", "TransFusion", "TTS-GAN", "TimeVQVAE"]
datasets = ["D2", "D3", "D4", "D5", "D6", "D7"]

rows = []

for method in methods:
    for dataset in datasets:
        file_candidates = list(here('result').glob(f'numeric_{method}_{dataset}_*.json'))
        if len(file_candidates) != 1:
            warnings.warn(f"Ignoring {method} {dataset}: Expected one result file for {method} {dataset}, instead matched {file_candidates}.")
            continue
        file_path = file_candidates[0]
        
        with open(file_path) as f:
            data = json.load(f)
        
        rows.append(((method, dataset), data))

df_all = pd.DataFrame.from_dict(dict(rows), orient="index")
df_all.index = pd.MultiIndex.from_tuples(df_all.index, names=["Method", "Dataset"])

# load timings
def timings_path(method):
    paths = list(here("models").glob(f"*{method}*/timings.csv"))
    assert len(paths) == 1
    return paths[0]

timings = pd.DataFrame()
for method in methods:
    df = pd.read_csv(timings_path(method))
    df["Method"] = method
    timings = pd.concat([timings, df], ignore_index=True)

_timings = timings.melt(id_vars="Method", var_name="Dataset", value_name="Time").set_index(["Method", "Dataset"]).sort_index(level="Method")

df_all = pd.concat([df_all, _timings], axis = 1)

In [108]:
cm = sns.diverging_palette(h_neg=130, h_pos=0, as_cmap=True)

def style_results_table(df):
    return df.style\
    .background_gradient(cmap=cm, vmin=0, vmax=0.5, subset=["DS", "PS"])\
    .background_gradient(cmap=cm, vmin=0, vmax=1, subset=["C-FID"])\
    .background_gradient(cmap=cm, vmin=0, vmax=2, subset=["MDD"])\
    .background_gradient(cmap=cm, vmin=0, vmax=1, subset=["ACD"])\
    .background_gradient(cmap=cm, vmin=0, vmax=1, subset=["SD"])\
    .background_gradient(cmap=cm, vmin=0, vmax=3, subset=["KD"])\
    .background_gradient(cmap=cm, vmin=0, vmax=3, subset=["ED"])\
    .background_gradient(cmap=cm, vmin=0, vmax=15, subset=["DTW"])\
    .background_gradient(cmap=cm, vmin=0, vmax=1000, subset=["Time"])\
    .format({"DTW": "{:.2e}", "C-FID": "{:.2e}", "ED": "{:.2e}"})


def group_measures(df):
    category_map = {
    "Utility": ["PS"],
    "Fidelity": ["DS", "C-FID", "MDD", "ACD", "SD", "KD", "ED", "DTW"],
    "Training Efficiency": ["Time"]
    }

    # Create a list of tuples (top-level, column name)
    new_cols = []
    for cat, cols in category_map.items():
        for col in cols:
            new_cols.append((cat, col))

    # Reorder df columns to match the new structure
    df = df[[col for _, col in new_cols]]

    # Apply MultiIndex to columns
    df.columns = pd.MultiIndex.from_tuples(new_cols)
    return df

df_no_copy = df_all.drop(index="JustCopy", level="Method")
style_results_table(df_all)

Unnamed: 0_level_0,Unnamed: 1_level_0,DS,PS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Method,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
JustCopy,D2,0.010101,0.03772,-1.67e-12,0.000317,0.0,0.0,0.0,0.0,0.0,0
JustCopy,D3,0.01014,0.036687,-2.12e-12,0.000282,0.0,0.0,0.0,0.0,0.0,0
JustCopy,D4,0.013775,0.05706,-2.43e-13,0.000264,0.0,0.0,0.0,0.0,0.0,0
JustCopy,D5,0.016671,0.250396,-2.42e-15,0.060657,0.0,0.0,0.0,0.0,0.0,0
JustCopy,D6,0.009915,0.247743,-3.29e-15,0.071081,0.0,0.0,0.0,0.0,0.0,0
JustCopy,D7,0.011593,0.049266,4.34e-15,0.005795,0.0,0.0,0.0,0.0,0.0,0
TTS-GAN,D2,0.353788,0.543119,3.14e+17,1.003093,2.90814,1.346879,4.078087,206000000.0,636000000.0,646
TTS-GAN,D3,0.5,0.070101,6.62,1.249867,7.3996,1.411439,2.743262,7.81,19.6,660
TTS-GAN,D4,0.499752,0.287067,758000000000000.0,1.012819,9.31871,0.815356,0.909358,43400000.0,132000000.0,649
TTS-GAN,D5,0.5,0.298715,4800000000.0,1.010239,5.06993,0.982785,1.486653,25800.0,153000.0,649


In [109]:
print(style_results_table(df_all).to_latex())

\begin{tabular}{llrrrrrrrrrr}
{} & {} & {DS} & {PS} & {C-FID} & {MDD} & {ACD} & {SD} & {KD} & {ED} & {DTW} & {Time} \\
{Method} & {Dataset} & {} & {} & {} & {} & {} & {} & {} & {} & {} & {} \\
\multirow[c]{6}{*}{JustCopy} & D2 & \background-color#408b48 \color#f1f1f1 0.010101 & \background-color#54975a \color#f1f1f1 0.037720 & \background-color#398641 \color#f1f1f1 -1.67e-12 & \background-color#398641 \color#f1f1f1 0.000317 & \background-color#398641 \color#f1f1f1 0.000000 & \background-color#398641 \color#f1f1f1 0.000000 & \background-color#398641 \color#f1f1f1 0.000000 & \background-color#398641 \color#f1f1f1 0.00e+00 & \background-color#398641 \color#f1f1f1 0.00e+00 & \background-color#398641 \color#f1f1f1 0 \\
 & D3 & \background-color#408b48 \color#f1f1f1 0.010140 & \background-color#529659 \color#f1f1f1 0.036687 & \background-color#398641 \color#f1f1f1 -2.12e-12 & \background-color#398641 \color#f1f1f1 0.000282 & \background-color#398641 \color#f1f1f1 0.000000 & \background-color

In [110]:
style_results_table(df_all.swaplevel("Method", "Dataset")\
    .sort_index(level=["Dataset", "Method"]))

Unnamed: 0_level_0,Unnamed: 1_level_0,DS,PS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Dataset,Method,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
D2,JustCopy,0.010101,0.03772,-1.67e-12,0.000317,0.0,0.0,0.0,0.0,0.0,0
D2,TTS-GAN,0.353788,0.543119,3.14e+17,1.003093,2.90814,1.346879,4.078087,206000000.0,636000000.0,646
D2,Time-Transformer,0.324495,0.085745,0.266,0.629389,0.24659,0.582419,2.577668,1.31,3.38,13
D2,TimeGAN,0.093434,0.037907,0.0201,0.443126,0.174772,0.472857,2.370855,1.18,3.05,259
D2,TimeVQVAE,0.271212,0.042083,0.0861,0.600689,0.151148,0.300332,2.007588,1.16,2.99,73
D2,TransFusion,0.14697,0.039753,0.103,0.793641,0.033649,0.533438,2.629961,1.07,2.77,43
D3,JustCopy,0.01014,0.036687,-2.12e-12,0.000282,0.0,0.0,0.0,0.0,0.0,0
D3,TTS-GAN,0.5,0.070101,6.62,1.249867,7.3996,1.411439,2.743262,7.81,19.6,660
D3,Time-Transformer,0.297712,0.071347,0.185,0.747137,0.489822,0.532613,2.88285,2.48,6.3,16
D3,TimeGAN,0.224909,0.039703,0.0542,0.864601,0.144977,0.272758,2.354675,2.64,6.67,1373


## Spider Plots

In [111]:
rankings = df_no_copy.unstack(level=1).rank()

ranking_by_metric = rankings.stack().reset_index().groupby("Method").mean()
spider_by_metric = ranking_by_metric.reset_index().melt(id_vars="Method")
px.line_polar(spider_by_metric, r="value", theta="variable", color="Method", line_close=True) \
    .update_layout(polar={"radialaxis": {"range": [5.9, 1], "dtick": 1}})

In [112]:
ranking_by_dataset = rankings.stack(level=0).reset_index().groupby("Method").mean()
spider_by_dataset = ranking_by_dataset.reset_index().melt(id_vars="Method")
px.line_polar(spider_by_dataset, r="value", theta="Dataset", color="Method", line_close=True) \
    .update_layout(polar={"radialaxis": {"range": [5.9,1], "dtick": 1}})

In [113]:
import plotly.graph_objects as go
import scikit_posthocs as sp
import numpy as np

def conover_test(df):
    df = df.unstack(level=1).rank().T.reset_index().melt(id_vars=['level_0', 'Dataset'], var_name='Method', value_name='Rank')
    posthoc = sp.posthoc_conover(df, group_col='Method', val_col='Rank', p_adjust='bonferroni')

    return posthoc

def pretty_conover(df):
    posthoc = conover_test(df)

    p_threshold = 0.05 / len(posthoc.columns)
    print(p_threshold)

    def highlight_below_threshold(val):
        color = 'background-color: turquoise' if val < p_threshold else ''
        return color

    return posthoc.style.applymap(highlight_below_threshold).format("{:.2e}")

def plot_on_number_line(s, title, axis_label, connections=[], textpositions=[], range_max=None, range_min=None):
    colors = [color for _, color in zip(methods, px.colors.qualitative.Plotly)]

    fig = go.Figure(go.Scatter(
        x=s.values,
        y=[0]*len(s),
        mode="markers+text",
        text=s.index,
        textfont=dict(color=colors),
        textposition=textpositions,
        marker=dict(size=12, color=colors)
    ))

    if not range_max:
        range_max = s.max()
    if not range_min:
        range_min = s.min()
    h_margin = 0.1*(range_max-range_min)

    for i, (m1, m2) in enumerate(connections):
        y_offset = (i+1)*0.07
        gray_shade = 50 + 200 * (i/len(connections))
        line_color = f"rgb({gray_shade},{gray_shade},{gray_shade})"
        fig.add_shape(
            type="line",
            x0=s[m1], x1=s[m2],
            y0=y_offset, y1=y_offset,
            line=dict(color=line_color, width=2),
            layer="below"
        )

    # Thin horizontal number line
    fig.add_shape(type="line",
                x0=range_min, x1=range_max,
                y0=0, y1=0,
                line=dict(color="black", width=1),
                layer="below")
    
    # Start line
    fig.add_shape(type="line",
                x0=range_min, x1=range_min,
                y0=-1, y1=1,
                line=dict(color="black", width=1),
                layer="below")

    # Layout tweaks for minimal look
    fig.update_yaxes(visible=False)
    fig.update_xaxes(range=[range_min-h_margin, range_max+h_margin], showgrid=True, zeroline=False)
    fig.update_layout(
        height=200,
        xaxis_title=axis_label,
        yaxis_title="",
        showlegend=False,
        margin=dict(t=60, b=60, l=150, r=150),
        title=title,
    )

    fig.show()

def plot_average_rankings(df, title, connections=[], textpositions=["top center","bottom center","top center","top center","top center"]):
    s = df.unstack(level=1).rank().mean(axis=1)

    range_max = len(df.index.get_level_values(level="Method").unique())

    plot_on_number_line(s, title=title, axis_label="Average Rankings", connections=connections, textpositions=textpositions, range_min=1, range_max=range_max)

connections = [("TimeGAN", "Time-Transformer"), ("Time-Transformer", "TimeVQVAE")]
plot_average_rankings(df_no_copy, connections=connections, title="All Measures")


In [114]:
pretty_conover(df_no_copy)

0.01


Unnamed: 0,TTS-GAN,Time-Transformer,TimeGAN,TimeVQVAE,TransFusion
TTS-GAN,1.0,2.8799999999999995e-23,1.27e-15,1.09e-27,3.01e-41
Time-Transformer,2.8799999999999995e-23,1.0,0.197,1.0,1.22e-05
TimeGAN,1.27e-15,0.197,1.0,0.00363,2.69e-11
TimeVQVAE,1.09e-27,1.0,0.00363,1.0,0.00265
TransFusion,3.01e-41,1.22e-05,2.69e-11,0.00265,1.0


In [115]:
connections=[("TimeVQVAE", "TimeGAN")]
plot_average_rankings(df_all.drop(index="JustCopy", level="Method").drop(columns=["PS", "Time"]), title="Fidelity", connections=connections)

In [116]:
pretty_conover(df_no_copy.drop(columns=["PS", "Time"]))

0.01


Unnamed: 0,TTS-GAN,Time-Transformer,TimeGAN,TimeVQVAE,TransFusion
TTS-GAN,1.0,1.33e-22,3.73e-20,3.0699999999999998e-30,3.29e-43
Time-Transformer,1.33e-22,1.0,1.0,0.203,2.17e-08
TimeGAN,3.73e-20,1.0,1.0,0.021,2.64e-10
TimeVQVAE,3.0699999999999998e-30,0.203,0.021,1.0,0.0013
TransFusion,3.29e-43,2.17e-08,2.64e-10,0.0013,1.0


In [117]:
connections=[("TTS-GAN", "TimeGAN"), ("TimeVQVAE", "TransFusion")]
plot_average_rankings(df_no_copy["PS"].to_frame(), connections=connections, title="Utility", textpositions=["top center","top center","top left","bottom center","top center"])

In [118]:
pretty_conover(df_no_copy["PS"].to_frame())

0.01


Unnamed: 0,TTS-GAN,Time-Transformer,TimeGAN,TimeVQVAE,TransFusion
TTS-GAN,1.0,1.0,0.785,1.0,0.000775
Time-Transformer,1.0,1.0,1.0,1.0,0.0223
TimeGAN,0.785,1.0,1.0,1.0,0.0799
TimeVQVAE,1.0,1.0,1.0,1.0,0.0223
TransFusion,0.000775,0.0223,0.0799,0.0223,1.0


In [119]:
from scipy.stats import friedmanchisquare

df = df_no_copy.unstack(level=1).T
ranks = df.rank(axis=0, method='average', ascending=False)

stat, p = friedmanchisquare(*[ranks.loc[m].values for m in ranks.index])
p

8.69339661590917e-22

In [120]:
import pandas as pd

def leaderboard(df):
    out = pd.DataFrame(index=df.index.levels[1], columns=df.columns)

    for lvl2 in df.index.levels[1]:
        subset = df.xs(lvl2, level=1)
        min_idx = subset.idxmin()
        out.loc[lvl2] = min_idx


    def color_cells(value):
        if value == "TTS-GAN":
            return 'background-color: yellow'
        if value == "TransFusion":
            return 'background-color: lightgreen'
        if value == "TimeGAN":
            return 'background-color: lightblue'
        if value == "Time-Transformer":
            return 'background-color: pink'
        if value == "TimeVQVAE":
            return 'background-color: turquoise'

    return out.style.applymap(color_cells)

leaderboard(df_no_copy)

Unnamed: 0_level_0,DS,PS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
D2,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TransFusion,TimeVQVAE,TimeVQVAE,TransFusion,TransFusion,Time-Transformer
D3,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TimeVQVAE,TimeVQVAE,Time-Transformer
D4,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TimeVQVAE,TimeVQVAE,Time-Transformer,Time-Transformer,Time-Transformer
D5,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,Time-Transformer,Time-Transformer,Time-Transformer
D6,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TimeVQVAE,TimeVQVAE,Time-Transformer
D7,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TimeVQVAE,TimeVQVAE,Time-Transformer


In [121]:
leaderboard(df_all[df_all.index.get_level_values(level=0).isin(["TTS-GAN", "TimeGAN"])])

Unnamed: 0_level_0,DS,PS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
D2,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN
D3,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TTS-GAN
D4,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TTS-GAN
D5,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN
D6,TTS-GAN,TTS-GAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TTS-GAN
D7,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TTS-GAN


In [122]:
leaderboard(df_all[df_all.index.get_level_values(level=0).isin(["Time-Transformer", "TimeGAN"])])

Unnamed: 0_level_0,DS,PS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
D2,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,Time-Transformer
D3,TimeGAN,TimeGAN,TimeGAN,Time-Transformer,TimeGAN,TimeGAN,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer
D4,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer,TimeGAN,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer
D5,TimeGAN,TimeGAN,TimeGAN,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer
D6,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer
D7,TimeGAN,Time-Transformer,Time-Transformer,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer


In [123]:
rankings

Unnamed: 0_level_0,DS,DS,DS,DS,DS,DS,PS,PS,PS,PS,...,DTW,DTW,DTW,DTW,Time,Time,Time,Time,Time,Time
Dataset,D2,D3,D4,D5,D6,D7,D2,D3,D4,D5,...,D4,D5,D6,D7,D2,D3,D4,D5,D6,D7
Method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
TTS-GAN,5.0,5.0,5.0,5.0,4.0,5.0,5.0,4.0,5.0,3.0,...,5.0,5.0,5.0,5.0,5.0,4.0,4.0,5.0,3.0,3.0
Time-Transformer,4.0,4.0,3.0,4.0,4.0,4.0,4.0,5.0,3.0,4.0,...,1.0,1.0,2.0,2.0,1.0,1.0,1.0,1.0,1.0,1.0
TimeGAN,1.0,2.0,2.0,2.0,4.0,2.0,1.0,2.0,4.0,2.0,...,4.0,4.0,4.0,4.0,4.0,5.0,5.0,4.0,4.0,5.0
TimeVQVAE,3.0,3.0,4.0,3.0,2.0,3.0,3.0,3.0,2.0,5.0,...,2.0,3.0,1.0,1.0,3.0,2.0,2.0,2.0,2.0,2.0
TransFusion,2.0,1.0,1.0,1.0,1.0,1.0,2.0,1.0,1.0,1.0,...,3.0,2.0,3.0,3.0,2.0,3.0,3.0,3.0,5.0,4.0


In [124]:
df = df_no_copy.unstack(level=1).rank()
utility_cols = [col for col in df.columns if col[0] in ["PS"]]

fidelity_cols = [col for col in df.columns if col[0] not in ["PS", "Time"]]

timing_cols = [col for col in df.columns if col[0] in ["Time"]]

avg_utility = df[utility_cols].mean(axis=1)
avg_fidelity = df[fidelity_cols].mean(axis=1)
avg_timing = df[timing_cols].mean(axis=1)

new_df = pd.DataFrame({
    "Utility": avg_utility,
    "Fidelity": avg_fidelity,
    "Time": avg_timing
})

new_df.style.background_gradient(cmap=cm)

Unnamed: 0_level_0,Utility,Fidelity,Time
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
TTS-GAN,4.166667,4.958333,4.0
Time-Transformer,3.333333,2.875,1.0
TimeGAN,3.0,3.020833,4.5
TimeVQVAE,3.333333,2.4375,2.166667
TransFusion,1.166667,1.708333,3.333333


In [125]:
timings.set_index("Method").drop(index="JustCopy")

Unnamed: 0_level_0,D2,D3,D4,D5,D6,D7
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
TimeGAN,259,1373,1195,304,1252,1385
Time-Transformer,13,16,28,66,77,55
TransFusion,43,222,454,201,1387,856
TTS-GAN,646,660,649,649,660,645
TimeVQVAE,73,94,94,74,88,86


In [126]:
average_timings = timings.set_index("Method").mean(axis=1).to_frame(name="Average Wall Clock Time").drop(index="JustCopy")


plot_on_number_line(\
    average_timings["Average Wall Clock Time"], \
    title = "Average Wall Clock Time", \
    axis_label = "minutes", \
    range_min=1, \
    range_max=1000, \
    textpositions=["top center","bottom center","top center","bottom center","top center"])