In [1]:
import re
import wandb
import pandas as pd
import numpy as np

pd.set_option('display.max_rows', 48)
pd.set_option('display.max_columns', 20)
pd.set_option('display.width', 1000)

api = wandb.Api()

runs = api.runs("rap1ide/slice_inflate")
recent_mmwhs_runs = list(filter(lambda r:
                          r.name.startswith("202402")
                          and 'meta_config_id' in r.config
                          and isinstance(r.config['meta_config_id'], int)
                          and r.config['dataset'][0] == 'mmwhs', runs))

mmwhs_runs = [dict(name=f"dummy{idx:02d}") for idx in range(0,48)]

for run_idx in range(len(mmwhs_runs)):
    meta_config_id = int(run_idx // 3)
    fold_idx = run_idx % 3
    if meta_config_id in [9,11,13,15]:
        wandb_runs = list(filter(lambda r: f"fold-{fold_idx}" in r.name
            and 'meta_config_id' in r.config
            and r.config['meta_config_id'] == meta_config_id
            and '-opt_second' in r.name,
            recent_mmwhs_runs))
    else:
        wandb_runs = list(filter(lambda r: f"fold-{fold_idx}" in r.name
            and 'meta_config_id' in r.config
            and r.config['meta_config_id'] == meta_config_id
            and '-ref' in r.name,
            recent_mmwhs_runs))

    if len(wandb_runs) > 1:
        raise ValueError(f"Expected exactly one run, got {len(wandb_runs)}")
    elif len(wandb_runs) == 0:
        continue

    run = wandb_runs[0]
    name = run.name
    path = run.path
    meta_config_id = run.config['meta_config_id']

    fold_idx = int(re.match(r'.*fold-(\d+).*', name).group(1))
    entry = dict(
        name=name,
        id=meta_config_id,
        fold=fold_idx,
        access_key='/'.join(path),
    )
    mmwhs_runs[run_idx] = entry


In [2]:
def get_agg_dict(filtered_frame):
    agg_dict = {
        n: 'first' if t != float else 'mean' for n,t in zip(filtered_frame.columns,filtered_frame.dtypes)
    }
    return agg_dict

# get n chunks of a list
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def join_cols(frame, first_col_key, second_col_key, join_str="+", drop_second=True):
    first_col = frame[first_col_key].astype(str)
    second_col = frame[second_col_key].astype(str)
    frame[first_col_key] = first_col + join_str + second_col
    if drop_second:
        frame.drop(second_col_key, axis=1, inplace=True)
    else:
        frame.drop(first_col_key, axis=1, inplace=True)
    return frame

In [3]:
settings = {
   "0:1": dict(description="Clinical standard",
      first_view="p2CH",
      second_view="p4CH",
      prescan_res ="$(1.5mm)^3$",
      prescan_type ="GT",
      slice_res ="$(1.5mm)^2$",
      slice_type="GT",
   ),
   "1:2": dict(description="Clinical standard",
      first_view="2CH",
      second_view="4CH",
      prescan_res ="$(1.5mm)^3$",
      prescan_type ="GT",
      slice_res ="$(1.5mm)^2$",
      slice_type="GT",
   ),
   "2:3": dict(description="Clinical standard",
      first_view="2CH",
      second_view="SA",
      prescan_res ="$(1.5mm)^3$",
      prescan_type ="GT",
      slice_res ="$(1.5mm)^2$",
      slice_type="GT",
   ),
   "3:9": dict(description="Mean out of 6 Random",
      first_view="RND",
      second_view="RND",
      prescan_res ="$(1.5mm)^3$",
      prescan_type ="GT",
      slice_res ="$(1.5mm)^2$",
      slice_type="GT",
   ),
   "9:10": dict(description="Optimized",
      first_view="OPT",
      second_view="OPT",
      prescan_res ="$(1.5mm)^3$",
      prescan_type ="GT",
      slice_res ="$(1.5mm)^2$",
      slice_type="GT",
   ),

   "10:11": dict(description="Clinical standard",
      first_view="2CH",
      second_view="4CH",
      prescan_res ="$(6mm)^3$",
      prescan_type ="GT",
      slice_res ="$(1.5mm)^2$",
      slice_type="GT",
   ),
   "11:12": dict(description="Optimized",
      first_view="OPT",
      second_view="OPT",
      prescan_res ="$(6mm)^3$",
      prescan_type ="GT",
      slice_res ="$(1.5mm)^2$",
      slice_type="GT",
   ),

   "12:13": dict(description="Clinical standard",
      first_view="2CH",
      second_view="4CH",
      prescan_res ="$(6mm)^3$",
      prescan_type ="SEG",
      slice_res ="$(1.5mm)^2$",
      slice_type="SEG",
   ),
   "13:14": dict(description="Optimized",
      first_view="OPT",
      second_view="OPT",
      prescan_res ="$(6mm)^3$",
      prescan_type ="SEG",
      slice_res ="$(1.5mm)^2$",
      slice_type="SEG",
   ),

   "14:15": dict(description="Clinical standard",
      first_view="2CH",
      second_view="4CH",
      prescan_res ="$(6mm)^3$",
      prescan_type ="SEG",
      slice_res ="$(6mm)^3$",
      slice_type="SEG",
   ),
   "15:16": dict(description="Optimized",
      first_view="OPT",
      second_view="OPT",
      prescan_res ="$(6mm)^3$",
      prescan_type ="SEG",
      slice_res ="$(6mm)^3$",
      slice_type="SEG",
   ),
}

# Build latex tables

In [4]:
def get_wanted_keys(phase):
    assert phase in ['val', 'test']

    wanted_keys = {
        '_id': '_id',
        'prescan_res': 'Precsan props.',
        # 'description': 'Description',
        'prescan_type': 'Prescan props.',
        'slice_res': 'Slice resolution',
        'slice_type': 'Slice props.',
        'first_view': 'Slice view(s)',
        'second_view': 'Second view',
        'fold': 'fold',
        f'scores/{phase}_mean_dice_MYO': 'MYO',
        f'scores/{phase}_mean_dice_LV': 'LV',
        f'scores/{phase}_mean_dice_RV': 'RV',
        f'scores/{phase}_mean_dice_LA': 'LA',
        f'scores/{phase}_mean_dice_RA': 'RA',

        f'scores/{phase}_mean_oa_exclude_bg_dice': '\multicolumn{1}{c}{$\mu\pm\sigma$ }',
        f'scores/{phase}_std_oa_exclude_bg_dice': 'N/A',

        f'scores/{phase}_mean_hd95_MYO': 'MYO',
        f'scores/{phase}_mean_hd95_LV': 'LV',
        f'scores/{phase}_mean_hd95_RV': 'RV',
        f'scores/{phase}_mean_hd95_LA': 'LA',
        f'scores/{phase}_mean_hd95_RA': 'RA',
        f'scores/{phase}_mean_oa_exclude_bg_hd95': '\multicolumn{1}{c}{$\mu\pm\sigma$  }',
        f'scores/{phase}_std_oa_exclude_bg_hd95': 'N/A',

        # 'scores/val_mean_oa_exclude_bg_iou': 'IOU',
        # 'scores/val_std_oa_exclude_bg_iou': '',

        f'scores/{phase}_mean_delta_vol_rel_LV': '\multicolumn{1}{c}{$\mu\pm\sigma$}   ',
        f'scores/{phase}_std_delta_vol_rel_LV': '',
    }
    return wanted_keys

## Build latex table for MMWHS

In [5]:
df = pd.DataFrame()

for run_idx, rr in enumerate(mmwhs_runs):
   wandb_run_name = rr['name']
   run_frame = pd.DataFrame()
   run_numeric_id = run_idx // 3
   try:
      run_key = rr['access_key']
      wrun = api.run(run_key)
      run_frame = pd.DataFrame(wrun.history())
      run_frame = run_frame.iloc[-1:]
      run_frame.index = [run_idx]

   except:
      print(f"Failed to fetch run {wandb_run_name}")
      run_frame = run_frame.map(lambda x: np.nan)

   if 3 <= run_numeric_id < 9:
      run_settings = settings['3:9']
   else:
      run_settings = settings[f"{run_numeric_id}:{run_numeric_id+1}"]

   run_frame.insert(0, '_id', [run_numeric_id])
   run_frame.insert(1, 'description', [run_settings['description']])
   run_frame.insert(2, 'first_view', run_settings['first_view'])
   run_frame.insert(3, 'second_view', run_settings['second_view'])
   run_frame.insert(4, 'prescan_res', run_settings['prescan_res'])
   run_frame.insert(5, 'prescan_type', run_settings['prescan_type'])
   run_frame.insert(6, 'slice_res', run_settings['slice_res'])
   run_frame.insert(7, 'slice_type', run_settings['slice_type'])
   fold_idx = run_idx % 3
   run_frame.insert(8, 'fold', fold_idx)

   df = pd.concat([df,run_frame])

df = df.fillna(0)

Failed to fetch run dummy09
Failed to fetch run dummy10
Failed to fetch run dummy11
Failed to fetch run dummy12
Failed to fetch run dummy13
Failed to fetch run dummy14
Failed to fetch run dummy15
Failed to fetch run dummy16
Failed to fetch run dummy17
Failed to fetch run dummy18
Failed to fetch run dummy19
Failed to fetch run dummy20
Failed to fetch run dummy21
Failed to fetch run dummy22
Failed to fetch run dummy23
Failed to fetch run dummy24
Failed to fetch run dummy25
Failed to fetch run dummy26
Failed to fetch run dummy29
Failed to fetch run dummy33
Failed to fetch run dummy34
Failed to fetch run dummy35
Failed to fetch run dummy37
Failed to fetch run dummy38
Failed to fetch run dummy39
Failed to fetch run dummy40
Failed to fetch run dummy41
Failed to fetch run dummy42
Failed to fetch run dummy43
Failed to fetch run dummy44
Failed to fetch run dummy45
Failed to fetch run dummy46
Failed to fetch run dummy47


In [6]:
df

Unnamed: 0,_id,description,first_view,second_view,prescan_res,prescan_type,slice_res,slice_type,fold,scores/val_std_delta_vol_rel_RA,...,orientations/test_hla_theta_ap0_mean,orientations/val_hla_theta_zp0_mean,orientations/val_sa_theta_t_offsets2_std,orientations/val_hla_theta_zp0_std,orientations/test_sa_theta_t_offsets1_std,orientations/val_sa_theta_zp0_mean,orientations/test_hla_theta_ap2_std,orientations/train_hla_theta_ap1_std,orientations/test_sa_theta_ap4_std,orientations/train_sa_theta_ap4_std
0,0,Clinical standard,p2CH,p4CH,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,0,0.082965,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,Clinical standard,p2CH,p4CH,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,1,0.208663,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,Clinical standard,p2CH,p4CH,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,2,0.34197,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,1,Clinical standard,2CH,4CH,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,0,0.244374,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,1,Clinical standard,2CH,4CH,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,1,0.11786,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,1,Clinical standard,2CH,4CH,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,2,0.168998,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,2,Clinical standard,2CH,SA,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,0,0.213786,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,2,Clinical standard,2CH,SA,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,1,0.441933,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,2,Clinical standard,2CH,SA,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,2,0.239835,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,3,Mean out of 6 Random,RND,RND,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [7]:
filtered_frame = df.copy()
FIRST_FOLD_ONLY = True
if FIRST_FOLD_ONLY:
    filtered_frame = filtered_frame[filtered_frame['fold'] == 0]
else:
    filtered_frame = filtered_frame.groupby('_id').agg(get_agg_dict(filtered_frame)) # TODO readd

# Prepare values
wanted_keys = get_wanted_keys('val')
filtered_frame = filtered_frame[wanted_keys.keys()]
filtered_frame

Unnamed: 0,_id,prescan_res,prescan_type,slice_res,slice_type,first_view,second_view,fold,scores/val_mean_dice_MYO,scores/val_mean_dice_LV,...,scores/val_std_oa_exclude_bg_dice,scores/val_mean_hd95_MYO,scores/val_mean_hd95_LV,scores/val_mean_hd95_RV,scores/val_mean_hd95_LA,scores/val_mean_hd95_RA,scores/val_mean_oa_exclude_bg_hd95,scores/val_std_oa_exclude_bg_hd95,scores/val_mean_delta_vol_rel_LV,scores/val_std_delta_vol_rel_LV
0,0,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,p2CH,p4CH,0,0.787386,0.882927,...,0.16188,7.653829,8.184327,30.340399,27.579405,38.687687,22.48913,25.385578,0.165246,0.148523
3,1,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,2CH,4CH,0,0.818286,0.88684,...,0.094747,6.848581,8.203704,19.513845,8.874496,27.100212,14.108168,10.160501,0.170663,0.193362
6,2,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,2CH,SA,0,0.799069,0.876908,...,0.121211,7.778636,10.23329,16.517567,13.801551,31.586577,15.983524,10.033892,0.14361,0.107447
0,3,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,RND,RND,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,4,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,RND,RND,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,5,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,RND,RND,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,6,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,RND,RND,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,7,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,RND,RND,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0,8,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,RND,RND,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
27,9,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,OPT,OPT,0,0.790135,0.873602,...,0.063857,8.506029,9.639396,13.170564,12.028209,13.943667,11.457573,4.916939,0.137696,0.15863


In [8]:
percent_keys = [k for k in wanted_keys if "dice" in k or "vol_rel" in k]
mean_keys = [k for k in wanted_keys if "mean" in k]
min_metrics_keys = [k for k in wanted_keys if ("delta" in k or "hd" in k) and not 'std' in k]
max_metrics_keys = [k for k in wanted_keys if ("dice" in k or "iou" in k) and not 'std' in k]
filtered_frame.loc[:,percent_keys] *= 100.

# Aggregate random runs to build mean
random_runs = list(range(0,3)) + 6 * ['is_random'] + list(range(9,16))
filtered_frame.insert(1, 'is_random', random_runs)
filtered_frame = filtered_frame.groupby('is_random', as_index=False).agg(get_agg_dict(filtered_frame))

reindex_idx = list(range(0,3)) + [-1] + list(range(3,10))
filtered_frame = filtered_frame.iloc[reindex_idx,:]
filtered_frame.drop('is_random', axis=1, inplace=True)

filtered_frame

Unnamed: 0,_id,prescan_res,prescan_type,slice_res,slice_type,first_view,second_view,fold,scores/val_mean_dice_MYO,scores/val_mean_dice_LV,...,scores/val_std_oa_exclude_bg_dice,scores/val_mean_hd95_MYO,scores/val_mean_hd95_LV,scores/val_mean_hd95_RV,scores/val_mean_hd95_LA,scores/val_mean_hd95_RA,scores/val_mean_oa_exclude_bg_hd95,scores/val_std_oa_exclude_bg_hd95,scores/val_mean_delta_vol_rel_LV,scores/val_std_delta_vol_rel_LV
0,0,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,p2CH,p4CH,0,78.738603,88.292696,...,16.188044,7.653829,8.184327,30.340399,27.579405,38.687687,22.48913,25.385578,16.524597,14.852262
1,1,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,2CH,4CH,0,81.828588,88.684019,...,9.474719,6.848581,8.203704,19.513845,8.874496,27.100212,14.108168,10.160501,17.066317,19.336211
2,2,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,2CH,SA,0,79.906897,87.690754,...,12.121097,7.778636,10.23329,16.517567,13.801551,31.586577,15.983524,10.033892,14.360977,10.744652
10,3,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,RND,RND,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,9,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,OPT,OPT,0,79.013486,87.360189,...,6.385709,8.506029,9.639396,13.170564,12.028209,13.943667,11.457573,4.916939,13.769631,15.862962
4,10,$(6mm)^3$,GT,$(1.5mm)^2$,GT,2CH,4CH,0,81.008878,89.431286,...,8.628087,7.543268,8.092771,18.864174,10.964181,22.749193,13.642718,9.240935,13.829095,11.850469
5,11,$(6mm)^3$,GT,$(1.5mm)^2$,GT,OPT,OPT,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,12,$(6mm)^3$,SEG,$(1.5mm)^2$,SEG,2CH,4CH,0,29.225207,59.013867,...,21.455296,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,13,$(6mm)^3$,SEG,$(1.5mm)^2$,SEG,OPT,OPT,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,14,$(6mm)^3$,SEG,$(6mm)^3$,SEG,2CH,4CH,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [9]:
# Get bold values (best)
group_ranges = [
    [0,5],
    [5,7],
    [7,9],
    [9,11],
]

bold_idxs = []
for sub_range in group_ranges:
    sub_frame = filtered_frame.iloc[sub_range[0]:sub_range[1]]
    optimal_vals_maximize = sub_frame[max_metrics_keys].idxmax()
    optimal_vals_minimize = sub_frame[min_metrics_keys].idxmin()

    bold_idxs.append(optimal_vals_maximize)
    bold_idxs.append(optimal_vals_minimize)

In [10]:
# Round and convert
filtered_frame = filtered_frame.round(decimals=1)
string_frame = filtered_frame.copy().astype(str)

# Fuse mean and std columns
for m in mean_keys:
    mean_col = string_frame[m].astype(str)
    std_col_key = m.replace("mean","std")
    if std_col_key in string_frame.columns:
        std_col = string_frame[std_col_key].astype(str)
        string_frame[m] = mean_col + r" \pm " + std_col
        string_frame.drop(std_col_key, axis=1, inplace=True)

# Add bold to optimal values
for bold_group in bold_idxs:
    for col_name, idx in zip(bold_group.index, bold_group.values):
        row_idx = string_frame.index.get_loc(idx)
        col_idx = string_frame.columns.get_loc(col_name)
        string_frame.iloc[row_idx,col_idx] = r"\B " + string_frame.iloc[row_idx,col_idx]

# Join first and second view
string_frame = join_cols(string_frame, 'first_view', 'second_view', join_str=r"+")
string_frame = join_cols(string_frame, 'prescan_type', 'prescan_res', join_str=r" ", drop_second=True)
string_frame = join_cols(string_frame, 'slice_type', 'slice_res', join_str=r" ", drop_second=True)

# Replace column names
string_frame.columns = [wanted_keys[c] for c in string_frame.columns]

# Drop and insert spacer columns
string_frame.drop('_id', axis=1, inplace=True)
string_frame.drop('fold', axis=1, inplace=True)

string_frame.insert(3, ' ', len(string_frame)*["\hspace{1pt}"])
string_frame.insert(10, '  ', len(string_frame)*["\hspace{1pt}"])
string_frame.insert(17, '   ', len(string_frame)*["\hspace{1pt}"])

# Add offset to groups in latex
for group_idxs in group_ranges[:-1]:
    txt = string_frame.iloc[group_idxs[1],0]
    string_frame.iloc[group_idxs[1],0] = r"\rule{0pt}{4ex} "+txt

string_frame

Unnamed: 0,Prescan props.,Slice props.,Slice view(s),Unnamed: 4,MYO,LV,RV,LA,RA,\multicolumn{1}{c}{$\mu\pm\sigma$ },Unnamed: 11,MYO.1,LV.1,RV.1,LA.1,RA.1,\multicolumn{1}{c}{$\mu\pm\sigma$ }.1,Unnamed: 18,\multicolumn{1}{c}{$\mu\pm\sigma$}
0,GT $(1.5mm)^3$,GT $(1.5mm)^2$,p2CH+p4CH,\hspace{1pt},78.7,88.3,69.4,75.7,65.4,75.5 \pm 16.2,\hspace{1pt},7.7,8.2,30.3,27.6,38.7,22.5 \pm 25.4,\hspace{1pt},16.5 \pm 14.9
1,GT $(1.5mm)^3$,GT $(1.5mm)^2$,2CH+4CH,\hspace{1pt},\B 81.8,\B 88.7,77.2,\B 86.5,74.9,81.8 \pm 9.5,\hspace{1pt},6.8,8.2,19.5,8.9,27.1,14.1 \pm 10.2,\hspace{1pt},17.1 \pm 19.3
2,GT $(1.5mm)^3$,GT $(1.5mm)^2$,2CH+SA,\hspace{1pt},79.9,87.7,77.0,79.7,61.3,77.1 \pm 12.1,\hspace{1pt},7.8,10.2,16.5,13.8,31.6,16.0 \pm 10.0,\hspace{1pt},14.4 \pm 10.7
10,GT $(1.5mm)^3$,GT $(1.5mm)^2$,RND+RND,\hspace{1pt},0.0,0.0,0.0,0.0,0.0,0.0 \pm 0.0,\hspace{1pt},\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0 \pm 0.0,\hspace{1pt},\B 0.0 \pm 0.0
3,GT $(1.5mm)^3$,GT $(1.5mm)^2$,OPT+OPT,\hspace{1pt},79.0,87.4,\B 82.7,81.9,\B 84.0,\B 83.0 \pm 6.4,\hspace{1pt},8.5,9.6,13.2,12.0,13.9,11.5 \pm 4.9,\hspace{1pt},13.8 \pm 15.9
4,\rule{0pt}{4ex} GT $(6mm)^3$,GT $(1.5mm)^2$,2CH+4CH,\hspace{1pt},\B 81.0,\B 89.4,\B 78.9,\B 85.2,\B 76.4,\B 82.2 \pm 8.6,\hspace{1pt},7.5,8.1,18.9,11.0,22.7,13.6 \pm 9.2,\hspace{1pt},13.8 \pm 11.9
5,GT $(6mm)^3$,GT $(1.5mm)^2$,OPT+OPT,\hspace{1pt},0.0,0.0,0.0,0.0,0.0,0.0 \pm 0.0,\hspace{1pt},\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0 \pm 0.0,\hspace{1pt},\B 0.0 \pm 0.0
6,\rule{0pt}{4ex} SEG $(6mm)^3$,SEG $(1.5mm)^2$,2CH+4CH,\hspace{1pt},\B 29.2,\B 59.0,\B 21.4,\B 8.8,\B 0.0,\B 23.7 \pm 21.5,\hspace{1pt},\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0 \pm 0.0,\hspace{1pt},\B 0.0 \pm 0.0
7,SEG $(6mm)^3$,SEG $(1.5mm)^2$,OPT+OPT,\hspace{1pt},0.0,0.0,0.0,0.0,0.0,0.0 \pm 0.0,\hspace{1pt},0.0,0.0,0.0,0.0,0.0,0.0 \pm 0.0,\hspace{1pt},0.0 \pm 0.0
8,\rule{0pt}{4ex} SEG $(6mm)^3$,SEG $(6mm)^3$,2CH+4CH,\hspace{1pt},\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0 \pm 0.0,\hspace{1pt},\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0 \pm 0.0,\hspace{1pt},\B 0.0 \pm 0.0


In [20]:
# Save to latex

PM_COL_FORMAT = "S[table-figures-decimal=1,separate-uncertainty=true,table-format=3.1(3)]"
COL_FORMAT = \
    ("c" * 9) \
    + PM_COL_FORMAT \
    + ("c" * 6) \
    + PM_COL_FORMAT \
    + "c" \
    + PM_COL_FORMAT
# 19 cols

header = [r'\multicolumn{3}{c}{\textbf{Experiment I}}'] \
    + 2*[None] \
    + [r'\hspace{1pt}'] \
    + [r'\multicolumn{6}{c}{\textbf{Dice in \% $\uparrow$}}'] + 5*[None] \
    + [r'\hspace{1pt}'] \
    + [r'\multicolumn{6}{c}{\textbf{HD95 in mm $\downarrow$}}'] + 5*[None] \
    + [r'\hspace{1pt}'] \
    + [r'\textbf{$\Delta$vol LV in \% $\downarrow$}']

latex_frame = pd.concat([pd.DataFrame(header, index=string_frame.columns).T, string_frame])

latex_frame.to_latex(
    buf="mmwhs_results.txt",
    escape=False,
    column_format=COL_FORMAT,
    index=False,
)

# Load latex file and replace & NaN with &
with open("mmwhs_results.txt", "r") as f:
    lines = f.readlines()

    lines = [l.replace("& NaN", "") for l in lines]

with open("mmwhs_results.txt", "w") as f:
    f.writelines(lines)

# Insert 5th line before third line
with open("mmwhs_results.txt", "r") as f:
    lines = f.readlines()
    lines.insert(2, lines[4])
    del lines[5]
with open("mmwhs_results.txt", "w") as f:
    f.writelines(lines)

# Build latex table for MRXCAT

In [21]:
df = pd.DataFrame()

for run_idx, (wandb_run_name, rr) in enumerate(mrxcat_run_dict.items()):
   run_key = rr['access_key']
   run_numeric_id = rr['id']
   try:
      wrun = api.run(run_key)
   except:
      print(f"Failed to fetch run {run_key}")
      continue

   run_frame = pd.DataFrame(wrun.history())
   run_frame = run_frame.iloc[-1:]
   run_frame.index = [run_idx]

   if 'dummy' in wandb_run_name:
      run_frame = run_frame.map(lambda x: 0)

   if 3 <= run_numeric_id < 9:
      run_settings = settings['3:9']
   else:
      run_settings = settings[f"{run_numeric_id}:{run_numeric_id+1}"]

   run_frame.insert(0, '_id', [rr['id']])
   run_frame.insert(1, 'description', [run_settings['description']])
   run_frame.insert(2, 'first_view', run_settings['first_view'])
   run_frame.insert(3, 'second_view', run_settings['second_view'])
   run_frame.insert(4, 'prescan_res', run_settings['prescan_res'])
   run_frame.insert(5, 'prescan_type', run_settings['prescan_type'])
   run_frame.insert(6, 'slice_res', run_settings['slice_res'])
   run_frame.insert(7, 'slice_type', run_settings['slice_type'])
   run_frame.insert(8, 'fold', rr['fold'])

   df = pd.concat([df,run_frame])

df

NameError: name 'mrxcat_run_dict' is not defined

In [None]:
filtered_frame = df.copy()
filtered_frame = filtered_frame.groupby('_id').agg(get_agg_dict(filtered_frame))

wanted_keys = get_wanted_keys('test') # Use test here for MRXCAT

In [None]:
# Prepare values
filtered_frame = filtered_frame[wanted_keys.keys()]

percent_keys = [k for k in wanted_keys if "dice" in k or "vol_rel" in k]
mean_keys = [k for k in wanted_keys if "mean" in k]
min_metrics_keys = [k for k in wanted_keys if ("delta" in k or "hd" in k) and not 'std' in k]
max_metrics_keys = [k for k in wanted_keys if ("dice" in k or "iou" in k) and not 'std' in k]
filtered_frame.loc[:,percent_keys] *= 100.

# Aggregate random runs to build mean
random_runs = list(range(0,3)) + 6 * ['is_random'] + list(range(9,16))
filtered_frame.insert(1, 'is_random', random_runs)
filtered_frame = filtered_frame.groupby('is_random', as_index=False).agg(get_agg_dict(filtered_frame))

reindex_idx = list(range(0,3)) + [-1] + list(range(3,10))
filtered_frame = filtered_frame.iloc[reindex_idx,:]
filtered_frame.drop('is_random', axis=1, inplace=True)

filtered_frame

Unnamed: 0,_id,prescan_res,prescan_type,slice_res,slice_type,first_view,second_view,fold,scores/test_mean_dice_MYO,scores/test_mean_dice_LV,...,scores/test_std_oa_exclude_bg_dice,scores/test_mean_hd95_MYO,scores/test_mean_hd95_LV,scores/test_mean_hd95_RV,scores/test_mean_hd95_LA,scores/test_mean_hd95_RA,scores/test_mean_oa_exclude_bg_hd95,scores/test_std_oa_exclude_bg_hd95,scores/test_mean_delta_vol_rel_LV,scores/test_std_delta_vol_rel_LV
0,0,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,p2CH,p4CH,0,84.641796,89.905107,...,3.84151,5.006896,5.481754,6.514866,7.524068,7.246971,6.354911,1.65476,21.401555,13.816542
1,1,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,2CH,4CH,0,76.878982,85.586771,...,3.961466,6.094784,6.461951,7.552638,7.859861,9.500514,7.493949,1.785818,33.92009,13.669775
2,2,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,2CH,SA,0,82.850911,89.889262,...,6.484449,5.579479,5.645747,15.448795,11.5988,13.66721,10.388006,5.381621,21.31151,9.46876
10,3,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,RND,RND,0,80.936613,87.501125,...,5.28614,7.515442,7.79629,10.717961,10.048375,11.308519,9.477318,3.516369,23.495407,12.508354
3,9,$(1.5mm)^3$,GT,$(1.5mm)^2$,GT,OPT,OPT,0,74.85398,86.27445,...,5.826232,7.899631,7.0626,5.884158,12.272633,9.148672,8.453539,3.088667,29.449003,16.144714
4,10,$(6mm)^3$,GT,$(1.5mm)^2$,GT,2CH,4CH,0,85.769944,90.76504,...,4.281002,4.898186,14.383776,6.328859,9.484782,10.113405,9.041802,8.460434,17.118691,10.181439
5,11,$(6mm)^3$,GT,$(1.5mm)^2$,GT,OPT,OPT,0,75.110361,83.801749,...,6.160986,7.205254,7.734472,5.700679,8.15207,9.602473,7.67899,2.252497,39.924002,21.947657
6,12,$(6mm)^3$,SEG,$(1.5mm)^2$,SEG,2CH,4CH,0,70.274425,81.067562,...,6.262354,7.62783,8.427763,8.130813,8.532231,10.996769,8.743081,2.596051,48.438331,21.510787
7,13,$(6mm)^3$,SEG,$(1.5mm)^2$,SEG,OPT,OPT,0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,14,$(6mm)^3$,SEG,$(6mm)^3$,SEG,2CH,4CH,0,79.952453,89.927898,...,4.030221,6.017407,5.274241,6.658468,7.465805,8.546093,6.792403,1.569478,18.916123,6.733484


In [None]:
# Get bold values (best)
group_ranges = [
    [0,5],
    [5,7],
    [7,9],
    [9,11],
]

bold_idxs = []
for sub_range in group_ranges:
    sub_frame = filtered_frame.iloc[sub_range[0]:sub_range[1]]
    optimal_vals_maximize = sub_frame[max_metrics_keys].idxmax()
    optimal_vals_minimize = sub_frame[min_metrics_keys].idxmin()

    bold_idxs.append(optimal_vals_maximize)
    bold_idxs.append(optimal_vals_minimize)

In [None]:
# Round and convert
filtered_frame = filtered_frame.round(decimals=1)
string_frame = filtered_frame.copy().astype(str)

# Fuse mean and std columns
for m in mean_keys:
    mean_col = string_frame[m].astype(str)
    std_col_key = m.replace("mean","std")
    if std_col_key in string_frame.columns:
        std_col = string_frame[std_col_key].astype(str)
        string_frame[m] = mean_col + r" \pm " + std_col
        string_frame.drop(std_col_key, axis=1, inplace=True)

# Add bold to optimal values
for bold_group in bold_idxs:
    for col_name, idx in zip(bold_group.index, bold_group.values):
        row_idx = string_frame.index.get_loc(idx)
        col_idx = string_frame.columns.get_loc(col_name)
        string_frame.iloc[row_idx,col_idx] = r"\B " + string_frame.iloc[row_idx,col_idx]

# Join first and second view
string_frame = join_cols(string_frame, 'first_view', 'second_view', join_str=r"+")
string_frame = join_cols(string_frame, 'prescan_type', 'prescan_res', join_str=r" ", drop_second=True)
string_frame = join_cols(string_frame, 'slice_type', 'slice_res', join_str=r" ", drop_second=True)

# Replace column names
string_frame.columns = [wanted_keys[c] for c in string_frame.columns]

# Drop and insert spacer columns
string_frame.drop('_id', axis=1, inplace=True)
string_frame.drop('fold', axis=1, inplace=True)

string_frame.insert(3, ' ', len(string_frame)*["\hspace{1pt}"])
string_frame.insert(10, '  ', len(string_frame)*["\hspace{1pt}"])
string_frame.insert(17, '   ', len(string_frame)*["\hspace{1pt}"])

# Add offset to groups in latex
for group_idxs in group_ranges[:-1]:
    txt = string_frame.iloc[group_idxs[1],0]
    string_frame.iloc[group_idxs[1],0] = r"\rule{0pt}{4ex} "+txt

string_frame

Unnamed: 0,Prescan props.,Slice props.,Slice view(s),Unnamed: 4,MYO,LV,RV,LA,RA,\multicolumn{1}{c}{$\mu\pm\sigma$ },Unnamed: 11,MYO.1,LV.1,RV.1,LA.1,RA.1,\multicolumn{1}{c}{$\mu\pm\sigma$ }.1,Unnamed: 18,\multicolumn{1}{c}{$\mu\pm\sigma$}
0,GT $(1.5mm)^3$,GT $(1.5mm)^2$,p2CH+p4CH,\hspace{1pt},\B 84.6,\B 89.9,87.3,\B 83.9,\B 85.6,\B 86.3 \pm 3.8,\hspace{1pt},\B 5.0,\B 5.5,6.5,\B 7.5,\B 7.2,\B 6.4 \pm 1.7,\hspace{1pt},21.4 \pm 13.8
1,GT $(1.5mm)^3$,GT $(1.5mm)^2$,2CH+4CH,\hspace{1pt},76.9,85.6,84.2,81.6,82.9,82.2 \pm 4.0,\hspace{1pt},6.1,6.5,7.6,7.9,9.5,7.5 \pm 1.8,\hspace{1pt},33.9 \pm 13.7
2,GT $(1.5mm)^3$,GT $(1.5mm)^2$,2CH+SA,\hspace{1pt},82.9,89.9,78.6,79.4,73.2,80.8 \pm 6.5,\hspace{1pt},5.6,5.6,15.4,11.6,13.7,10.4 \pm 5.4,\hspace{1pt},\B 21.3 \pm 9.5
10,GT $(1.5mm)^3$,GT $(1.5mm)^2$,RND+RND,\hspace{1pt},80.9,87.5,84.5,80.3,79.3,82.5 \pm 5.3,\hspace{1pt},7.5,7.8,10.7,10.0,11.3,9.5 \pm 3.5,\hspace{1pt},23.5 \pm 12.5
3,GT $(1.5mm)^3$,GT $(1.5mm)^2$,OPT+OPT,\hspace{1pt},74.9,86.3,\B 91.0,82.3,83.9,83.7 \pm 5.8,\hspace{1pt},7.9,7.1,\B 5.9,12.3,9.1,8.5 \pm 3.1,\hspace{1pt},29.4 \pm 16.1
4,\rule{0pt}{4ex} GT $(6mm)^3$,GT $(1.5mm)^2$,2CH+4CH,\hspace{1pt},\B 85.8,\B 90.8,88.7,83.5,80.8,\B 85.9 \pm 4.3,\hspace{1pt},\B 4.9,14.4,6.3,9.5,10.1,9.0 \pm 8.5,\hspace{1pt},\B 17.1 \pm 10.2
5,GT $(6mm)^3$,GT $(1.5mm)^2$,OPT+OPT,\hspace{1pt},75.1,83.8,\B 89.4,\B 86.0,\B 82.9,83.4 \pm 6.2,\hspace{1pt},7.2,\B 7.7,\B 5.7,\B 8.2,\B 9.6,\B 7.7 \pm 2.3,\hspace{1pt},39.9 \pm 21.9
6,\rule{0pt}{4ex} SEG $(6mm)^3$,SEG $(1.5mm)^2$,2CH+4CH,\hspace{1pt},\B 70.3,\B 81.1,\B 84.7,\B 81.9,\B 81.9,\B 80.0 \pm 6.3,\hspace{1pt},7.6,8.4,8.1,8.5,11.0,8.7 \pm 2.6,\hspace{1pt},48.4 \pm 21.5
7,SEG $(6mm)^3$,SEG $(1.5mm)^2$,OPT+OPT,\hspace{1pt},0.0,0.0,0.0,0.0,0.0,0.0 \pm 0.0,\hspace{1pt},\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0,\B 0.0 \pm 0.0,\hspace{1pt},\B 0.0 \pm 0.0
8,\rule{0pt}{4ex} SEG $(6mm)^3$,SEG $(6mm)^3$,2CH+4CH,\hspace{1pt},\B 80.0,\B 89.9,\B 86.8,\B 81.5,\B 84.0,\B 84.5 \pm 4.0,\hspace{1pt},6.0,5.3,6.7,7.5,8.5,6.8 \pm 1.6,\hspace{1pt},18.9 \pm 6.7


In [None]:
# Save to latex

PM_COL_FORMAT = "S[table-figures-decimal=1,separate-uncertainty=true,table-format=3.1(3)]"
COL_FORMAT = \
    ("c" * 9) \
    + PM_COL_FORMAT \
    + ("c" * 6) \
    + PM_COL_FORMAT \
    + "c" \
    + PM_COL_FORMAT
# 19 cols

header = [r'\multicolumn{3}{c}{\textbf{Experiment II}}'] \
    + 2*[None] \
    + [r'\hspace{1pt}'] \
    + [r'\multicolumn{6}{c}{\textbf{Dice in \% $\uparrow$}}'] + 5*[None] \
    + [r'\hspace{1pt}'] \
    + [r'\multicolumn{6}{c}{\textbf{HD95 in mm $\downarrow$}}'] + 5*[None] \
    + [r'\hspace{1pt}'] \
    + [r'\textbf{$\Delta$vol LV in \% $\downarrow$}']

latex_frame = pd.concat([pd.DataFrame(header, index=string_frame.columns).T, string_frame])

latex_frame.to_latex(
    buf="mrxcat_results.txt",
    escape=False,
    column_format=COL_FORMAT,
    index=False,
)

# Load latex file and replace & NaN with &
with open("mrxcat_results.txt", "r") as f:
    lines = f.readlines()

    lines = [l.replace("& NaN", "") for l in lines]

with open("mrxcat_results.txt", "w") as f:
    f.writelines(lines)

# Insert 5th line before third line
with open("mrxcat_results.txt", "r") as f:
    lines = f.readlines()
    lines.insert(2, lines[4])
    del lines[5]
with open("mrxcat_results.txt", "w") as f:
    f.writelines(lines)