In [10]:
from IPython.display import display, Markdown, HTML
import plotly.express as px
import itertools

from azureml.core import Run, Model
from azureml.core import Datastore, Experiment, ScriptRunConfig, Workspace, Run

from model_drift import settings
from model_drift.helpers import column_xs, correlate_performance, mutual_info_performance, w_avg
import pandas as pd
import os
import datetime



In [11]:
import pandas as pd
import six


ws = Workspace.from_config(settings.AZUREML_CONFIG)
experiment_name = 'generate-drift-csv-label-mod-dbg'
exp = Experiment(workspace=ws, name=experiment_name)


def run_to_dict(run):
    d = dict(**run.tags)
    d['id'] = run.id
    d['display_name'] = run.display_name
    d['url'] = run.get_portal_url()
    d['run'] = run
    # d["startTimeUtc"] = pd.to_datetime(run.get_details()["startTimeUtc"])
    # d["endTimeUtc"] = pd.to_datetime(run.get_details()["endTimeUtc"])
    return d


def experiment_to_dataframe(experiment, workspace=None):
    
    if isinstance(experiment, six.string_types):
        if workspace is None:
            raise ValueError("if experiment is string, must provide workspace")
        experiment = Experiment(workspace=workspace, name=experiment)
    df = []
    for run in exp.get_runs():
        if run.status != "Completed":
            continue
        df.append(run_to_dict(run))
    return pd.DataFrame(df).set_index(['display_name'])#.sort_values("endTimeUtc", ascending=False)

df = experiment_to_dataframe(exp)

run = Run(experiment=Experiment(workspace=ws, name='generate-drift-csv-3'), 
          run_id="generate-drift-csv-3_1639942528_d9f3ee90") # "tender_pear_lfbd6
df = df[~df['mod_end_date'].isnull()]

# Baseline
d = run_to_dict(run)
d["label_modifiers"] = "Baseline"
df = df.append(pd.Series(d, name=d.pop("display_name")))

df.head()

Unnamed: 0_level_0,_aml_system_ComputeTargetStatus,mlflow.source.type,mlflow.source.name,run_azure,input_dir,output_dir,dataset,vae_dataset,classifier_dataset,vae_filter,...,num_workers,dbg,id,url,run,frontal_remove_date,nonfrontal_add_date,peds_end_date,peds_start_date,peds_weight
display_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
sad_caravan_nnryywcw,"{""AllocationState"":""steady"",""PreparingNodeCoun...",JOB,scripts/drift/generate-drift-csv-label.py,1,/mnt/batch/tasks/shared/LS_root/jobs/mlops_sha...,./outputs/,padchest,padchest-trained,padchest-finetuned-chx-frontalonly,all-data,...,10,1,generate-drift-csv-label-mod-dbg_1642612521_b3...,https://ml.azure.com/runs/generate-drift-csv-l...,Run(Experiment: generate-drift-csv-label-mod-d...,,,,,
serene_head_v04mgvx4,"{""AllocationState"":""steady"",""PreparingNodeCoun...",JOB,scripts/drift/generate-drift-csv-label.py,1,/mnt/batch/tasks/shared/LS_root/jobs/mlops_sha...,./outputs/,padchest,padchest-trained,padchest-finetuned-chx-frontalonly,all-data,...,10,1,generate-drift-csv-label-mod-dbg_1642612530_ae...,https://ml.azure.com/runs/generate-drift-csv-l...,Run(Experiment: generate-drift-csv-label-mod-d...,,,,,
salmon_cloud_wg3s2fnq,"{""AllocationState"":""steady"",""PreparingNodeCoun...",JOB,scripts/drift/generate-drift-csv-label.py,1,/mnt/batch/tasks/shared/LS_root/jobs/mlops_sha...,./outputs/,padchest,padchest-trained,padchest-finetuned-chx-frontalonly,all-data,...,10,1,generate-drift-csv-label-mod-dbg_1642612524_fb...,https://ml.azure.com/runs/generate-drift-csv-l...,Run(Experiment: generate-drift-csv-label-mod-d...,,,,,
affable_plum_cm6y51d3,"{""AllocationState"":""steady"",""PreparingNodeCoun...",JOB,scripts/drift/generate-drift-csv-label.py,1,/mnt/batch/tasks/shared/LS_root/jobs/mlops_sha...,./outputs/,padchest,padchest-trained,padchest-finetuned-chx-frontalonly,all-data,...,10,1,generate-drift-csv-label-mod-dbg_1642612534_c7...,https://ml.azure.com/runs/generate-drift-csv-l...,Run(Experiment: generate-drift-csv-label-mod-d...,,,,,
musing_lock_mgsp9xps,"{""AllocationState"":""steady"",""PreparingNodeCoun...",JOB,scripts/drift/generate-drift-csv-label.py,1,/mnt/batch/tasks/shared/LS_root/jobs/mlops_sha...,./outputs/,padchest,padchest-trained,padchest-finetuned-chx-frontalonly,all-data,...,10,1,generate-drift-csv-label-mod-dbg_1642461065_3b...,https://ml.azure.com/runs/generate-drift-csv-l...,Run(Experiment: generate-drift-csv-label-mod-d...,,,,,


In [12]:
html_top_dir = settings.TOP_DIR.joinpath("html", "graphs_paper")
html_top_dir.mkdir(exist_ok=True)

In [13]:
def is_arg_col(c):
    if "mlflow" in c or "_aml" in c or 'run_' in c or 'url' in c:
        return False
    ignore = ["output_dir", "input_dir", 'run', 'display_name', 'id']
    return c not in ignore

def experiment_df_for_display(df, remove_duplicates=True, remove_const_columns=True):
    arg_cols = [c for c in df.columns if is_arg_col(c)]
    arg_df = df[arg_cols]
    
    if remove_const_columns:
        arg_df = arg_df[[c for c in arg_df if arg_df[c].fillna('NA').nunique() > 1]]
    
    
    if remove_duplicates:    
        arg_df[~arg_df.duplicated(keep='last')]
    
    return arg_df

arg_df = experiment_df_for_display(df)
arg_df['Link'] = [f"""<a href="{name}/index.html" disabled=>Graphs</a>""" if html_top_dir.joinpath(name).exists() else "N/A" for name in arg_df.index]

arg_df

