In [None]:
import gzip, json, glob
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "serif",
    #"font.serif": ["Palatino"],
    
    "legend.frameon": False,
    "legend.fancybox": False,
    
    'font.size': 8,
    'axes.linewidth': 0.6,

    'xtick.major.width': 0.6,
    'ytick.major.width': 0.6,
    'xtick.minor.width': 0.6,
    'ytick.minor.width': 0.6,
    
    "lines.linewidth": 0.9,
    
    "axes.grid": True,
    "grid.color": "#EEE"
    })

plt.rc("text.latex", preamble=r"\usepackage{amsmath}")

In [None]:
def load_files(pattern): # data map (example_i[, target_label]) => data
    data = {}
    max_memory = 0.0
    for filename in glob.glob(pattern):
        print("Reading file:", filename)
        with gzip.open(filename, "r") as f:
            lines = f.readlines()
            for line in lines:
                j = json.loads(line)
                #print(filename, j.keys())
                example_i = j["example_i"]
                if "target_label" in j:
                    key = (example_i, j["target_label"])
                else:
                    key = example_i

                d = data.get(key, {})
                data[key] = d
                
                if "veritas_deltas" in j:
                    max_memory = max(max_memory, max(x["memory"][-1] for x in j["veritas_log"]))
                    try: column_prefix = f"veritas{j['max_time']:02d}"
                    except:
                        s0 = filename.find("time")+4
                        s1 = filename.find("-", s0)
                        max_time = int(filename[s0:s1])
                        #print("no max time in", filename, f"extracted '{max_time}' from filename")
                        column_prefix = f"veritas{max_time:02d}"
                    d[f"{column_prefix}_time"] = j["veritas_time"]
                    d[f"{column_prefix}_delta"] = j["veritas_deltas"][-1][0]
                    #print("deltas", j["veritas_deltas"])
                
                if "merge_ext" in j and "max_clique" in j["merge_ext"]:
                    column_prefix = f"mext_T{j['merge_ext']['max_clique']}_L{j['merge_ext']['max_level']}"
                    d[f"{column_prefix}_time"] = j["merge_ext"]["times"][-1]
                    d[f"{column_prefix}_delta"] = j["merge_ext"]["deltas"][-1]
                    
                if "kantchelian" in j:
                    column_prefix = "kan"
                    d[f"{column_prefix}_time"] = j["kantchelian"]["time_p"]
                    d[f"{column_prefix}_delta"] = j["kantchelian_delta"]
                    #print(j["kantchelian"]["bounds"])
    print(f"max_memory for {pattern} is: {max_memory/(1024*1024)}")
    return data

def get_column_names(data):
    columns = set()
    for value in data.values():
        columns |= value.keys()
    return sorted(columns)
        

def to_df(data):
    colnames = get_column_names(data)
    columns = {}
    index = pd.Series(list(data.keys()))
    for c in colnames:
        values = {}
        for key, value in data.items():
            if c in value:
                values[key] = value[c]
        columns[c] = values
    df = pd.DataFrame(columns)
    df.set_index(index)
    return df

def load_to_df(pattern, dropna=True):
    data = load_files(pattern)
    df = to_df(data)
    df.sort_index(inplace=True, axis=0)
    if dropna: df = df.dropna()
    return df

In [None]:
dfs={}

In [None]:
dfs["covtype"] = load_to_df("/home/laurens/repos/veritas/tests/experiments/results/r1-covtype-*")

In [None]:
dfs["f-mnist"] = load_to_df("/home/laurens/repos/veritas/tests/experiments/results/r1-f-mnist-*")

In [None]:
dfs["higgs"] = load_to_df("/home/laurens/repos/veritas/tests/experiments/results/r1-higgs-*")

In [None]:
dfs["ijcnn1"] = load_to_df("/home/laurens/repos/veritas/tests/experiments/results/r1-ijcnn1-*")

In [None]:
dfs["mnist"] = load_to_df("/home/laurens/repos/veritas/tests/experiments/results/r1-mnist-*")

In [None]:
dfs["webspam"] = load_to_df("/home/laurens/repos/veritas/tests/experiments/results/r1-webspam-*")

In [None]:
dfs["mnist2v6"] = load_to_df("/home/laurens/repos/veritas/tests/experiments/results/r1-mnist2v6-*")

In [None]:
gridcolor="#EEEEEE"

# Bound difference

