In [637]:
import json
import pandas as pd
import warnings
import seaborn as sns
from pyprojroot import here
import plotly.express as px
import numpy as np
import math

methods = ["JustCopy", "TimeGAN", "Time-Transformer", "TransFusion", "TTS-GAN", "TimeVQVAE"]
datasets = ["D2", "D3", "D4", "D5", "D6", "D7"]

rows = []

for method in methods:
    for dataset in datasets:
        file_candidates = list(here('result').glob(f'numeric_{method}_{dataset}_*.json'))
        if len(file_candidates) != 1:
            warnings.warn(f"Ignoring {method} {dataset}: Expected one result file for {method} {dataset}, instead matched {file_candidates}.")
            continue
        file_path = file_candidates[0]
        
        with open(file_path) as f:
            data = json.load(f)
        
        rows.append(((method, dataset), data))

df_all = pd.DataFrame.from_dict(dict(rows), orient="index")
df_all.index = pd.MultiIndex.from_tuples(df_all.index, names=["Method", "Dataset"])

# load timings
def timings_path(method):
    paths = list(here("models").glob(f"*{method}*/timings.csv"))
    assert len(paths) == 1
    return paths[0]

timings = pd.DataFrame()
for method in methods:
    df = pd.read_csv(timings_path(method))
    df["Method"] = method
    timings = pd.concat([timings, df], ignore_index=True)

_timings = timings.melt(id_vars="Method", var_name="Dataset", value_name="Time").set_index(["Method", "Dataset"]).sort_index(level="Method")

measures_order = ["PS", "DS", "C-FID","MDD", "ACD",	"SD", "KD", "ED", "DTW", "Time"]
ranking_order = ["JustCopy", "TransFusion", "TimeVQVAE", "Time-Transformer", "TimeGAN", "TTS-GAN"]
df_all = pd.concat([df_all, _timings], axis = 1).loc[ranking_order][measures_order]

In [638]:
cm = sns.diverging_palette(h_neg=130, h_pos=0, as_cmap=True)


def style_results_table(df):
    sci_columns = {"ED", "DTW", "C-FID"}
    decimal_columns = set(df.columns) - sci_columns

    # Create Styler on the renamed DataFrame
    styled = df.style

    # Formatting
    fmt = {}

    # Decimal formatting
    def decimal_fmt(x, sig=3):
        import math
        if pd.isna(x):
            return ""  # or "NaN"
        if x == 0:
            return "0"
        # Determine number of decimal places needed for 'sig' significant digits
        digits = sig - int(math.floor(math.log10(abs(x)))) - 1
        # Round the number
        rounded = round(x, max(digits, 0))
        # Convert to string and strip unnecessary zeros
        s = f"{rounded:.{max(digits,0)}f}".rstrip("0").rstrip(".")
        return s

    if decimal_columns:
        for col in decimal_columns:
            if col in df.columns:
                fmt[col] = decimal_fmt  # simple decimal formatting

    def sci_fmt_unicode(x):
        if x == 0:
            return "0"
        exp = math.floor(math.log10(abs(x)))
        coeff = x / 10**exp
        if exp == 0 or exp == 1 or exp == -1:
            return f"{coeff:.3g}"
        # Otherwise, use Unicode superscript
        superscripts = str.maketrans("0123456789-", "⁰¹²³⁴⁵⁶⁷⁸⁹⁻")
        return f"{coeff:.3g}×10{str(exp).translate(superscripts)}"

    if sci_columns:
        for col in sci_columns:
            if col in df.columns:
                fmt[col] = sci_fmt_unicode  

    styled = styled.format(fmt)

    # Background gradients

    ranges = {
        "DS": (0.012, 0.476),
        "PS": (0.113, 0.279),
        "C-FID": (0, 1),
        "MDD": (0.0231, 1.05),
        "ACD": (0, 1),
        "SD": (0, 1),
        "KD": (0, 2.16),
        "ED": (0, 5),
        "DTW": (0, 16),
        "Time": (0, 961),
    }
    for col, (vmin, vmax) in ranges.items():
        if col in df.columns:  # note: ranges refer to original df
            styled = styled.background_gradient(
                cmap=cm, vmin=vmin, vmax=vmax, subset=col
            )

    return styled