Unnamed: 0_level_0,label_modifiers,mod_end_date,randomize_start_date,dbg,peds_weight,Link
display_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
sad_caravan_nnryywcw,"{""Opacity"": [0.75, ""2014-04-01"", ""2014-08-15""]...",2014-12-31,,1,,
serene_head_v04mgvx4,"{""Cardiomegaly"": [0.75, ""2014-04-01"", ""2014-08...",2014-12-31,,1,,
salmon_cloud_wg3s2fnq,"{""Pleural Abnormalities"": [0.75, ""2014-04-01"",...",2014-12-31,,1,,
affable_plum_cm6y51d3,"{""Atelectasis"": [0.75, ""2014-04-01"", ""2014-08-...",2014-12-31,,1,,
musing_lock_mgsp9xps,"{""Opacity"": [0.75, ""2014-06-01"", ""2014-09-15""]...",2014-12-31,,1,,
mighty_beard_vvtgq0q9,"{""Pleural Abnormalities"": [0.75, ""2014-06-01"",...",2014-12-31,,1,,
jolly_tangelo_ch7q56v0,"{""Atelectasis"": [0.75, ""2014-06-01"", ""2014-09-...",2014-12-31,,1,,
keen_circle_9wktjvvd,"{""Atelectasis"": [0.75, ""2014-06-01"", ""2014-09-...",2014-12-31,,1,,
purple_toe_dnntc2j8,,2014-12-31,2014-06-01,1,,
goofy_dress_ztmzmh27,,2014-12-31,2014-04-01,1,,


In [14]:
fix_links_script = """
              <script>
                var x = document.getElementsByTagName('a');
                var i;
                for (i = 0; i < x.length; i++) {{
                    let url = x[i].getAttribute("href");
                    x[i].href = url + window.location.search;
                }}
                </script>
              """

In [15]:
def get_run(display_name, experiment):
    for run in experiment.get_runs():
        if run.display_name == display_name:
            return run
        
    raise KeyError(f"'{display_name}' not found in experiment!")

In [16]:
# # run_row = df.loc[[not html_top_dir.joinpath(name).exists() for name in df.index]].iloc[0]
# # run_name, verbose_name = "tender_pear_lfbd6wwg", "baseline"
# # run_name, verbose_name = "tender_pear_lfbd6wwg", "baseline_2row"
# # run_name, verbose_name = "orange_scooter_dry7q6y9", "inject-laterals"


# runs = {"baseline": ("tender_pear_lfbd6wwg", "generate-drift-csv-3")}
# verbose_name = "baseline"
# run_name, experiment = runs[verbose_name]
# experiment = Experiment(ws, experiment)
# run = get_run(run_name, experiment)

# # Diplay settings
# span = 7
# which = 'mean'
# clip = 10
# performance_col = ("performance", "macro avg", "auroc")
# congruency_measure_col = ('in_distro', 'stats', 'mean')
# add_error_bars = True

# standardize_dates = (settings.PADCHEST_SPLIT_DATES[0], settings.PADCHEST_SPLIT_DATES[1])
# standardize_ix = pd.date_range(*standardize_dates)
# stat = []
# # stat.append('pval')
# stat.append('distance')

# verbose_name = verbose_name + '_' + '+'.join(stat) 

# write_html = True
# graph_start = "2014-01-01"
# graph_end = "2014-12-31"


# font=dict(size=12)

# print(run_name, verbose_name)


In [17]:
import itertools
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
from matplotlib import colors as mpl_colors
from collections import defaultdict

def to_rgba(rgb, alpha=None):
    rgb = mpl_colors.to_rgb(rgb)
    if alpha is None:
        return "rgb(%s, %s, %s)" % (rgb[0], rgb[1], rgb[2])
    return "rgba(%s, %s, %s, %s)" % (rgb[0], rgb[1], rgb[2], alpha)

def line_maker(color, **l):
    return dict(color=color, **l)

def marker_maker(color, **l):
    return dict(color=color)

def smooth(y: pd.DataFrame, span=7):
    if span > 0:
        ys = y.ewm(span=span, ignore_na=False).mean()
        ys[y.isna()] = None
    else:
        ys = y    
    return ys

def add_date_line(fig, date, name, y=1.08):
    fig.add_shape(type='line',
                x0=date,
                y0=0,
                x1=date,
                y1=1,
                line=dict(color='black', dash='dot'),
                xref='x',
                yref='paper'
                )
    fig.add_annotation(textangle=0,
                    xref="x",
                    yref="paper", x=date, y=y,
                       text=name, showarrow=False,
                       font=dict(size=18))

def add_dates(fig, dates, line_y=1.05, include_date=True):
    for name, date in dates.items():
        if not pd.isna(date):
            n = f"{name}<br />({date})" if include_date else name
            add_date_line(fig, date, n, y=line_y)
            
def collect_corr(y, yp, name, when, weights_name, start_date=None, end_date=None):
    yp = yp.loc[start_date: end_date]
    y = y.loc[start_date: end_date]
    c, cm = yp.corr(y), smooth(yp).corr(smooth(y))
    return {"name": name, "weights_name": weights_name,
                "corr (raw)": c, "corr (smoothed)": cm, "when": when}