In [None]:
#datasets = ["covtype", "f-mnist", "higgs"]
#datasets = ["ijcnn1", "mnist", "webspam", "mnist2v6"]
datasets = ["covtype", "f-mnist", "higgs", "ijcnn1", "mnist", "webspam", "mnist2v6"]
datasets = ["webspam"]

fig, axs = plt.subplots(1, len(datasets), figsize=(len(datasets)*4.0, 1.8))
fig.subplots_adjust(left=0.15, right=0.9, top=0.85, bottom=0.25)

axs=[axs]

for d, ax in zip(datasets, axs):
    df = dfs[d]
    #display(d, df)
    time_columns = [c for c in df.columns if c.endswith("time")]
    delta_columns = [c for c in df.columns if c.endswith("delta")]
    time_mean = df[time_columns].mean()
    time_std = df[time_columns].std()
    #div_from_opt = df[delta_columns].subtract(df["kan_delta"], axis=0).abs().mean()
    div_from_opt = df[delta_columns].mean()
    speedup = (1.0/df[time_columns].divide(df["kan_time"], axis=0)).mean()
    
    scale_ = np.log10(div_from_opt.max().max()).round()
    scale = 10**-scale_ * 10
    
    print("scale", d, scale)
    #div_from_opt *= scale
    
    veritas_time_columns = [c for c in df.columns if c.endswith("time") and c.startswith("veritas")]
    veritas_delta_columns = [c for c in df.columns if c.endswith("delta") and c.startswith("veritas")]
    mer_time_columns = [c for c in df.columns if c.endswith("time") and c.startswith("mext")]
    mer_delta_columns = [c for c in df.columns if c.endswith("delta") and c.startswith("mext")]
    
    #ax.set_title(f"{d} (n={len(df)})")
    ax.set_title(f"{d}")
    ax.set_xlabel("Time")
    #ax.set_ylabel("Robustness delta value")
    #if scale != 1.0:
    #    ax.text(-0.2, 1.1, f'$\\delta \\times 10^{{{scale_:.0f}}}$', horizontalalignment='left', verticalalignment='center', transform=ax.transAxes)
    #else:
    ax.text(-0.1, 1.1, f'$\\delta$', horizontalalignment='left', verticalalignment='center', transform=ax.transAxes)
    
    #ax.plot(time_mean[veritas_time_columns], div_from_opt[veritas_delta_columns], marker=".", linestyle="-", markersize=4, label="Veritas")
    ax.errorbar(time_mean[veritas_time_columns], div_from_opt[veritas_delta_columns], xerr=time_std[veritas_time_columns],
               capthick=1.0, elinewidth=None, capsize=2.0, marker=".", linestyle=":", markersize=4, errorevery=4, label="Veritas")
    #for i, (x, y, m) in enumerate(zip(time_mean[veritas_time_columns], div_from_opt[veritas_delta_columns], speedup[veritas_time_columns])):
    #    ax.text(x, y-0.1, f"{m:.0f}×", horizontalalignment='right', verticalalignment='top', c="gray")
    #l, = ax.plot(time_mean[mer_time_columns], div_from_opt[mer_delta_columns], marker="8", markersize=5, linestyle=":", label="Merge")
    #for i, (x, y, m) in enumerate(zip(time_mean[mer_time_columns], div_from_opt[mer_delta_columns], speedup[mer_time_columns])):
    #    ax.text(x, y-0.1, f"{m:.0f}×", horizontalalignment='right', verticalalignment='top', c="gray")
    #ax.axhline(y=div_from_opt[mer_delta_columns][0], c=l.get_color(), ls=l.get_linestyle())
    ax.errorbar(time_mean[mer_time_columns], div_from_opt[mer_delta_columns], xerr=time_std[mer_time_columns],
               capthick=1.0, elinewidth=None, capsize=2.0, marker="8", markersize=5, linestyle="", label="Merge")
    #l, = ax.plot(time_mean[["kan_time"]], div_from_opt[["kan_delta"]], marker="*", linestyle=":", markersize=4, label="MILP")
    ax.errorbar(time_mean[["kan_time"]], div_from_opt[["kan_delta"]], marker="*", linestyle="", markersize=4,
                xerr=time_std[["kan_time"]],
               capthick=1.0, elinewidth=0.0, capsize=2.0, barsabove=True, label="MILP")
    ax.axhline(y=div_from_opt["kan_delta"], c="gray", ls=":", label="Exact")

    ax.set_xscale("log")
    
    #xlim = (0.0, 1.1*time_mean["kan_time"])
    #ax.set_xlim(xlim)
    #ax.set_xticks(list(np.arange(0.0,xlim[1], 10.0)))
    ax.legend(fontsize="large", bbox_to_anchor=(1.0, 0.8))