def group_measures(df):
    category_map = {
    "Utility": ["PS"],
    "Fidelity": ["DS", "C-FID", "MDD", "ACD", "SD", "KD", "ED", "DTW"],
    "Training Efficiency": ["Time"]
    }

    # Create a list of tuples (top-level, column name)
    new_cols = []
    for cat, cols in category_map.items():
        for col in cols:
            new_cols.append((cat, col))

    # Reorder df columns to match the new structure
    df = df[[col for _, col in new_cols]]

    # Apply MultiIndex to columns
    df.columns = pd.MultiIndex.from_tuples(new_cols)
    return df

df_no_copy = df_all.drop(index="JustCopy", level="Method")
style_results_table(df_all)

Unnamed: 0_level_0,Unnamed: 1_level_0,PS,DS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Method,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
JustCopy,D2,0.0377,0.0101,-1.67×10⁻¹²,0.000317,0.0,0.0,0.0,0,0,0
JustCopy,D3,0.0367,0.0101,-2.12×10⁻¹²,0.000282,0.0,0.0,0.0,0,0,0
JustCopy,D4,0.0571,0.0138,-2.43×10⁻¹³,0.000264,0.0,0.0,0.0,0,0,0
JustCopy,D5,0.25,0.0167,-2.42×10⁻¹⁵,0.0607,0.0,0.0,0.0,0,0,0
JustCopy,D6,0.248,0.00992,-3.29×10⁻¹⁵,0.0711,0.0,0.0,0.0,0,0,0
JustCopy,D7,0.0493,0.0116,4.34×10⁻¹⁵,0.00579,0.0,0.0,0.0,0,0,0
TransFusion,D2,0.0398,0.147,1.03,0.794,0.0336,0.533,2.63,1.07,2.77,43
TransFusion,D3,0.0378,0.119,5.48×10⁻³,0.405,0.113,0.166,0.811,2.68,6.83,222
TransFusion,D4,0.0549,0.0314,1.04×10⁻²,0.292,0.0402,0.118,0.243,2.81,8.95,454
TransFusion,D5,0.251,0.11,1.03×10⁻²,0.224,0.0633,0.106,0.421,1.01,6.28,201


In [639]:
print(style_results_table(df_all).to_latex())

\begin{tabular}{llrrrrrrrrrr}
{} & {} & {PS} & {DS} & {C-FID} & {MDD} & {ACD} & {SD} & {KD} & {ED} & {DTW} & {Time} \\
{Method} & {Dataset} & {} & {} & {} & {} & {} & {} & {} & {} & {} & {} \\
\multirow[c]{6}{*}{JustCopy} & D2 & \background-color#398641 \color#f1f1f1 0.0377 & \background-color#398641 \color#f1f1f1 0.0101 & \background-color#398641 \color#f1f1f1 -1.67×10⁻¹² & \background-color#398641 \color#f1f1f1 0.000317 & \background-color#398641 \color#f1f1f1 0 & \background-color#398641 \color#f1f1f1 0 & \background-color#398641 \color#f1f1f1 0 & \background-color#398641 \color#f1f1f1 0 & \background-color#398641 \color#f1f1f1 0 & \background-color#398641 \color#f1f1f1 0 \\
 & D3 & \background-color#398641 \color#f1f1f1 0.0367 & \background-color#398641 \color#f1f1f1 0.0101 & \background-color#398641 \color#f1f1f1 -2.12×10⁻¹² & \background-color#398641 \color#f1f1f1 0.000282 & \background-color#398641 \color#f1f1f1 0 & \background-color#398641 \color#f1f1f1 0 & \background-color#39

In [640]:
style_results_table(
    df_all
        .drop(index="TTS-GAN", level="Method")
        .drop(index="JustCopy", level="Method")
        .groupby(level="Dataset").mean()
    )

Unnamed: 0_level_0,PS,DS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
D2,0.0514,0.209,1.19,0.617,0.152,0.472,2.4,1.18,3.05,97
D3,0.0493,0.227,7.41×10⁻²,0.653,0.256,0.291,1.9,2.55,6.49,426
D4,0.0749,0.203,5.24×10⁻²,0.511,0.146,0.162,0.317,2.77,8.8,443
D5,0.307,0.398,1.99,0.376,0.389,0.237,0.833,1.01,6.26,161
D6,0.278,0.432,4.19,0.507,1.05,0.35,0.872,2.48,1.44,701
D7,0.0596,0.295,2.08,0.293,0.84,0.31,0.751,2.4,9.01,596