class FigureHelper(object):
    
    def __init__(self, x=None, color_list=px.colors.qualitative.Plotly, dashes=('solid',), smooth_func=smooth, merge_hover=True):
        self.traces = []
        self.error_traces = []
        self.color_list = color_list
        self.line_picker = itertools.cycle(itertools.product(dashes, self.color_list))
        self.lines = defaultdict(lambda: dict(zip(['dash', 'color'], next(self.line_picker))))
        self.names = set()
        self.smooth = smooth_func
        self.x = x
        self.merge_hover = merge_hover
        
    def set_line(self, key, line=None):
        line = line or {}
        self.lines[key] = self.lines[key]
        self.lines[key].update(line)
        self.lines[key]['color'] = self.lines[key]['color']
        return self.lines[key]
        
        
    def make_error_traces(self, x, yu, yl, name, color, alpha):
        
        
        # need to remove nans from error traces
        k = ~(yu.isnull()|yl.isnull())
        xe = x[k]
        yl = yl[k]
        yu = yu[k]
        
        return [go.Scatter(x=xe, 
                            y=yu, 
                            hoverinfo="skip",
                            showlegend=False,
                            legendgroup=name,
                            name=name,
                            connectgaps=False,
                            line=dict(width=0),
                ), 
                go.Scatter(x=xe, 
                            y=yl,
                            fillcolor=to_rgba(color, alpha),
                            fill='tonexty',
                            hoverinfo="skip",
                            showlegend=False,                            
                            legendgroup=name,
                            name=name,
                            connectgaps=False,
                            line=dict(width=0),
                )]

    def add_trace(self, y, name, x=None, kind=go.Scatter, color_key=None, row=1, col=1, line=None,
                  std=None, yu=None, yl=None, **trace_kwargs):
        color_key = color_key or name
        trace_kwargs.setdefault('showlegend', name not in self.names)
        self.names.add(name)
        trace_kwargs.setdefault('legendgroup', name)
        
        line = self.set_line(color_key, line)
        x = x or self.x
        y = y.reindex(x)
        t = kind(x=x, y=y, name=name, **trace_kwargs)
        if not isinstance(t, go.Bar):
            t.line = line_maker(**line)
        else:
            t.marker = marker_maker(**line)
            
    
        self.traces.append((row, col, t))
        
        if std is not None:
            yu = y+std
            yl = y-std
            
        if yu is not None and yl is not None:
            for t_ in self.make_error_traces(x, yu, yl, name=name, color=line["color"], alpha=0.2):
                self.error_traces.append((row, col, t_))
    
    
    def add_bar(self, y, name, x=None, color_key=None, row=1, col=1, line=None, include_line=True,
                **trace_kwargs):
        
        if include_line:
            self.add_trace(y=y, name=name, color_key=color_key, line=line, row=row, col=col, **trace_kwargs)
        self.add_trace(y=y, name=name, color_key=color_key, kind=go.Bar, line=line, row=row, col=col, **trace_kwargs)
        
        
    
    def make_fig(self, **fig_kwargs):

        data = {}
        max_row = 1
        max_col = 1
        for r, c, t in self.traces:
            max_row = max(r, max_row)
            max_col = max(c, max_col)
            data[t.name] = pd.Series(t.y, index=t.x)
            
        customdata = pd.DataFrame(data)
        fig = make_subplots(rows=max_row, cols=max_col, **fig_kwargs)
        for r, c, t in self.traces:
            if self.merge_hover:
                cus_cols = sorted(customdata)
                ho = "<br />".join(["{name}=%{{customdata[{i}]:.3f}}".format(i=i, name=name) for i,name in enumerate(cus_cols)])
                hovertemplate = "%{x}<br>" + f"{t.name}: " +"%{y}<br><br>"+f"{ho}<extra></extra>"
                t.customdata = customdata[cus_cols]
                t.hovertemplate = hovertemplate
            # t.hoverlabel = {'bgcolor': 'white'}
            fig.add_trace(t, row=r, col=c)
                        
        for r, c, t in self.error_traces:
            fig.add_trace(t, row=r, col=c)
        return fig
        
        
        
        

In [23]:
def download_and_prepare_run_data(run, which="mean"):
    output_file_path = settings.TOP_DIR.joinpath('results', 'drift', run.display_name+".csv")
    fname = str(output_file_path)
    run.download_file("outputs/output.csv", output_file_path=output_file_path)
    
    combined_df = pd.read_csv(str(fname), index_col=0, header=[0, 1, 2, 3])
    combined_df.index = pd.to_datetime(combined_df.index)
    flip = column_xs(combined_df, include=["pval"])
    combined_df[flip] = 1-combined_df[flip]
    
    error_df = combined_df.swaplevel(0, -1, axis=1)[["std"]].swaplevel(0, -1, axis=1).droplevel(-1, axis=1).copy()
    combined_df = combined_df.swaplevel(0, -1, axis=1)[[which]].swaplevel(0, -1, axis=1).droplevel(-1, axis=1).copy()
    
    return error_df, combined_df


def calc_stats(other_df, standardize_dates):
    standardize_ix = pd.date_range(*standardize_dates)
    stats = other_df.dropna(axis=1).reindex(standardize_ix)
    stats = stats.agg(["mean", "std"])
    return stats

def standardize(other_df, standardize_dates=None, stats=None, clip=None) -> pd.DataFrame:
    
    if stats is None:
        stats = calc_stats(other_df, standardize_dates)
    otherstd = other_df.copy()
    # cannot divide by zero
    std0 = stats.loc['std'] == 0
    stats.loc["std", stats.loc['std'] == 0] = 1
    
    otherstd = (otherstd-stats.loc['mean'])/(stats.loc["std"])
    bad_cols = otherstd.columns[otherstd.isnull().max(axis=0)].tolist()
    # print("bad columns after standardization:", bad_cols)
    
    if clip is not None:
        otherstd = otherstd.clip(-1*clip, clip)
    
    return otherstd

def calculate_weights(yp, otherstd) -> pd.DataFrame:
    all_corr_df = correlate_performance(yp.rename('auroc'), otherstd)
    all_ig_df = mutual_info_performance(yp.rename('auroc'), otherstd, bins=25)
    m_ = all_ig_df.to_frame().join(all_corr_df.apply(lambda x: max(0, x)).rename('corr')).join(all_corr_df.abs().rename('abs(corr)'))
    m_ = m_.join(m_.mean(axis=1).rename('mean[abs(corr),info_gain]'))
    m_ = m_.assign(no_weights=1)
    m_ = m_.fillna(0)
    
    return m_