plt.savefig(f"/tmp/bound_err_{datasets[0]}.pdf")
plt.show()

# Counting stats

In [None]:

fig, axs = plt.subplots(1, len(dfs), figsize=(len(dfs)*5, 4))

better_stats = {}
worse_stats = {}
same_stats = {}

for (d, df), ax in zip(dfs.items(), axs):
    time_columns = [c for c in df.columns if c.endswith("time")]
    delta_columns = [c for c in df.columns if c.endswith("delta")]
    time_mean = df[time_columns].mean()
    veritas_time_columns = [c for c in df.columns if c.endswith("time") and c.startswith("veritas")]
    veritas_delta_columns = [c for c in df.columns if c.endswith("delta") and c.startswith("veritas")]
    mer_delta_columns = [c for c in df.columns if c.endswith("delta") and c.startswith("mext")]
    mer_time_columns = [c for c in df.columns if c.endswith("time") and c.startswith("mext")]
    
    div_from_opt = df[delta_columns].subtract(df["kan_delta"], axis=0).abs()
    
    same_threshold = (df["kan_delta"].quantile(0.8) - df["kan_delta"].quantile(0.2)) / 100
    
    print(f"same threshold {d}: {same_threshold}")
    
    ver_better = div_from_opt[delta_columns].subtract(div_from_opt[mer_delta_columns[0]], axis=0) < -same_threshold
    ver_worse = div_from_opt[delta_columns].subtract(div_from_opt[mer_delta_columns[0]], axis=0) > same_threshold
    ver_same = ~ver_better & ~ver_worse
    #display(ver_better.sum(), ver_worse.sum(), ver_same.sum())
    
    n = len(df)
    ax.set_title(f"{d} (n={n})")
    ax.set_xlabel("Time [s]")
    #ax.set_ylabel("Robustness delta value")
    ax.text(-0.1, 1.04, '%', horizontalalignment='right', verticalalignment='center', transform=ax.transAxes)
    
    ax.plot(time_mean[veritas_time_columns], ver_better.sum()[veritas_delta_columns]/n, marker="^", linestyle=":", label="Better")
    ax.plot(time_mean[veritas_time_columns], ver_worse.sum()[veritas_delta_columns]/n, marker="v", linestyle=":", label="Worse")
    ax.plot(time_mean[veritas_time_columns], ver_same.sum()[veritas_delta_columns]/n, marker=".", linestyle=":", label="Same")
    ax.axvline(x=time_mean[mer_time_columns[0]], ls="--", color="gray", label="Merge time")
    #ax.set_xscale("log")
    ax.legend()
    
    better_stats[d] = ver_better.sum()[veritas_delta_columns]/n
    worse_stats[d] = ver_worse.sum()[veritas_delta_columns]/n
    same_stats[d] = ver_same.sum()[veritas_delta_columns]/n

In [None]:
better_df = (pd.DataFrame(better_stats).transpose()*100).round(1)
worse_df = (pd.DataFrame(worse_stats).transpose()*100).round(1)
same_df = (pd.DataFrame(same_stats).transpose()*100).round(1)

In [None]:
display(better_df, worse_df, same_df)

# How many problems are solved in 1s, 2s, ...

In [None]:
datasets = ["covtype", "f-mnist", "higgs", "ijcnn1", "mnist", "mnist2v6"]
fig, axs = plt.subplots(1, len(datasets), figsize=(len(datasets)*1.4, 1.8), sharey=True, sharex=True)
fig.subplots_adjust(left=0.04, bottom=0.22, right=0.99
                    , top=0.7, wspace=0.1, hspace=0.4)
axs = axs.flatten()

better_stats = {}
worse_stats = {}
same_stats = {}