In [641]:
style_results_table(
    df_all
        .groupby(level="Method").mean()
        .loc[ranking_order]
    )

Unnamed: 0_level_0,PS,DS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
JustCopy,0.113,0.012,-6.73×10⁻¹³,0.0231,0.0,0.0,0.0,0,0,0.0
TransFusion,0.114,0.115,2.41×10⁻²,0.344,0.122,0.176,0.718,2.1,8.13,527.0
TimeVQVAE,0.146,0.379,2.3,0.509,0.551,0.315,1.27,1.99,7.73,84.8
Time-Transformer,0.143,0.396,2.11,0.534,0.495,0.378,1.41,2.03,7.91,42.5
TimeGAN,0.143,0.286,2.49,0.583,0.717,0.346,1.31,2.14,8.22,961.0
TTS-GAN,0.279,0.476,5.25×10¹⁶,1.05,6.71,1.0,2.16,4.68×10⁷,1.52×10⁸,652.0


In [642]:
style_results_table(df_all.groupby("Method").mean()["PS"].to_frame().sort_values("PS"))

Unnamed: 0_level_0,PS
Method,Unnamed: 1_level_1
JustCopy,0.113
TransFusion,0.114
Time-Transformer,0.143
TimeGAN,0.143
TimeVQVAE,0.146
TTS-GAN,0.279


In [643]:
style_results_table(df_all.groupby("Method").mean()["Time"].to_frame().sort_values("Time").sort_values("Time"))

Unnamed: 0_level_0,Time
Method,Unnamed: 1_level_1
JustCopy,0.0
Time-Transformer,42.5
TimeVQVAE,84.8
TransFusion,527.0
TTS-GAN,652.0
TimeGAN,961.0


In [644]:
style_results_table(
    df_all
        .groupby("Method").mean()
        .drop(columns=["Time", "PS"])
        .loc[ranking_order]
        )

Unnamed: 0_level_0,DS,C-FID,MDD,ACD,SD,KD,ED,DTW
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
JustCopy,0.012,-6.73×10⁻¹³,0.0231,0.0,0.0,0.0,0,0
TransFusion,0.115,2.41×10⁻²,0.344,0.122,0.176,0.718,2.1,8.13
TimeVQVAE,0.379,2.3,0.509,0.551,0.315,1.27,1.99,7.73
Time-Transformer,0.396,2.11,0.534,0.495,0.378,1.41,2.03,7.91
TimeGAN,0.286,2.49,0.583,0.717,0.346,1.31,2.14,8.22
TTS-GAN,0.476,5.25×10¹⁶,1.05,6.71,1.0,2.16,4.68×10⁷,1.52×10⁸


## Spider Plots

In [645]:
rankings = df_no_copy.unstack(level=1).rank()

ranking_by_metric = rankings.stack().reset_index().groupby("Method").mean()
spider_by_metric = ranking_by_metric.reset_index().melt(id_vars="Method")
px.line_polar(spider_by_metric, r="value", theta="variable", color="Method", line_close=True) \
    .update_layout(polar={"radialaxis": {"range": [5.9, 1], "dtick": 1}})

In [646]:
ranking_by_dataset = rankings.stack(level=0).reset_index().groupby("Method").mean()
spider_by_dataset = ranking_by_dataset.reset_index().melt(id_vars="Method")
px.line_polar(spider_by_dataset, r="value", theta="Dataset", color="Method", line_close=True) \
    .update_layout(polar={"radialaxis": {"range": [5.9,1], "dtick": 1}})

In [647]:
import plotly.graph_objects as go
import scikit_posthocs as sp

def conover_test(df):
    df = df.unstack(level=1).rank().T.reset_index().melt(id_vars=['level_0', 'Dataset'], var_name='Method', value_name='Rank')
    posthoc = sp.posthoc_conover(df, group_col='Method', val_col='Rank', p_adjust='bonferroni')

    return posthoc