class DriftFromRun(object):
    
    which = "mean"
    standardize_dates = (settings.PADCHEST_SPLIT_DATES[0], settings.PADCHEST_SPLIT_DATES[1])
    stat = ['distance']
    span = 7
    clip = 10
    performance_col = ("performance", "micro avg", "auroc")
    # performance_col = ("performance", "Pneumonia", "auroc")
    # performance_col = ("performance", "Pleural Abnormalities", "auroc")
    graph_start = "2014-01-01"
    graph_end = "2014-12-31"
    
    def __init__(self, run):
        
        error_df, combined_df = download_and_prepare_run_data(run, which=self.which)
        perf_df = combined_df[self.performance_col]
        other_cols = column_xs(combined_df, exclude=['performance', 'count'])
        other_df = combined_df[other_cols]
        cxs = column_xs(other_df, include=self.stat)
        otherstd = standardize(other_df[cxs], self.standardize_dates, clip=self.clip)
        
        self.standardize_ix = pd.date_range(*self.standardize_dates)
        
        count_cols = column_xs(combined_df, ['support'])
        count_cols = [c for c in count_cols if 'avg' not in c[1]]
        label_counts = combined_df[count_cols]
        label_counts.columns = [c[1] for c in label_counts.columns]
        num_samples = combined_df['count'].iloc[:, 0]
        
        self.label_rates = label_counts.div(num_samples, axis=0)
        
        vae_cols = [c for c in list(otherstd) if "mu." in c[0]]
        score_cols = [c for c in list(otherstd) if "activation." in c[0]]
        metadata_cols = sorted(set(otherstd).difference(vae_cols).difference(score_cols))
        
        self.other_cols = other_cols
        self.vae_cols = vae_cols
        self.score_cols = score_cols
        self.metadata_cols = metadata_cols
        
        self.combined_df = combined_df
        self.otherstd = otherstd
        self.perf_df = perf_df
        self.perf_error_df = error_df[self.performance_col]
        self.m_ = calculate_weights(perf_df.reindex(self.standardize_ix), otherstd.reindex(self.standardize_ix))
        
    def unify_metrics(self, cols=None, weights="abs(corr)", start=None, end=None, m=None):
        if cols is None:
            cols = self.otherstd.columns.tolist()
        if m is None:
            m=self.m_
        return -w_avg(self.otherstd[cols], weights=m[weights].to_dict())
mega_m = pd.read_pickle("mega_m.pkl")

In [44]:
import json

figure_name = "_".join(["label", "+".join(DriftFromRun.stat), DriftFromRun.performance_col[1].replace(' ', '-').lower()])
runs = {
    "Baseline": ("tender_pear_lfbd6wwg", "generate-drift-csv-3"),
    "Trial 1": ("musing_lock_mgsp9xps", "generate-drift-csv-label-mod-dbg"),
    "Trial 2": ("mighty_beard_vvtgq0q9", "generate-drift-csv-label-mod-dbg"),
    "Trial 3": ("keen_circle_9wktjvvd", "generate-drift-csv-label-mod-dbg"),
    }
verbose_name = "baseline"
graph_start = "2014-03-01"
graph_end = "2015-01-01"


# figure_name = "_".join(["label-t2", "+".join(DriftFromRun.stat), DriftFromRun.performance_col[1].replace(' ', '-').lower()])
# runs = {
#     "Baseline": ("tender_pear_lfbd6wwg", "generate-drift-csv-3"),
#     "Trial 1": ("patient_hominy_8v6p78yh", "generate-drift-csv-label-mod-dbg"),
#     "Trial 2": ("ivory_branch_d350x4yq", "generate-drift-csv-label-mod-dbg"),
#     "Trial 3": ("red_answer_rdp62h0t", "generate-drift-csv-label-mod-dbg"),
#     }
# verbose_name = "baseline"
# graph_start = "2014-03-01"
# graph_end = "2015-01-01"

x = pd.date_range(pd.to_datetime(graph_start)-pd.DateOffset(n=30), pd.to_datetime(graph_end)+pd.DateOffset(n=30))

runs_ = {}
builds = {}
score_graphs = {}
perf_graphs = {}
score_graphs_mega = {}

dates = {}
S = []
for name, (run_name, experiment) in runs.items():
    print(name)
    experiment = Experiment(ws, experiment)
    run = get_run(run_name, experiment)
    build = DriftFromRun(run)
    runs_[name] = run
    builds[name] = build
    score_graphs[name] = build.unify_metrics(build.score_cols).reindex(x)
    perf_graphs[name] = pd.concat([build.perf_df.reindex(x).rename("yp"), 
                                   build.perf_error_df.reindex(x).rename("ye")], axis=1)
    
    score_graphs_mega[name] = build.unify_metrics(build.score_cols, m=mega_m).reindex(x) 
    
    k = name
    if "label_modifiers" in run.tags:
        lm = json.loads(run.tags["label_modifiers"])
        keys = list(lm.keys())
        s = []
        for kk, p in zip(keys, "ABCDE"):
            if kk == "No Finding": continue
            s.append(f"{p}={kk} to {lm[kk][0]:.0%}".replace("%", r"\%"))
            dates[p] = lm[kk][1]
        s = ", ".join(s)
        S.append(f"({k}) {s}")
        
        # print(f"({k}) A={keys[0]} to {lm[keys[0]][0]:.0%}, B={keys[1]} to {lm[keys[1]][0]:.0%}. ".replace("%", r"\%"))
    else:
        S.append(f"({k}) No modification at A or B. ")

print('\n'.join(S))
dates

Baseline



merging between different levels can give an unintended result (3 levels on the left,1 on the right)



Trial 1



merging between different levels can give an unintended result (3 levels on the left,1 on the right)



Trial 2



merging between different levels can give an unintended result (3 levels on the left,1 on the right)



Trial 3



merging between different levels can give an unintended result (3 levels on the left,1 on the right)



(Baseline) No modification at A or B. 
(Trial 1) A=Opacity to 75\%, B=Pleural Effusion to 75\%
(Trial 2) A=Pleural Abnormalities to 75\%, B=Pneumonia to 75\%
(Trial 3) A=Atelectasis to 75\%, B=Lesion to 75\%


{'A': '2014-06-01', 'B': '2014-09-16'}

In [46]:

fh = FigureHelper(x)


row=1
for name, score_unify in perf_graphs.items():
    fh.add_trace(y=smooth(score_unify["yp"], span=7), 
                 yu=smooth(score_unify["yp"]+score_unify["ye"], span=7),
                 yl=smooth(score_unify["yp"]-score_unify["ye"], span=7), name=name, connectgaps=False, row=row, col=1)

row=2
for name, score_unify in score_graphs_mega.items():
    fh.add_trace(y=smooth(score_unify, span=7), name=name, connectgaps=False, row=row, col=1)

fig = fh.make_fig(shared_xaxes=True, vertical_spacing=0.025, row_heights=[.2]*row)

add_dates(fig, dates, line_y=1.045, include_date=False)