for d, ax in zip(datasets, axs):
    df = dfs[d]
    time_columns = [c for c in df.columns if c.endswith("time")]
    delta_columns = [c for c in df.columns if c.endswith("delta")]
    time_mean = df[time_columns].mean()
    veritas_time_columns = [c for c in df.columns if c.endswith("time") and c.startswith("veritas")]
    veritas_delta_columns = [c for c in df.columns if c.endswith("delta") and c.startswith("veritas")]
    mer_delta_columns = [c for c in df.columns if c.endswith("delta") and c.startswith("mext")]
    mer_time_columns = [c for c in df.columns if c.endswith("time") and c.startswith("mext")]
    
    times = np.linspace(0, 12, 50)
    div_from_opt = df[veritas_delta_columns+mer_delta_columns].subtract(df["kan_delta"], axis=0).abs()
    #q10, q50, q90 = pd.Series(div_from_opt.to_numpy().flatten()).quantile([0.25, 0.5, 1.0])
    #q50 = pd.Series(div_from_opt.to_numpy().flatten()).median()
    q50 = pd.Series(df[mer_delta_columns].subtract(df["kan_delta"], axis=0).abs().to_numpy().flatten()).mean()
    #print(d, "quantiles:", q10, q50, q90)
    #print("   how often do we see median?  ", (div_from_opt==q50).sum().sum()/len(df) * 100)
    #print("   how many unique error values?", len(pd.Series(div_from_opt.to_numpy().flatten()).unique()), np.prod(div_from_opt.shape))
    
    #in_time_ver = pd.concat([df[veritas_time_columns[0]]]*len(times), axis=1).le(times, axis=1)
    #in_time_ver.columns = [f"in_time{t:.2f}" for t in times]
    #in_time_mer = pd.concat([df[mer_time_columns[0]]]*len(times), axis=1).le(times, axis=1)
    #in_time_mer.columns = [f"in_time{t:.2f}" for t in times]
    in_time_kan = pd.concat([df["kan_time"]]*len(times), axis=1).le(times, axis=1)
    in_time_kan.columns = [f"in_time{t:.2f}" for t in times]
    
    in_time_ver10 = None
    in_time_ver50 = None
    in_time_ver90 = None
    for tcol, dcol in zip(veritas_time_columns, veritas_delta_columns):
        #x10 = pd.concat([(df[tcol]<=t) & ((df[dcol]-df["kan_delta"]).abs()<=q10) for t in times], axis=1)
        x50 = pd.concat([(df[tcol]<=t) & ((df[dcol]-df["kan_delta"]).abs()<=q50) for t in times], axis=1)
        #x90 = pd.concat([(df[tcol]<=t) & ((df[dcol]-df["kan_delta"]).abs()<=q90) for t in times], axis=1)
        #in_time_ver10 = (in_time_ver10 | x10) if in_time_ver10 is not None else x10
        in_time_ver50 = (in_time_ver50 | x50) if in_time_ver50 is not None else x50
        #in_time_ver90 = (in_time_ver90 | x90) if in_time_ver90 is not None else x90
    
    #in_time_mer10 = pd.concat([(df[mer_time_columns[0]]<=t) & ((df[mer_delta_columns[0]]-df["kan_delta"]).abs()<=q10) for t in times], axis=1)
    in_time_mer50 = pd.concat([(df[mer_time_columns[0]]<=t) & ((df[mer_delta_columns[0]]-df["kan_delta"]).abs()<=q50) for t in times], axis=1)
    #in_time_mer90 = pd.concat([(df[mer_time_columns[0]]<=t) & ((df[mer_delta_columns[0]]-df["kan_delta"]).abs()<=q90) for t in times], axis=1)
    
    n = len(df)
    #ax.set_title(f"{d} (n={n}, m={q50:.2g})")
    #ax.set_title(f"{d} (n={n})")
    ax.set_title(f"{d}")
    ax.set_xlabel("Time")
    #ax.set_ylabel("Robustness delta value")
    if d=="covtype":# or d=="ijcnn1":
        ax.text(-0.1, 1.09, '%', horizontalalignment='right', verticalalignment='center', transform=ax.transAxes)
    #lv, = ax.plot(times, in_time_ver.mean()*100, ls=(0, (2, 4)))
    #lm, = ax.plot(times, in_time_mer.mean()*100, ls=(0, (1, 4)))
    lv, = ax.plot(times, in_time_ver50.mean()*100, ls="-", label="Veritas")
    #ax.fill_between(times, in_time_ver10.mean()*100, in_time_ver90.mean()*100, alpha=0.1, color=lv.get_color())
    #ax.plot(times, in_time_ver10.mean()*100, ls=(0, (1, 4)), c=lv.get_color())
    #ax.plot(times, in_time_ver90.mean()*100, ls=(0, (1, 4)), c=lv.get_color())
    lm, = ax.plot(times, in_time_mer50.mean()*100, ls="--", label="Merge")
    #ax.fill_between(times, in_time_mer10.mean()*100, in_time_mer90.mean()*100, alpha=0.1, color=lm.get_color())
    #ax.plot(times, in_time_mer10.mean()*100, ls=(0, (1, 4)), c=lm.get_color())
    #ax.plot(times, in_time_mer90.mean()*100, ls=(0, (1, 4)), c=lm.get_color())
    ax.plot(times, in_time_kan.mean()*100, ls="-.", label="MILP")