def pretty_conover(df):
    posthoc = conover_test(df)

    p_threshold = 0.05 / len(posthoc.columns)
    print(p_threshold)

    def highlight_below_threshold(val):
        color = 'background-color: turquoise' if val < p_threshold else ''
        return color

    return posthoc.style.applymap(highlight_below_threshold).format("{:.2e}")

def plot_on_number_line(s, title, axis_label, connections=[], textpositions=[], range_max=None, range_min=None):
    colors = [color for _, color in zip(methods, px.colors.qualitative.Plotly)]

    fig = go.Figure(go.Scatter(
        x=s.values,
        y=[0]*len(s),
        mode="markers+text",
        text=s.index,
        textfont=dict(color=colors),
        textposition=textpositions,
        marker=dict(size=12, color=colors)
    ))

    if not range_max:
        range_max = s.max()
    if not range_min:
        range_min = s.min()
    h_margin = 0.1*(range_max-range_min)

    for i, (m1, m2) in enumerate(connections):
        y_offset = (i+1)*0.07
        gray_shade = 50 + 200 * (i/len(connections))
        line_color = f"rgb({gray_shade},{gray_shade},{gray_shade})"
        fig.add_shape(
            type="line",
            x0=s[m1], x1=s[m2],
            y0=y_offset, y1=y_offset,
            line=dict(color=line_color, width=2),
            layer="below"
        )

    # Thin horizontal number line
    fig.add_shape(type="line",
                x0=range_min, x1=range_max,
                y0=0, y1=0,
                line=dict(color="black", width=1),
                layer="below")
    
    # Start line
    fig.add_shape(type="line",
                x0=range_min, x1=range_min,
                y0=-1, y1=1,
                line=dict(color="white", width=1),
                layer="below")

    # Layout tweaks for minimal look
    fig.update_yaxes(visible=False)
    fig.update_xaxes(range=[range_min-h_margin, range_max+h_margin], showgrid=True, zeroline=False)
    fig.update_layout(
        height=200,
        xaxis_title=axis_label,
        yaxis_title="",
        showlegend=False,
        margin=dict(t=60, b=60, l=150, r=150),
        title=title,
    )

    fig.update_layout(width=1000, height=210)
    fig.show()
    fig.update_layout(
        title=None,
        margin=dict(t=60, b=60, l=50, r=50)
    )
    fig.write_image(f"{title}.pdf")
    return fig


def plot_average_rankings(df, title, connections=[], textpositions=["top center","bottom center","top center","top center","top center"]):
    s = df.unstack(level=1).rank().mean(axis=1)

    range_max = len(df.index.get_level_values(level="Method").unique())

    return plot_on_number_line(s, title=title, axis_label="Average Rankings", connections=connections, textpositions=textpositions, range_min=1, range_max=range_max)

connections = [("TimeGAN", "Time-Transformer"), ("Time-Transformer", "TimeVQVAE")]
plot_average_rankings(df_no_copy, connections=connections, title="All Measures")


In [648]:
pretty_conover(df_no_copy)

0.01


Unnamed: 0,TTS-GAN,Time-Transformer,TimeGAN,TimeVQVAE,TransFusion
TTS-GAN,1.0,2.8799999999999995e-23,1.27e-15,1.09e-27,3.01e-41
Time-Transformer,2.8799999999999995e-23,1.0,0.197,1.0,1.22e-05
TimeGAN,1.27e-15,0.197,1.0,0.00363,2.69e-11
TimeVQVAE,1.09e-27,1.0,0.00363,1.0,0.00265
TransFusion,3.01e-41,1.22e-05,2.69e-11,0.00265,1.0


In [649]:
connections=[("TimeVQVAE", "TimeGAN")]
plot_average_rankings(df_all.drop(index="JustCopy", level="Method").drop(columns=["PS", "Time"]), title="Fidelity", connections=connections)

In [650]:
pretty_conover(df_no_copy.drop(columns=["PS", "Time"]))

0.01


Unnamed: 0,TTS-GAN,Time-Transformer,TimeGAN,TimeVQVAE,TransFusion
TTS-GAN,1.0,1.33e-22,3.73e-20,3.0699999999999998e-30,3.29e-43
Time-Transformer,1.33e-22,1.0,1.0,0.203,2.17e-08
TimeGAN,3.73e-20,1.0,1.0,0.021,2.64e-10
TimeVQVAE,3.0699999999999998e-30,0.203,0.021,1.0,0.0013
TransFusion,3.29e-43,2.17e-08,2.64e-10,0.0013,1.0