fig.update_xaxes(showspikes=True, spikecolor="black", spikesnap="cursor", spikemode="across", spikethickness=1)
fig.update_layout(spikedistance=1000)
fig.update_layout(height=300*row)
fig.update_layout(legend={"x": .5, "orientation":"h", "borderwidth": .5, "xanchor": "center",})
fig.update_xaxes(range=[graph_start, graph_end])
fig.update_layout(yaxis1=dict(
    # range=[0.8, 1], 
    title=f"AUROC ({DriftFromRun.performance_col[1].title()})"),
    yaxis2=dict(range=[-10, 1], title=r"$MMC_w(\text{Score})$"))
fig.update_layout(barmode='overlay')
fig.update_layout(font=dict(size=13))
fig.update_layout(plot_bgcolor="#E8E8EA")

# fig.update_layout(
#     yaxis2 = dict(
#         tickmode = 'linear',
#         tick0 = 0,
#         dtick = 25,
#     )
# )

xaxis = dict(
        tickformat = '%Y-%m-%d',
        tickmode = 'linear',
        dtick = "M1"
    )
fig.update_layout(
    xaxis1=xaxis,
    xaxis2=xaxis,
    xaxis3 = xaxis,
)

fig.show()

In [47]:
import plotly.io as pio

html_top_dir = settings.TOP_DIR.joinpath("html", "paper", 'graphs')
html_top_dir.mkdir(exist_ok=True, parents=True)

fname = html_top_dir.joinpath(figure_name+".svg")
pio.write_image(fig, fname, scale=1, width=5*300, height=2*300)
print(fname)

relfname = os.path.relpath(str(fname), os.getcwd())
display(HTML(f"""<img src="{relfname}" />"""))

D:\Code\MLOpsDay2\MedImaging-ModelDriftMonitoring\html\paper\graphs\label_distance_micro-avg.svg


In [43]:
raise

RuntimeError: No active exception to reraise

In [None]:

output_file_path = settings.TOP_DIR.joinpath('results', 'drift', name+".csv")
fname = str(output_file_path)
r.download_file("outputs/output.csv", output_file_path=output_file_path)
# # Settings to file CSV file

# display_args = ["span", "which", "clip", "standardize_perf", "shift_drift_to_perf", "performance_col", "this_center", "this_range", "standardize_dates", "stat", "add_error_bars"]
d = locals()
display_args = {k: d[k] for k in display_args if k in d}

print(display_args)


if not os.path.exists(fname):
    raise ValueError("no fn")

combined_df_o = pd.read_csv(str(fname), index_col=0, header=[0, 1, 2, 3])
combined_df_o.index = pd.to_datetime(combined_df_o.index)

flip = column_xs(combined_df_o, include=["pval"])
combined_df_o[flip] = 1-combined_df_o[flip]
combined_df = combined_df_o.copy()


smooth_name = f"ewm{span}"

error_df = combined_df.swaplevel(0, -1, axis=1)[["std"]].swaplevel(0, -1, axis=1).droplevel(-1, axis=1).copy()
combined_df = combined_df.swaplevel(0, -1, axis=1)[[which]].swaplevel(0, -1, axis=1).droplevel(-1, axis=1).copy()

html_dir = html_top_dir.joinpath(name)
perf_col_name = '-'.join(performance_col)

if not os.path.exists(html_dir):
    os.makedirs(html_dir)

stat_str = '+'.join(stat)
fn = f"{html_dir}/{which}_{stat_str}_stdclip{clip}_smooth-{smooth_name}_{perf_col_name}.html"

print("output:", fn)
def is_arg_col(c):
    if "mlflow" in c or "_aml" in c or 'run_' in c or 'url' in c:
        return False

    ignore = ["output_dir", "input_dir", 'run', 'display_name', 'id']
    return c not in ignore

arg_row = run_row[arg_cols].copy()
display_row = pd.Series(display_args)
params = pd.concat({'Drift': arg_row, "Display": display_row}, axis=0).rename("Value").to_frame()

if write_html:
    with open(fn, 'w') as f:
        print("""
            <style>
            table {
            font-family: arial, sans-serif;
            border-collapse: collapse;
            width: 80%;
            }

            td, th {
            border: 1px solid #dddddd;
            text-align: left;
            padding: 8px;
            }

            tr:nth-child(even) {
            background-color: #dddddd;
            }
            </style>
        """, file=f)
        print(f"""
            <h1>Drift report</h1> created: {datetime.datetime.now()}
            <br /><br />
            <h2>Arguments </h2>
            {params.to_html()}
            """, file=f)

def shift_to_other(this, other, this_range=None, this_center=None):
    u = other.mean()
    r = other.std()#other.max()-other.min()

    if this_range is None:
        this_range = this.std()#this.max()-this.min()

    if this_center is None:
        this_center = this.mean()
    return (this-this_center)/(this_range)*r+u

perf_col = performance_col
perf_df = combined_df[perf_col]


other_cols = column_xs(combined_df, exclude=['performance', 'count'])
other_df = combined_df[other_cols]


cxs = column_xs(other_df, include=stat)


stats = pd.concat([other_df[cxs].dropna(axis=1), extra_valids[cxs].dropna(axis=1)], axis=0).sort_index()
stats = stats.loc[standardize_ix]
stats = stats.agg(["mean", "std"])


stats.T

otherstd = other_df[cxs].copy()

# cannot divide by zero
std0 = stats.loc['std'] == 0
stats.loc["std", stats.loc['std'] == 0] = 1
otherstd = (otherstd-stats.loc['mean'])/(stats.loc["std"])
errorstd = (error_df[cxs]-stats.loc['mean'])/(stats.loc["std"])
bad_cols = otherstd.columns[otherstd.isnull().max(axis=0)].tolist()

print(bad_cols)

vae_cols = [c for c in list(otherstd) if "mu." in c[0]]
score_cols = [c for c in list(otherstd) if "activation." in c[0]]
metadata_cols = sorted(set(otherstd).difference(vae_cols).difference(score_cols))

if clip is not None:
  otherstd = otherstd.clip(-1*clip, clip)

In [None]:
x = pd.date_range(combined_df.index.min(), combined_df.index.max())


yp = combined_df[perf_col].reindex(x)
perf_error_df = error_df[perf_col].reindex(x)