axs[1].legend(ncol=3, bbox_to_anchor=(3.4, 1.6), fontsize="large")
#for ax in axs[3:]: ax.set_xlabel("Time")
plt.savefig("/tmp/solved_per_time.pdf")

# Tables

In [None]:
rows = {}

def map_name(n):
    if "kan" in n:
        return "MIPS"
    if "veritas" in n:
        #return f"$\\ouralg{{}}_{{{int(n[7:9])}}}$"
        return "\\ouralg{}"
    if "mext" in n:
        return "\\merge{}"

def which_column(d):
    if d == "f-mnist":
        return "veritas06"
    else:
        return "veritas02"

for i, (d, df) in enumerate(dfs.items()):
    time_columns = [c for c in df.columns if c.endswith("time") and (not c.startswith("veritas") or c.startswith(which_column(d)))]
    delta_columns = [c for c in df.columns if c.endswith("delta") and (not c.startswith("veritas") or c.startswith(which_column(d)))]
    time_mean = df[time_columns].mean()
    
    r1 = df[delta_columns].mean()
    r1[r1.index[1:]] /= r1[r1.index[0]]
    r1[r1.index[1:]] *= 100.0
    r1[r1.index[1:]] = [f"\\SI{{{x:.3g}}}{{\percent}}" for x in r1[r1.index[1:]]]
    r1.index = [map_name(n) for n in r1.index]
    r2 = df[time_columns].mean()
    r2.index = [map_name(n) for n in r2.index]
    r3 = df[time_columns].std()
    r3.index = [map_name(n) for n in r3.index]
    r4 = df[time_columns].mean()
    r4[r4.index[1:]] = r4[r4.index[0]] / r4[r4.index[1:]]
    r4[r4.index[0]] = ""
    r4[r4.index[1:]] = [f"\\SI{{{x:.0f}}}{{\times}}" for x in r4[r4.index[1:]]]
    r4.index = [map_name(n) for n in r4.index]
    
    
    rows[(d, "$\\delta$")] = r1
    rows[(d, "$t$")] = r2
    rows[(d, "$\\times$")] = r4
    rows[(d, "$\\sigma_t$")] = r3

means_df = pd.DataFrame(rows)
means_df = means_df.transpose()
means_df

In [None]:
print(means_df.to_latex(escape=False))

In [None]:
# Counts table
rows = {}

def map_name(n):
    if "veritas" in n:
        #return f"$\\ouralg{{}}_{{{int(n[7:9])}}}$"
        return f"Budget {int(n[7:9])}"