In [651]:
connections=[("TTS-GAN", "TimeGAN"), ("TimeVQVAE", "TransFusion")]
plot_average_rankings(df_no_copy["PS"].to_frame(), connections=connections, title="Utility", textpositions=["top center","top center","top left","bottom center","top center"])

In [652]:
pretty_conover(df_no_copy["PS"].to_frame())

0.01


Unnamed: 0,TTS-GAN,Time-Transformer,TimeGAN,TimeVQVAE,TransFusion
TTS-GAN,1.0,1.0,0.785,1.0,0.000775
Time-Transformer,1.0,1.0,1.0,1.0,0.0223
TimeGAN,0.785,1.0,1.0,1.0,0.0799
TimeVQVAE,1.0,1.0,1.0,1.0,0.0223
TransFusion,0.000775,0.0223,0.0799,0.0223,1.0


In [653]:
from scipy.stats import friedmanchisquare

df = df_no_copy.unstack(level=1).T
ranks = df.rank(axis=0, method='average', ascending=False)

stat, p = friedmanchisquare(*[ranks.loc[m].values for m in ranks.index])
p

8.69339661590917e-22

In [654]:
import pandas as pd

def leaderboard(df):
    out = pd.DataFrame(index=df.index.levels[1], columns=df.columns)

    for lvl2 in df.index.levels[1]:
        subset = df.xs(lvl2, level=1)
        min_idx = subset.idxmin()
        out.loc[lvl2] = min_idx


    def color_cells(value):
        if value == "TTS-GAN":
            return 'background-color: yellow'
        if value == "TransFusion":
            return 'background-color: lightgreen'
        if value == "TimeGAN":
            return 'background-color: lightblue'
        if value == "Time-Transformer":
            return 'background-color: pink'
        if value == "TimeVQVAE":
            return 'background-color: turquoise'

    return out.style.applymap(color_cells)

leaderboard(df_no_copy)

Unnamed: 0_level_0,PS,DS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
D2,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TransFusion,TimeVQVAE,TimeVQVAE,TransFusion,TransFusion,Time-Transformer
D3,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TimeVQVAE,TimeVQVAE,Time-Transformer
D4,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TimeVQVAE,TimeVQVAE,Time-Transformer,Time-Transformer,Time-Transformer
D5,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,Time-Transformer,Time-Transformer,Time-Transformer
D6,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TimeVQVAE,TimeVQVAE,Time-Transformer
D7,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TransFusion,TimeVQVAE,TimeVQVAE,Time-Transformer


In [655]:
leaderboard(df_all[df_all.index.get_level_values(level=0).isin(["TTS-GAN", "TimeGAN"])])

Unnamed: 0_level_0,PS,DS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
D2,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN
D3,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TTS-GAN
D4,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TTS-GAN
D5,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN
D6,TTS-GAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TTS-GAN
D7,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TTS-GAN


In [656]:
leaderboard(df_all[df_all.index.get_level_values(level=0).isin(["Time-Transformer", "TimeGAN"])])

Unnamed: 0_level_0,PS,DS,C-FID,MDD,ACD,SD,KD,ED,DTW,Time
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
D2,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,TimeGAN,Time-Transformer
D3,TimeGAN,TimeGAN,TimeGAN,Time-Transformer,TimeGAN,TimeGAN,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer
D4,Time-Transformer,TimeGAN,Time-Transformer,Time-Transformer,TimeGAN,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer
D5,TimeGAN,TimeGAN,TimeGAN,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer
D6,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer
D7,Time-Transformer,TimeGAN,Time-Transformer,TimeGAN,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer,Time-Transformer


In [657]:
rankings