all_corr_df = correlate_performance(yp.rename('auroc'), otherstd)
all_ig_df = mutual_info_performance(yp.rename('auroc'), otherstd, bins=25)
m_ = all_ig_df.to_frame().join(all_corr_df.abs().rename('abs(corr)'))
m_ = m_.join(m_.mean(axis=1).rename('mean[abs(corr),info_gain]'))
m_ = m_.assign(no_weights=1)
m_ = m_.fillna(0)

true_counts = combined_df_o['count'].droplevel([0, 1], axis=1)['obs']
count_df = combined_df['count'].reindex(x)


dates = {"Laterals Injected": run_row['nonfrontal_add_date'],
         "Frontals Removed": run_row['frontal_remove_date'],
         "Peds Added": run_row['peds_start_date'],
         "Peds Stop": run_row['peds_end_date'],
         "Val Start": settings.PADCHEST_SPLIT_DATES[0],
        #  "Test Start": settings.PADCHEST_SPLIT_DATES[1],
         }

x = x[x<graph_end]

In [None]:
weight_names = {"no_weights": r"$MMC$", "abs(corr)": r"$MMC_w$",}

In [None]:
m = m_.copy()

yp = yp.reindex(x)
otherstd = otherstd.reindex(x)
counts = combined_df['count'].iloc[:, 0].reindex(x)
counts2 = true_counts.reindex(x)

#collect_corr(y, yp, name, when, weights_name, start_date=None, end_date=None)
row=1
fh = FigureHelper(x)
fh.add_trace(y=smooth(yp), name="AUROC", connectgaps=False, line={"color": "blue"}, 
             yu=smooth(yp+perf_error_df.reindex(x)), 
             yl=smooth(yp-perf_error_df.reindex(x)),
             row=row, col=1)
# row += 1
# fh.add_trace(y=smooth(combined_df[congruency_measure_col]*100), row=row, name='% In-distr.', line={"color": "purple"})
row += 1
errorstd = errorstd[otherstd.columns]

for i, (name, vname) in enumerate(weight_names.items()):
    weights = m_[name].sort_values(ascending=False)
    # weights = weights.iloc[:5]
    y = -w_avg(otherstd.reindex(x), weights=weights.to_dict())
    ystd = -w_avg(errorstd.reindex(x), weights=weights.to_dict())
    fh.add_trace(y=smooth(y),
                        # customdata=smooth(yo),
                        showlegend=True, legendgroup=vname,
                        name=vname,  
                        connectgaps=False, row=row, col=1)

# display(HTML("""<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
#     <script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>"""))

fig = fh.make_fig(shared_xaxes=True, vertical_spacing=0.025, row_heights=[.2]*row)
add_dates(fig, dates, line_y=1.045, include_date=False)
fig.update_xaxes(showspikes=True, spikecolor="black", spikesnap="cursor", spikemode="across", spikethickness=1)
fig.update_layout(spikedistance=1000)
fig.update_layout(height=200*row)
fig.update_layout(legend={"x": .35, "orientation":"h", "borderwidth": .5})
fig.update_xaxes(range=[graph_start, graph_end])
fig.update_layout(barmode='overlay')
fig.update_layout(font=dict(size=13))
fig.update_layout(plot_bgcolor="#E8E8EA")

fig.update_layout(
    yaxis2 = dict(
        tickmode = 'linear',
        tick0 = 0,
        dtick = 25,
    )
)

xaxis = dict(
        tickformat = '%Y-%m-%d',
        tickmode = 'linear',
        dtick = "M1"
    )
fig.update_layout(
    xaxis1=xaxis,
    xaxis2=xaxis,
    xaxis3 = xaxis,
)

fig.show()


In [None]:
import plotly.io as pio

fname = html_top_dir.joinpath(verbose_name+'.svg')
pio.write_image(fig, html_top_dir.joinpath(verbose_name+'.svg'), scale=1, width=5*300, height=2*300)
fname

In [None]:
from pathlib import Path
relfname = os.path.relpath(str(fname), os.getcwd())
display(HTML(f"""<img src="{relfname}" />"""))

In [None]:
!pip install -U kaleido

In [None]:
raise

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

auroc_cols = column_xs(combined_df, ['auroc'])
precision_cols = column_xs(combined_df, ['precision'])
recall_cols = column_xs(combined_df, ['recall'])
f1_cols = column_xs(combined_df, ['f1-score'])
support_cols = column_xs(combined_df, ['support'])

cols_ = [auroc_cols, recall_cols, precision_cols, f1_cols, support_cols]
fig = make_subplots(rows=len(cols_), cols=1, shared_xaxes=True, vertical_spacing=0.01)

names = sorted(set([c[1] for c in itertools.chain(*cols_)]), key=lambda x: 'avg' in x)
colors = px.colors.qualitative.Plotly
dashes = ['solid', 'dash', 'dotted']

list(itertools.product(dashes, colors))

lines = {}
for name, spec in zip(names, itertools.product(dashes, colors)):
    lines[name] = {'color': spec[1], 'dash': spec[0]}

visited = set()
for r, cols__ in enumerate(cols_, 1):
    for c in cols__:
        ypp = combined_df[c].reindex(x)
        line = lines[c[1]]
        showlegend = not c[1] in visited
        visited.add(c[1])
        fig.add_trace(go.Scatter(x=x, y=smooth(ypp), showlegend=showlegend, legendgroup=c[1],
                name=c[1], hovertemplate="%{y: .5f}", connectgaps=False, line=line), row=r, col=1)
    fig.update_yaxes(title_text=c[-1], row=r, col=1)
add_dates(fig, dates, 1.025)

fig.update_layout(title=f"Peformance")
fig.update_layout(hovermode="x unified")
fig.update_layout(height=300*len(cols_))
fig.update_xaxes(range=[graph_start, graph_end])
fig.update_layout(font=font)

fig.show()

if write_html:
    fig_html = fig.to_html()
    with open(fn, 'a') as f:
        print(f"<h2>Performance</h2>", file=f)
        print(fig_html, file=f)

In [None]:
fh = FigureHelper(x)
fh.add_trace(y=smooth(yp), name="AUROC", connectgaps=False, line={"color": "blue"}, yu=smooth(yp+perf_error_df), yl=smooth(yp-perf_error_df), row=1, col=1)
fh.add_trace(y=smooth(combined_df[congruency_measure_col]), row=2, name='Data Congruency (True)')
corrs = []