for i, (d, df) in enumerate(dfs.items()):
    time_columns = [c for c in df.columns if c.endswith("time") and c.startswith("ver")]
    delta_columns = [c for c in df.columns if c.endswith("delta") and c.startswith("ver")]
    mer_time_column = [c for c in df.columns if c.endswith("time") and c.startswith("mext")][0]
    mer_delta_column = [c for c in df.columns if c.endswith("delta") and c.startswith("mext")][0]
    
    kan_delta = df["kan_delta"]
    same_threshold = (kan_delta.quantile(0.6) - kan_delta.quantile(0.4)) / 1000
    
    mer_abs_diff = (df[mer_delta_column] - kan_delta).abs()
    r1 = df[delta_columns].subtract(kan_delta, axis=0).abs().lt(mer_abs_diff, axis=0).mean()
    r1 *= 100.0
    r1.index = [map_name(n) for n in r1.index]
    
    r2 = df[delta_columns].subtract(kan_delta, axis=0).subtract(mer_abs_diff, axis=0).le(-same_threshold).mean()
    r2 *= 100.0
    r2.index = [map_name(n) for n in r2.index]
    
    r3 = df[delta_columns].subtract(kan_delta, axis=0).subtract(mer_abs_diff, axis=0).ge(same_threshold).mean()
    r3 *= 100.0
    r3.index = [map_name(n) for n in r3.index]
    
    r4 = df[delta_columns].subtract(kan_delta, axis=0).subtract(mer_abs_diff, axis=0).abs().lt(same_threshold).mean()
    r4 *= 100.0
    r4.index = [map_name(n) for n in r4.index]
    
    r5 = df[time_columns].lt(df[mer_time_column], axis=0).mean()
    r5 *= 100.0
    r5.index = [map_name(n) for n in r5.index]
    
    r6_a = df[time_columns].gt(df[mer_time_column], axis=0)
    r6_b = df[delta_columns].subtract(kan_delta, axis=0).subtract(mer_abs_diff, axis=0).ge(same_threshold)
    r6_a.columns = [map_name(n) for n in r6_a.columns]
    r6_b.columns = [map_name(n) for n in r6_b.columns]
    r6 = (r6_a & r6_b).mean()
    r6 *= 100.0

    
    #rows[(d, "r1")] = r1
    rows[(d, "better")] = r2
    rows[(d, "worse")] = r3
    rows[(d, "same")] = r4
    rows[(d, "faster")] = r5
    rows[(d, "slower and worse")] = r6
    

counts_df = pd.DataFrame(rows)
counts_df = counts_df.transpose()
counts_df.round(1)

In [None]:
formatter = lambda x: f"\\SI{{{x:.1f}}}{{\percent}}"
print(counts_df.to_latex(escape=False, formatters=[formatter] * 5))

In [None]:
time_columns = [c for c in df.columns if c.endswith("time")]
delta_columns = [c for c in df.columns if c.endswith("delta")]
veritas_time_columns = [c for c in df.columns if c.endswith("time") and c.startswith("veritas")]
veritas_delta_columns = [c for c in df.columns if c.endswith("delta") and c.startswith("veritas")]
mer_time_columns = [c for c in df.columns if c.endswith("time") and c.startswith("mext")]
mer_delta_columns = [c for c in df.columns if c.endswith("delta") and c.startswith("mext")]

In [None]:
delta_mean = df[delta_columns].mean()
time_mean = df[time_columns].mean()

In [None]:
df[delta_columns]

In [None]:
plt.errorbar(time_mean, delta_mean, marker="x", ls="", xerr=df[time_columns].std())#, yerr=df[delta_columns].std())

In [None]:
time_mean[veritas_time_columns]

In [None]:
div_from_opt = df[delta_columns].subtract(df["kan_delta"], axis=0).abs().mean()

In [None]:
plt.title("Mean absolute difference of delta value")
plt.plot(time_mean[veritas_time_columns], div_from_opt[veritas_delta_columns], marker=".", linestyle=":", label="Veritas")
l, = plt.plot(time_mean[mer_time_columns], div_from_opt[mer_delta_columns], marker="o", linestyle=":", label="Merge")
plt.axhline(y=div_from_opt[mer_delta_columns][0], c=l.get_color(), ls=l.get_linestyle())
l, = plt.plot(time_mean[["kan_time"]], div_from_opt[["kan_delta"]], marker="x", linestyle=":", label="MILP")
plt.axhline(y=div_from_opt["kan_delta"], c=l.get_color(), ls=l.get_linestyle())
plt.legend()

In [None]:
plt.title("Mean delta value")
plt.plot(time_mean[veritas_time_columns], delta_mean[veritas_delta_columns], marker=".", linestyle=":", label="Veritas")
l, = plt.plot(time_mean[mer_time_columns], delta_mean[mer_delta_columns], marker="o", linestyle=":", label="Merge")
plt.axhline(y=delta_mean[mer_delta_columns][0], c=l.get_color(), ls=l.get_linestyle())
l, = plt.plot(time_mean[["kan_time"]], delta_mean[["kan_delta"]], marker="x", linestyle=":", label="MILP")
plt.axhline(y=delta_mean["kan_delta"], c=l.get_color(), ls=l.get_linestyle())
plt.legend()

In [None]:
time_mean

In [None]:
df[delta_columns].subtract(df["kan_delta"], axis=0).describe()

In [None]:
df[df["kan_delta"]>20]

In [None]:
dfs["f-mnist"] = load_to_df("/home/laurens/repos/veritas/tests/experiments/results/r1-f-mnist-time2*")

In [None]:

df = dfs["mnist"]