Unnamed: 0_level_0,PS,PS,PS,PS,PS,PS,DS,DS,DS,DS,...,DTW,DTW,DTW,DTW,Time,Time,Time,Time,Time,Time
Dataset,D2,D3,D4,D5,D6,D7,D2,D3,D4,D5,...,D4,D5,D6,D7,D2,D3,D4,D5,D6,D7
Method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
TransFusion,2.0,1.0,1.0,1.0,1.0,1.0,2.0,1.0,1.0,1.0,...,3.0,2.0,3.0,3.0,2.0,3.0,3.0,3.0,5.0,4.0
TimeVQVAE,3.0,3.0,2.0,5.0,4.0,3.0,3.0,3.0,4.0,3.0,...,2.0,3.0,1.0,1.0,3.0,2.0,2.0,2.0,2.0,2.0
Time-Transformer,4.0,5.0,3.0,4.0,2.0,2.0,4.0,4.0,3.0,4.0,...,1.0,1.0,2.0,2.0,1.0,1.0,1.0,1.0,1.0,1.0
TimeGAN,1.0,2.0,4.0,2.0,5.0,4.0,1.0,2.0,2.0,2.0,...,4.0,4.0,4.0,4.0,4.0,5.0,5.0,4.0,4.0,5.0
TTS-GAN,5.0,4.0,5.0,3.0,3.0,5.0,5.0,5.0,5.0,5.0,...,5.0,5.0,5.0,5.0,5.0,4.0,4.0,5.0,3.0,3.0


In [658]:
df = df_no_copy.unstack(level=1).rank()
utility_cols = [col for col in df.columns if col[0] in ["PS"]]

fidelity_cols = [col for col in df.columns if col[0] not in ["PS", "Time"]]

timing_cols = [col for col in df.columns if col[0] in ["Time"]]

avg_utility = df[utility_cols].mean(axis=1)
avg_fidelity = df[fidelity_cols].mean(axis=1)
avg_timing = df[timing_cols].mean(axis=1)

new_df = pd.DataFrame({
    "Utility": avg_utility,
    "Fidelity": avg_fidelity,
    "Time": avg_timing
})

new_df.style.background_gradient(cmap=cm)

Unnamed: 0_level_0,Utility,Fidelity,Time
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
TransFusion,1.166667,1.708333,3.333333
TimeVQVAE,3.333333,2.4375,2.166667
Time-Transformer,3.333333,2.875,1.0
TimeGAN,3.0,3.020833,4.5
TTS-GAN,4.166667,4.958333,4.0


In [659]:
timings.set_index("Method").drop(index="JustCopy")

Unnamed: 0_level_0,D2,D3,D4,D5,D6,D7
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
TimeGAN,259,1373,1195,304,1252,1385
Time-Transformer,13,16,28,66,77,55
TransFusion,43,222,454,201,1387,856
TTS-GAN,646,660,649,649,660,645
TimeVQVAE,73,94,94,74,88,86


In [660]:
average_timings = timings.set_index("Method").mean(axis=1).to_frame(name="Average Wall Clock Time").drop(index="JustCopy")


plot_on_number_line(\
    average_timings["Average Wall Clock Time"], \
    title = "Average Wall Clock Time", \
    axis_label = "minutes", \
    range_min=1, \
    range_max=1000, \
    textpositions=["top center","bottom center","top center","bottom center","top center"])

In [661]:
from plotly.subplots import make_subplots

# Create your separate figures (using your existing functions)
fig1 = plot_average_rankings(df_no_copy, title="Overall",  connections=[])
fig2 = plot_average_rankings(df_no_copy.drop(columns=["PS", "Time"]), title="Fidelity", connections=[])
fig3 = plot_average_rankings(df_no_copy["PS"].to_frame(), title="Utility", connections=[])

# MAKE SUBPLOTS (1 row, 3 columns for example)
combined = make_subplots(rows=3, cols=1, shared_yaxes=True)

# Helper to add a figure into a subplot cell:
def add_fig_to_subplot(fig, row, col):
    for trace in fig.data:
        combined.add_trace(trace, row=row, col=col)
    if "shapes" in fig.layout:
        shapes = list(fig.layout.shapes)
        for s in shapes:
            s.xref = f"x{col}"
            s.yref = f"y{row}"
            combined.add_shape(s)

# Add your figures to the grid:
add_fig_to_subplot(fig1, 1, 1)
add_fig_to_subplot(fig2, 2, 1)
add_fig_to_subplot(fig3, 3, 1)

combined.update_layout(height=600, width=1000, title_text="Combined Ranking Plots")
combined.show()