if not float(run_row['peds_weight']):
    xcols = zip(["metadata", "vae", "score", "vae+score", "metadate+vae+score"],  [metadata_cols, vae_cols, score_cols, vae_cols+score_cols, vae_cols+score_cols+metadata_cols])
else:
    xcols = zip(["vae", "score", "vae+score"],  [vae_cols, score_cols, vae_cols+score_cols])

for row, (name_, cols) in enumerate(xcols, 1):
    for i, name in enumerate(["abs(corr)"]):
        otherstd_ = otherstd[cols]
        weights = m[name].sort_values(ascending=False)
        yo = -w_avg(otherstd_.loc[x], weights=weights.to_dict())
        
        fh.add_trace(y=smooth(yo), name=name_, line={"width": 1},  connectgaps=False, row=3, col=1)
        
        corrs.append(collect_corr(y,yo, name_, "Everything", name))
        corrs.append(collect_corr(y,yo, name_, "Validation", name,
                                start_date=settings.PADCHEST_SPLIT_DATES[0], end_date=settings.PADCHEST_SPLIT_DATES[1]))
        corrs.append(collect_corr(y,yo, name_, "Test", name,
                                start_date=settings.PADCHEST_SPLIT_DATES[1]))
        corrs.append(collect_corr(y,yo, name_, "First Year of Test", name,
                                start_date=settings.PADCHEST_SPLIT_DATES[1], end_date="2014-12-31"))


fig = fh.make_fig(shared_xaxes=True, vertical_spacing=0.01)
add_dates(fig, dates, 1.08)


fig.update_layout(title=f"Level 1 Metrics")
fig.update_xaxes(showspikes=True, spikecolor="black", spikesnap="cursor", spikemode="across", spikethickness=1)
fig.update_layout(spikedistance=1000)
fig.update_layout(height=600)
fig.update_xaxes(range=[graph_start, graph_end])
fig.update_layout(barmode='overlay')
corr_df = pd.DataFrame(corrs).sort_values('when')
display(corr_df)
fig.show()

fig_html = fig.to_html()

if write_html:
    with open(fn, 'a') as f:
        print(f"<h2>Level 1 Unified</h2>", file=f)
        for w, grp in corr_df.groupby('when'):
            print(f"<strong>{w}</strong>{grp.to_html()}", file=f)
        print(fig_html, file=f)


In [None]:
fh = FigureHelper(x, dashes=['solid', 'dot', 'dash', 'longdash', 'dashdot', 'longdashdot'])
fh.add_trace(y=smooth(yp), name="AUROC", connectgaps=False, line={"color": "blue"}, yu=smooth(yp+perf_error_df), yl=smooth(yp-perf_error_df), row=1, col=1)
fh.add_trace(y=smooth(combined_df[congruency_measure_col]), row=2, name='Data Congruency (True)')
corrs = []

def partition(l, n):
    return [l[i:i + n] for i in range(0, len(l), n)]

cols = metadata_cols
otherstd_ = otherstd[cols]
cols_chi2 = column_xs(otherstd_, include='chi2')
otherstd_ = otherstd[cols_chi2]
cols_ = partition(cols_chi2, 12)
print(len(cols_))

for row, cols in enumerate(cols_, 3):
    for c in cols:
        yo = -otherstd[c]
        fh.add_trace(y=smooth(yo), name=str(c), connectgaps=False, row=row, col=1)
        
        corrs.append(collect_corr(y,yo, str(c), "Everything", 'N/A'))
        corrs.append(collect_corr(y,yo, str(c), "Validation", 'N/A',
                                start_date=settings.PADCHEST_SPLIT_DATES[0], end_date=settings.PADCHEST_SPLIT_DATES[1]))
        corrs.append(collect_corr(y,yo, str(c), "Test", 'N/A',
                                start_date=settings.PADCHEST_SPLIT_DATES[1]))
        corrs.append(collect_corr(y,yo, str(c), "First Year of Test", 'N/A',
                                start_date=settings.PADCHEST_SPLIT_DATES[1], end_date="2014-12-31"))

if not float(run_row['peds_weight']):     
    fig = fh.make_fig(shared_xaxes=True, vertical_spacing=0.01)
    add_dates(fig, dates, 1.05)

    fig.update_layout(title=f"Metadata Categorical")
    fig.update_xaxes(showspikes=True, spikecolor="black", spikesnap="cursor", spikemode="across", spikethickness=1)
    fig.update_layout(spikedistance=1000)
    fig.update_layout(height=200*(len(cols_)+2))
    fig.update_xaxes(range=[graph_start, graph_end])
    fig.update_layout(font=font)
    fig.update_layout(barmode='overlay')

    corr_df = pd.DataFrame(corrs)
    display(corr_df)

    fig.show()


    fig_html = fig.to_html()
    if write_html: 
        with open(fn, 'a') as f:
            print(f"<h2>Metadata Categorical</h2>", file=f)
            for w, grp in corr_df.groupby('when'):
                print(f"<strong>{w}</strong>{grp.to_html()}", file=f)
            print(fig_html, file=f)

In [None]:
fh = FigureHelper(x, dashes=['solid', 'dot', 'dash', 'longdash', 'dashdot', 'longdashdot'])
fh.add_trace(y=smooth(yp), name="AUROC", connectgaps=False, line={"color": "blue"}, yu=smooth(yp+perf_error_df), yl=smooth(yp-perf_error_df), row=1, col=1)
fh.add_trace(y=smooth(combined_df[congruency_measure_col]), row=2, name='Data Congruency (True)')
corrs = []

def partition(l, n):
    return [l[i:i + n] for i in range(0, len(l), n)]

cols = metadata_cols
otherstd_ = otherstd[cols]
cols_chi2 = column_xs(otherstd_, include='ks')
otherstd_ = otherstd[cols_chi2]
cols_ = partition(cols_chi2, 14)
print(len(cols_chi2))

for row, cols in enumerate(cols_, 3):
    for c in cols:
        yo = -otherstd[c]
        fh.add_trace(y=smooth(yo), name=str(c), connectgaps=False, row=row, col=1)
        
        corrs.append(collect_corr(y,yo, str(c), "Everything", 'N/A'))
        corrs.append(collect_corr(y,yo, str(c), "Validation", 'N/A',
                                start_date=settings.PADCHEST_SPLIT_DATES[0], end_date=settings.PADCHEST_SPLIT_DATES[1]))
        corrs.append(collect_corr(y,yo, str(c), "Test", 'N/A',
                                start_date=settings.PADCHEST_SPLIT_DATES[1]))
        corrs.append(collect_corr(y,yo, str(c), "First Year of Test", 'N/A',
                                start_date=settings.PADCHEST_SPLIT_DATES[1], end_date="2014-12-31"))
if not float(run_row['peds_weight']):
    fig = fh.make_fig(shared_xaxes=True, vertical_spacing=0.01)
    add_dates(fig, dates, 1.08)

    fig.update_layout(title=f"Metadata Real Valued")
    fig.update_xaxes(showspikes=True, spikecolor="black", spikesnap="cursor", spikemode="across", spikethickness=1)
    fig.update_layout(spikedistance=1000)
    fig.update_layout(height=200*(len(cols_)+2))
    fig.update_xaxes(range=[graph_start, graph_end])
    fig.update_layout(font=font)
    fig.update_layout(barmode='overlay')
    fig.show()


    fig_html = fig.to_html()
    if write_html: 
        with open(fn, 'a') as f:
            print(f"<h2>Metadata Real Valued</h2>", file=f)
            for w, grp in corr_df.groupby('when'):
                print(f"<strong>{w}</strong>{grp.to_html()}", file=f)
            print(fig_html, file=f)

In [None]:
fh = FigureHelper(x, dashes=['solid', 'dot', 'dash', 'longdash', 'dashdot', 'longdashdot'])
fh.add_trace(y=smooth(yp), name="AUROC", connectgaps=False, line={"color": "blue"}, yu=smooth(yp+perf_error_df), yl=smooth(yp-perf_error_df), row=1, col=1)
fh.add_trace(y=smooth(combined_df[congruency_measure_col]), row=2, name='Data Congruency (True)')
corrs = []

def partition(l, n):
    return [l[i:i + n] for i in range(0, len(l), n)]


cols = vae_cols
o = other_df[cols].loc[settings.PADCHEST_SPLIT_DATES[1]:].swaplevel(0, 2, axis=1)[['distance']].swaplevel(0, 2, axis=1)
colss = o.max(axis=0).sort_values(ascending=False).head(12).index.tolist()
cols_ = partition(colss, 12)


for row, cols in enumerate(cols_, 3):
    for c in cols:
        yo = -otherstd[c]
        fh.add_trace(y=smooth(yo), name=str(c), connectgaps=False, row=row, col=1)
        
        corrs.append(collect_corr(y,yo, str(c), "Everything", 'N/A'))
        corrs.append(collect_corr(y,yo, str(c), "Validation", 'N/A',
                                start_date=settings.PADCHEST_SPLIT_DATES[0], end_date=settings.PADCHEST_SPLIT_DATES[1]))
        corrs.append(collect_corr(y,yo, str(c), "Test", 'N/A',
                                start_date=settings.PADCHEST_SPLIT_DATES[1]))
        corrs.append(collect_corr(y,yo, str(c), "First Year of Test", 'N/A',
                                start_date=settings.PADCHEST_SPLIT_DATES[1], end_date="2014-12-31"))
        
fig = fh.make_fig(shared_xaxes=True, vertical_spacing=0.01)
add_dates(fig, dates, 1.08)

fig.update_layout(title=f"VAE Mu (top {len(colss)})")
fig.update_xaxes(showspikes=True, spikecolor="black", spikesnap="cursor", spikemode="across", spikethickness=1)
fig.update_layout(spikedistance=1000)
fig.update_layout(height=200*(len(cols_)+2))
fig.update_xaxes(range=[settings.PADCHEST_SPLIT_DATES[1], graph_end])
fig.update_layout(barmode='overlay')
fig.update_layout(font=font)
corr_df = pd.DataFrame(corrs).sort_values('when')
display(corr_df)
fig.show()


fig_html = fig.to_html()
if write_html: 
    with open(fn, 'a') as f:
        print(f"<h2>VAE Mu (top {len(colss)})</h2>", file=f)
        for w, grp in corr_df.groupby('when'):
            print(f"<strong>{w}</strong>{grp.to_html()}", file=f)
        print(fig_html, file=f)

In [None]:
arg_df2['Link'] = [f"""<a href="{name}/index.html" disabled=>Graphs</a>""" if html_top_dir.joinpath(name).exists() else "N/A" for name in arg_df2.index]


with open(html_top_dir.joinpath("index.html"), 'w') as f:
        print("""
            <style>
            table {
            font-family: arial, sans-serif;
            border-collapse: collapse;
            width: 80%;
            }

            td, th {
            border: 1px solid #dddddd;
            text-align: left;
            padding: 8px;
            }

            tr:nth-child(even) {
            background-color: #dddddd;
            }
            </style>
        """, file=f)
        print(f"Generated: {datetime.datetime.now()}", file=f)
        print(arg_df2.to_html(escape=False), file=f)
        print(fix_links_script, file=f)

In [None]:
def create_index_html(child):
    if child.is_file(): return
    html_files = []
    html_folders = []
    if child.parent.joinpath('index.html').exists():
        html_folders.append("..")
    for html_file in child.iterdir():
        n = html_file.relative_to(child)
        if html_file.is_file() and not str(html_file).endswith('index.html'):
            html_files.append(n)
        elif html_file.joinpath('index.html').exists():
            html_folders.append(n)
    
    
    html = "folders: <ul>"
    for n in html_folders:
            html += f"""
            <li><a href="{n}/index.html">{n}</a></li>
            """
    html += "</ul>"
    html += "files:<ul>"
    for n in html_files:
            html += f"""
            <li><a href="{n}">{n}</a></li>
            """
    html += "</ul>"
    with open(child.joinpath('index.html'), 'w') as f:
        print(html, file=f)
        print(fix_links_script, file=f)
        
        


In [None]:
create_index_html(html_top_dir.parent.joinpath('vae[all-data]'))
create_index_html(html_top_dir.parent)
    

In [None]:
for child in html_top_dir.iterdir():
    create_index_html(child)
    
            
    
    
            

In [None]:
html