In [1]:
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from hyperimpute.utils.serialization import (load_model_from_file,
                                             save_model_to_file)
from sklearn.preprocessing import MinMaxScaler

workspace = Path("workspace")
results_dir = Path("results")
data_dir = Path("data")

workspace.mkdir(parents=True, exist_ok=True)

warnings.filterwarnings("ignore")

cat_limit = 10
n_seeds = 5

version = "take7"
changelog = f"hyperlatent_age_tweaks"

In [2]:
def dataframe_hash(df: pd.DataFrame) -> str:
    cols = sorted(list(df.columns))
    return str(abs(pd.util.hash_pandas_object(df[cols].fillna(0)).sum()))


def augment_base_dataset(df):
    df = df.sort_values(["RID_HASH", "VISCODE"])

    return df

In [3]:
dev_set = pd.read_csv(data_dir / "dev_set.csv")
dev_set = dev_set.sort_values(["RID_HASH", "VISCODE"])
dev_set = augment_base_dataset(dev_set)

scaled_cols = [
    "MMSE",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]

scaler = MinMaxScaler().fit(dev_set[scaled_cols])
dev_set[scaled_cols] = scaler.transform(dev_set[scaled_cols])

dev_set

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,0,20,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,0,20,1.0,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,72.9,1,12,1.0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,73.4,1,12,1.0,1.0,1.0,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,73.9,1,12,1.0,1.0,1.0,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,79.8,1,19,1.0,0.0,3.0,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,83.3,1,19,1.0,0.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,0,12,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,73.1,0,12,1.0,0.0,1.0,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [4]:
static_features = ["RID_HASH", "AGE", "PTGENDER_num", "PTEDUCAT", "APOE4"]  # first age
temporal_features = [
    "RID_HASH",
    "VISCODE",
    "DX_num",
    "CDRSB",
    "MMSE",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]  #

dev_set_static = dev_set.sort_values(["RID_HASH", "VISCODE"]).drop_duplicates(
    "RID_HASH"
)[static_features]
dev_set_temporal = dev_set.sort_values(["RID_HASH", "VISCODE"])[temporal_features]

dev_set_static

Unnamed: 0,RID_HASH,AGE,PTGENDER_num,PTEDUCAT,APOE4
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,79.1,0,20,1.0
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,72.9,1,12,1.0
298,0131f7f44ff183309c590b9ff440806b20f639c90c124d...,73.9,0,12,0.0
1762,01513c9ff1e8fcc22cbfc9093845a37ee69307e3493daf...,73.4,1,12,0.0
406,01705aaf2c869203d7a8374472f5907f53f3b15f7b4faa...,70.4,0,16,0.0
...,...,...,...,...,...
2205,ff1d8cc22fb5bf2bd80e31d6d3a6cf1709562bb7e9a22f...,71.5,1,16,1.0
1593,ff21c0f13c9535e8339ce653a268b26df8e4172212ac05...,75.9,1,18,0.0
3458,ff48382bcf5922a2db52db36c791b02910015feee82505...,70.3,0,12,1.0
1438,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,74.8,1,19,0.0


In [5]:
dev_set_temporal

Unnamed: 0,RID_HASH,VISCODE,DX_num,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,1.0,1.0,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,1.0,1.0,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,1.0,3.0,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,1.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,1.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,1.0,1.0,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [6]:
dev_1 = pd.read_csv(data_dir / "dev_1.csv")
dev_1 = augment_base_dataset(dev_1)
dev_1[scaled_cols] = scaler.transform(dev_1[scaled_cols])

dev_1

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,,0.0,20.0,1.0,1.0,0.5,0.923077,0.164384,,,0.376516,,,
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,0.0,20.0,1.0,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,,1.0,12.0,,1.0,,,,,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,,1.0,12.0,,1.0,,,,,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,,1.0,12.0,,1.0,,,,,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,79.8,1.0,19.0,,0.0,,,,0.170895,,0.321346,,,
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,83.3,1.0,19.0,,0.0,,,,0.178231,,0.309095,,,
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,,12.0,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,73.1,,12.0,1.0,0.0,1.0,0.961538,0.155205,0.398451,0.608521,,0.617108,0.729087,0.638477


In [7]:
dev_2 = pd.read_csv(data_dir / "dev_2.csv")
dev_2 = augment_base_dataset(dev_2)
dev_2[scaled_cols] = scaler.transform(dev_2[scaled_cols])

dev_2

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,0.0,20.0,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,,,,1.0,,,,0.071956,0.548307,,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,72.9,,12.0,1.0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,,,12.0,1.0,1.0,1.0,1.000000,0.164384,,0.549210,,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,,,12.0,1.0,1.0,1.0,0.961538,0.109589,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,,,19.0,1.0,0.0,3.0,0.923077,0.223699,,0.357020,,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,,,19.0,1.0,0.0,3.0,0.846154,0.168904,,0.352043,,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,,12.0,,0.0,,,,0.416382,0.602438,,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,,,12.0,,0.0,,,,,,,,,


In [8]:
submission = pd.read_csv(data_dir / "sample_submission.csv")

submission.values[1]

array(['6b6a7136f42a8dbd469a201b88e2abb54a93667822761357db2f6d620da6af8a_0_Ventricles_test_A',
       40613.0818580834], dtype=object)

In [9]:
test_A = pd.read_csv(data_dir / "test_A.csv")
test_A = augment_base_dataset(test_A)
test_A[scaled_cols] = scaler.transform(test_A[scaled_cols])

test_A_gt = pd.read_csv(data_dir / "test_A_gt.csv")
test_A_gt = augment_base_dataset(test_A_gt)
test_A_gt[scaled_cols] = scaler.transform(test_A_gt[scaled_cols])

assert (test_A["VISCODE"] == test_A_gt["VISCODE"]).all()

test_A

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
247,00d5e0050fbd3b6b610f6673347232eb0862df77b5b7a8...,0,,,16.0,1.0,0.0,0.5,0.961538,0.219178,,,,,,
819,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,0,72.5,1.0,12.0,,1.0,,,,0.057498,0.612302,0.423268,0.291182,0.433004,0.329131
276,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,6,73.0,1.0,12.0,,1.0,,,,0.067972,,0.399942,,,
350,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,12,73.5,1.0,12.0,1.0,1.0,2.0,0.769231,0.365342,0.077516,,0.415324,,,
1268,024efbff9265302acd00190e57ee08ba1fe1b90f561f79...,0,,0.0,14.0,1.0,1.0,2.0,1.000000,0.164384,,,0.515223,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
841,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,0,,,18.0,1.0,1.0,1.5,0.807692,0.150685,,,,,,
330,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,6,,,18.0,1.0,1.0,1.5,0.769231,0.095890,,,,,,
939,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,24,,,18.0,1.0,1.0,1.5,0.769231,0.150685,,,,,,
119,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,48,70.9,,18.0,1.0,1.0,2.5,0.807692,0.246575,0.307697,0.420993,,0.392416,0.577719,0.403872


In [10]:
test_B = pd.read_csv(data_dir / "test_B.csv")
test_B = augment_base_dataset(test_B)
test_B[scaled_cols] = scaler.transform(test_B[scaled_cols])

test_B_gt = pd.read_csv(data_dir / "test_B_gt.csv")
test_B_gt = augment_base_dataset(test_B_gt)
test_B_gt[scaled_cols] = scaler.transform(test_B_gt[scaled_cols])

assert (test_B["VISCODE"] == test_B_gt["VISCODE"]).all()

test_B

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
1181,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,0,71.4,,15.0,0.0,2.0,0.0,0.961538,0.077671,0.085164,0.638939,,0.608113,0.424862,0.523781
1426,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,36,74.4,,15.0,0.0,2.0,0.0,1.000000,0.027397,0.089750,,,,,
1201,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,0,,0.0,,1.0,0.0,0.5,0.846154,0.196301,,0.345711,0.286043,0.312698,0.276821,0.248579
757,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,6,,0.0,,1.0,0.0,1.0,1.000000,0.283151,,0.345147,0.278219,0.378307,0.289480,0.253793
763,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,12,,0.0,,1.0,0.0,2.5,0.807692,0.168904,,0.329233,0.253372,0.352028,0.259842,0.222042
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1423,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,0,,,18.0,1.0,1.0,1.5,0.884615,0.114110,,0.502370,,0.394356,0.397160,0.531003
558,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,12,,,18.0,1.0,1.0,1.5,0.923077,0.242055,,0.519639,,0.294356,0.416522,0.545575
70,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,84,,0.0,18.0,1.0,1.0,1.5,1.000000,0.178082,,0.432054,0.483387,0.363316,0.468451,0.508440
480,ffa86109ba8684f31325842d0ff26568e105f0f63b366a...,0,66.3,,13.0,0.0,0.0,0.0,0.923077,0.118767,0.177669,,,,,


In [11]:
test_A.isna().sum()

RID_HASH          0
VISCODE           0
AGE             612
PTGENDER_num    626
PTEDUCAT         65
DX_num          428
APOE4            49
CDRSB           428
MMSE            428
ADAS13          428
Ventricles      612
Hippocampus     668
WholeBrain      626
Entorhinal      668
Fusiform        668
MidTemp         668
dtype: int64

In [12]:
test_A.columns

Index(['RID_HASH', 'VISCODE', 'AGE', 'PTGENDER_num', 'PTEDUCAT', 'DX_num',
       'APOE4', 'CDRSB', 'MMSE', 'ADAS13', 'Ventricles', 'Hippocampus',
       'WholeBrain', 'Entorhinal', 'Fusiform', 'MidTemp'],
      dtype='object')

## Emulate missingness

In [13]:
def copy_missingness(ref_data):
    ref_data_ids = ref_data["RID_HASH"].unique()

    len_to_miss = {}
    for rid in ref_data_ids:
        local_A = ref_data[ref_data["RID_HASH"] == rid]
        # print(len(local_A), local_A.isna().sum().sum())

        local_len = len(local_A)
        if local_len not in len_to_miss:
            len_to_miss[local_len] = []
        for reps in range(5):
            len_to_miss[local_len].append(local_A.notna().reset_index(drop=True))

    out_data = pd.DataFrame([], columns=dev_set.columns)
    out_data_ids = dev_set["RID_HASH"].unique()
    for rid in out_data_ids:
        local_A = dev_set[dev_set["RID_HASH"] == rid].copy().reset_index(drop=True)
        local_len = len(local_A)

        if local_len in len_to_miss and len(len_to_miss[local_len]) > 0:
            target_mask = len_to_miss[local_len].pop(0)
            out_data = pd.concat([out_data, local_A[target_mask]], ignore_index=True)
        else:
            out_data = pd.concat([out_data, local_A], ignore_index=True)

    return out_data

In [14]:
dev_sim_A = copy_missingness(test_A)

dev_sim_A

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,0,20,,1.0,,,,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,,0,20,1.0,1.0,1.5,0.923077,0.237397,,0.548307,0.366398,0.40388,0.193367,0.397291
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,72.9,1,12,,1.0,,,,0.142655,,0.235599,,,
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,,1,12,,1.0,,,,,,0.230361,,,
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,,1,12,,1.0,,,,,,0.215944,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,79.8,1,19.0,1.0,0.0,3.0,0.923077,0.223699,0.170895,0.35702,0.321346,0.310935,0.399047,0.461476
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,83.3,1,19.0,1.0,0.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.25679,0.372685,0.416478
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,0.0,12,1.0,0.0,0.5,0.884615,0.150685,0.416382,,0.636654,,,
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,,,12,1.0,0.0,1.0,0.961538,0.155205,,,,,,


In [15]:
dev_sim_B = copy_missingness(test_B)

dev_sim_B

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,,20,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,,20,1.0,1.0,1.5,0.923077,0.237397,0.071956,,,,,
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,,,12,1.0,1.0,1.0,1.0,0.123288,,0.525169,,0.513404,0.356253,0.294774
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,,,12,1.0,1.0,1.0,1.0,0.164384,,,,,,
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,,,12,1.0,1.0,1.0,0.961538,0.109589,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,,1,,1.0,0.0,3.0,0.923077,0.223699,,,0.321346,,,
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,,1,,1.0,0.0,3.0,0.846154,0.168904,,,0.309095,,,
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,,,12,1.0,0.0,0.5,0.884615,0.150685,,,,,,
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,,,12,1.0,0.0,1.0,0.961538,0.155205,,,,,,


In [16]:
dataframe_hash(dev_sim_B)

'1718898274424252771'

## Baseline imputation

In [17]:
from hyperimpute.plugins.imputers import Imputers

# VISCODE 6 * x -> AGE 0.5 * x

const_by_patient = ["PTGENDER_num", "PTEDUCAT", "APOE4"]


def prepare_consts(train_data, test_data):
    test_data = test_data.copy()
    train_data = train_data.copy()

    train_data = train_data.sort_values(["RID_HASH", "VISCODE"])
    test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

    for item in test_data.groupby("RID_HASH"):
        local = item[1]

        # fill consts
        for col in const_by_patient:
            if len(local[col].unique()) == 1:
                continue
            rid = local["RID_HASH"].unique()[0]

            val = local[col][~local[col].isna()].unique()[0]
            local[col] = local[col].fillna(val)
            test_data.loc[test_data["RID_HASH"] == rid, col] = test_data[
                test_data["RID_HASH"] == rid
            ][col].fillna(val)
            assert len(local[col].unique()) == 1, col

    return test_data


def prepare_age(train_data, test_data):
    test_data = test_data.copy()
    train_data = train_data.copy()

    train_data = train_data.sort_values(["RID_HASH", "VISCODE"])
    test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

    col = "AGE"

    for rid in test_data["RID_HASH"].unique():
        local = test_data[test_data["RID_HASH"] == rid]

        # fill age
        ages = local["AGE"]
        if ages.isna().sum() == 0:
            continue

        if ages.isna().sum() == len(ages):
            continue

        # forward impute age
        prev_viscode = 0
        prev_age = 0
        for idx, row in local.iterrows():
            current_viscode = row["VISCODE"]
            local_idx = (test_data["VISCODE"] == current_viscode) & (
                test_data["RID_HASH"] == rid
            )
            if prev_age > 0 and prev_age == prev_age:
                pred_age = (current_viscode - prev_viscode) / 6 * 0.5 + prev_age
            else:
                pred_age = row[col]

            if pred_age == pred_age:
                # print("forward imputed", pred_age, current_viscode)
                test_data.loc[local_idx, col] = test_data.loc[local_idx][col].fillna(
                    pred_age
                )

            prev_viscode = row["VISCODE"]
            prev_age = pred_age

        # reverse impute age
        prev_viscode = 0
        prev_age = 0
        for idx, row in local.iloc[::-1].iterrows():
            current_viscode = row["VISCODE"]
            local_idx = (test_data["VISCODE"] == current_viscode) & (
                test_data["RID_HASH"] == rid
            )

            if prev_age > 0 and prev_age == prev_age:
                pred_age = prev_age - (prev_viscode - current_viscode) / 6 * 0.5
            else:
                pred_age = row[col]

            if pred_age == pred_age:
                # print("reversed imputed", pred_age, current_viscode)
                test_data.loc[local_idx, col] = test_data.loc[local_idx][col].fillna(
                    pred_age
                )

            prev_viscode = row["VISCODE"]
            prev_age = pred_age

        # print(test_data[(test_data["RID_HASH"] == rid)][["VISCODE", "AGE"]])
    return test_data


def interm_imputation(test_data, forward_first = True):
    test_data = test_data.copy()

    for rid in test_data["RID_HASH"].unique():
        local = test_data[test_data["RID_HASH"] == rid]

        if forward_first:
            local = local.ffill().bfill()
        else:
            local = local.bfill().ffill()

        test_data.loc[test_data["RID_HASH"] == rid] = local

    return test_data


def full_imputation(train_data, test_data, random_state: int = 0):
    imputed_test_data = test_data.copy()

    imputer_kwargs = {
        "optimizer": "simple",
        "classifier_seed": ["xgboost"],
        "regression_seed": ["xgboost_regressor"],
        "class_threshold": cat_limit,
        "random_state": random_state,
    }

    imputer = Imputers().get(
        "hyperimpute",
        **imputer_kwargs,
    )
    imputation_input = pd.concat([train_data, test_data], ignore_index=True)
    imputed_test_data = imputer.fit_transform(imputation_input)
    imputed_test_data = imputed_test_data.tail(len(test_data))

    return imputed_test_data


def evaluate_static_imputation(train_data, test_data, static_imputation):
    test_data = test_data.copy()
    train_data = train_data.copy()

    train_data = train_data.sort_values(["RID_HASH", "VISCODE"])
    test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

    for rid in test_data["RID_HASH"].unique():
        patient = test_data[test_data["RID_HASH"] == rid]
        misses = []
        viscodes = []
        for idx, row in patient.iterrows():
            misses.append(row.isna().sum())
            viscodes.append(row["VISCODE"])
        cidx = np.argmin(misses)

        current_viscode = viscodes[cidx]
        local_idx = (test_data["VISCODE"] == current_viscode) & (
            test_data["RID_HASH"] == rid
        )
        imputed_idx = (static_imputation["VISCODE"] == current_viscode) & (
            static_imputation["RID_HASH"] == rid
        )

        if len(test_data[local_idx]) == 0:
            continue

        for col in test_data.columns:
            val = test_data.loc[local_idx][col].values[0]
            if val == val:
                continue
            imputed_val = static_imputation.loc[imputed_idx][col].values[0]
            test_data.loc[local_idx, col] = imputed_val

            # print("imputed", test_data.loc[local_idx, col])

    return test_data


def impute_data_step(
    train_data,
    test_data,
    use_longitudinal=True,
    static_strategy="missmin",
    random_state: int = 0,
):
    test_id = dataframe_hash(test_data)
    train_id = dataframe_hash(train_data)

    print(" >>> Evaluate constants", test_id, test_data.isna().sum().sum())
    test_data = prepare_consts(train_data, test_data)
    test_data = prepare_age(train_data, test_data)
    test_data = interm_imputation(test_data)

    print(
        " >>> Evaluate static imputation",
        test_id,
        test_data.isna().sum().sum(),
        static_strategy,
    )
    static_imputation = full_imputation(
        train_data, test_data, random_state=random_state
    )
    test_data = evaluate_static_imputation(train_data, test_data, static_imputation)

    print(" >>> Evaluate constants take 2", test_id, test_data.isna().sum().sum())
    test_data = prepare_consts(train_data, test_data)
    test_data = prepare_age(train_data, test_data)

    test_data = interm_imputation(test_data)

    assert test_data.isna().sum().sum() == 0, test_data

    return test_data


def impute_baseline_data(train_data, test_data, random_state: int = 0):
    test_id = dataframe_hash(test_data)
    train_id = dataframe_hash(train_data)

    bkp_file = (
        workspace
        / f"seed_imputation_{version}_{train_id}_{test_id}_{random_state}_catlimit{cat_limit}.csv"
    )

    if bkp_file.exists():
        print("Using cached ", bkp_file)

        return pd.read_csv(bkp_file)

    print("Evaluate ", bkp_file)
    imputed_test_data = impute_data_step(
        train_data,
        test_data,
        random_state=random_state,
    )
    imputed_test_data.to_csv(bkp_file, index=None)

    return imputed_test_data

In [18]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

nn_scaled_cols = ["AGE", "CDRSB", "PTEDUCAT"]
nn_scaler = MinMaxScaler().fit(dev_set[nn_scaled_cols])
# tabular_encoder_static = TabularEncoder(categorical_limit = cat_limit).fit(dev_set_static.reset_index(drop = True))
# tabular_encoder_temporal = TabularEncoder(categorical_limit = cat_limit).fit(dev_set_temporal.reset_index(drop = True))


def mask_columns_map(s: str):
    return f"masked_{s}"


def generate_testcase(ref_df):
    baseline_imputation = impute_baseline_data(dev_set, ref_df).reset_index(drop=True)

    baseline_imputation_nn = baseline_imputation.copy()
    baseline_imputation_nn[nn_scaled_cols] = nn_scaler.transform(
        baseline_imputation_nn[nn_scaled_cols]
    )

    baseline_imputation_static = baseline_imputation_nn.sort_values(
        ["RID_HASH", "VISCODE"]
    )[static_features]
    baseline_imputation_temporal = baseline_imputation_nn.sort_values(
        ["RID_HASH", "VISCODE"]
    )[temporal_features]

    mask = (
        ref_df.isna()
        .astype(int)
        .drop(columns=["RID_HASH", "VISCODE"])
        .rename(mask_columns_map, axis="columns")
    ).reset_index(drop=True)

    full_input = pd.concat(
        [
            baseline_imputation_static.drop(columns=["AGE"]),
            baseline_imputation_temporal.drop(columns=["RID_HASH"]),
            mask,
        ],
        axis=1,
    )

    return baseline_imputation, full_input

In [19]:
testcases = []
for src in [
    dev_1,
    dev_2,
    dev_sim_A,
    dev_sim_B,
]:
    _, src_input = generate_testcase(src)
    testcases.append(src_input)

full_output = dev_set.copy()
full_output[nn_scaled_cols] = nn_scaler.transform(full_output[nn_scaled_cols])

full_output_static = full_output.sort_values(["RID_HASH", "VISCODE"])[static_features]
# full_output_static =  tabular_encoder_static.transform(full_output_static.reset_index(drop = True))

full_output_temporal = full_output.sort_values(["RID_HASH", "VISCODE"])[
    temporal_features
]
# full_output_temporal =  tabular_encoder_temporal.transform(full_output_temporal.reset_index(drop = True))

full_output_static

Using cached  workspace/seed_imputation_take7_4097467927144633164_8477102391824886331_0_catlimit10.csv
Using cached  workspace/seed_imputation_take7_4097467927144633164_6199915737732549321_0_catlimit10.csv
Using cached  workspace/seed_imputation_take7_4097467927144633164_1578441905907828792_0_catlimit10.csv
Using cached  workspace/seed_imputation_take7_4097467927144633164_1718898274424252771_0_catlimit10.csv


Unnamed: 0,RID_HASH,AGE,PTGENDER_num,PTEDUCAT,APOE4
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0.574419,0,1.0000,1.0
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0.586047,0,1.0000,1.0
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0.430233,1,0.5000,1.0
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0.441860,1,0.5000,1.0
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0.453488,1,0.5000,1.0
...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,0.590698,1,0.9375,0.0
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,0.672093,1,0.9375,0.0
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0.411628,0,0.5000,0.0
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0.434884,0,0.5000,0.0


In [20]:
full_output_temporal

Unnamed: 0,RID_HASH,VISCODE,DX_num,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,1.0,0.03125,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,1.0,0.09375,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,1.0,0.06250,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,1.0,0.06250,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,1.0,0.06250,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,1.0,0.18750,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,1.0,0.18750,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,1.0,0.03125,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,1.0,0.06250,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [21]:
testcases[0]

Unnamed: 0,RID_HASH,PTGENDER_num,PTEDUCAT,APOE4,VISCODE,DX_num,CDRSB,MMSE,ADAS13,Ventricles,...,masked_APOE4,masked_CDRSB,masked_MMSE,masked_ADAS13,masked_Ventricles,masked_Hippocampus,masked_WholeBrain,masked_Entorhinal,masked_Fusiform,masked_MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0.0,1.0000,1.0,0,1.0,0.031250,0.923077,0.164384,0.071956,...,0,0,0,0,1,1,0,1,1,1
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0.0,1.0000,1.0,6,1.0,0.093750,0.923077,0.237397,0.071956,...,0,0,0,0,0,0,0,0,0,0
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,1.0,0.5000,1.0,0,1.0,0.073349,0.984970,0.133982,0.187193,...,0,1,1,1,1,0,0,0,0,0
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,1.0,0.5000,1.0,6,1.0,0.073349,0.984970,0.133982,0.187193,...,0,1,1,1,1,0,0,0,0,0
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,1.0,0.5000,1.0,12,1.0,0.073349,0.984970,0.133982,0.187193,...,0,1,1,1,1,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,1.0,0.9375,0.0,60,1.0,0.169053,0.922065,0.234043,0.170895,...,0,1,1,1,0,1,0,1,1,1
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,1.0,0.9375,0.0,102,1.0,0.169053,0.922065,0.234043,0.178231,...,0,1,1,1,0,1,0,1,1,1
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0.0,0.5000,0.0,0,1.0,0.031250,0.884615,0.150685,0.416382,...,0,0,0,0,0,0,1,0,0,0
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0.0,0.5000,0.0,12,1.0,0.062500,0.961538,0.155205,0.398451,...,0,0,0,0,0,0,1,0,0,0


In [22]:
# activation_layout_static = tabular_encoder_static.activation_layout(discrete_activation = "softmax")[1 : ]
# activation_layout_temporal = tabular_encoder_temporal.activation_layout(discrete_activation = "softmax")[2 : ]

# activation_layout_static

In [23]:
# activation_layout_temporal

In [24]:
from hyperimpute.utils.serialization import (load_model_from_file,
                                             save_model_to_file)
from sklearn.model_selection import train_test_split

from ts_imputer import TimeSeriesImputer, modes

n_hidden_units = 150
n_hidden_layers = 2

for mode in modes:
    print("Training", mode)
    bkp_file = workspace / f"nn_imputer_mode_{mode}_{n_hidden_layers}_{n_hidden_units}.bkp"
    #bkp_file = workspace / f"nn_imputer_mode_{mode}.bkp"
    

    if bkp_file.exists():
        continue

    imputer = TimeSeriesImputer(
        n_units_in=testcases[0].shape[-1] - 1,  # DROP RID_HASH
        n_units_out_static=full_output_static.shape[-1] - 1,  # DROP RID_HASH
        n_units_out_temporal=full_output_temporal.shape[-1]- 2,  # DROP RID_HASH and VISCODE
        nonlin="leaky_relu",
        dropout=0.05,
        # nonlin_out_static = activation_layout_static,
        # nonlin_out_temporal = activation_layout_temporal,
        n_layers_hidden=2,
        n_units_hidden=150,
        n_iter=10000,
        mode=mode,
        residual=False,
    )

    for outer_iter in range(5):
        for idx, full_input in enumerate(testcases):
            (
                train_input,
                test_input,
                train_output_static,
                test_output_static,
                train_output_temporal,
                test_output_temporal,
            ) = train_test_split(
                full_input, full_output_static, full_output_temporal, random_state=0
            )
            imputer.fit(
                train_input,
                train_output_static,
                train_output_temporal,
                test_input,
                test_output_static,
                test_output_temporal,
            )
    save_model_to_file(bkp_file, imputer)

Training LSTM
Training GRU
Training RNN
Training Transformer
Training XceptionTime
Training ResCNN


In [25]:
from hyperimpute.utils.serialization import (load_model_from_file,
                                             save_model_to_file)


def get_latent_imputer(mode):
    #bkp_file = workspace / f"nn_imputer_mode_{mode}.bkp"
    bkp_file = workspace / f"nn_imputer_mode_{mode}_{n_hidden_layers}_{n_hidden_units}.bkp"

    return load_model_from_file(bkp_file)

def generate_latent_repr(ref_df, mode):
    imputer = get_latent_imputer(mode=mode)

    _, test_input = generate_testcase(ref_df)
    test_id = dataframe_hash(test_input)
    bkp_file = workspace / f"latent_repr_testcase_{test_id}_{mode}_{n_hidden_layers}_{n_hidden_units}.bkp"
    if bkp_file.exists():
        latent = load_model_from_file(bkp_file)
    else:
        latent = imputer.predict_latent(test_input)
        save_model_to_file(bkp_file, latent)
    return latent

def generate_training_latent_repr(mode):
    imputer = get_latent_imputer(mode=mode)

    output = []
    for idx, full_input in enumerate(testcases):
        test_id = dataframe_hash(full_input)
        bkp_file = workspace / f"latent_repr_testcase_{test_id}_{mode}_{n_hidden_layers}_{n_hidden_units}.bkp"
        if bkp_file.exists():
            latent = load_model_from_file(bkp_file)
        else:
            latent = imputer.predict_latent(full_input)
            save_model_to_file(bkp_file, latent)
        output.append(latent)
    return output


train_latents = {}
for mode in modes:
    train_latents[mode] = generate_training_latent_repr(mode=mode)

train_latents["LSTM"][0]

Unnamed: 0,RID_HASH,0,1,2,3,4,5,6,7,8,...,140,141,142,143,144,145,146,147,148,149
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158,-0.00055,-0.004908,0.000454,0.00693,0.001649,-0.003626,0.000997,-0.000456,-0.000113,...,-0.000112,-0.000045,-0.005524,-0.000863,-0.000003,-0.000524,0.00789,0.002598,-0.001932,-0.000063
1,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c,-0.000603,-0.004097,0.000817,0.007188,0.003833,-0.001723,0.001478,-0.000897,-0.000102,...,-0.000004,-0.000035,-0.005342,-0.001108,-0.000002,-0.000749,0.006699,0.004134,-0.006042,0.00002
2,0131f7f44ff183309c590b9ff440806b20f639c90c124da03f0c76b377cd6e2b,-0.000488,-0.005174,0.00055,0.010375,0.002223,-0.005376,0.000496,0.000017,-0.000122,...,-0.000103,-0.000041,-0.005251,-0.000597,-0.000004,-0.000286,0.011368,0.002295,-0.000574,-0.000043
3,01513c9ff1e8fcc22cbfc9093845a37ee69307e3493daf0697429bd4d177d5e6,-0.000591,-0.003357,0.00083,0.007135,0.002256,-0.001296,0.001315,-0.000735,-0.000112,...,0.000091,-0.000043,-0.004797,-0.001007,-0.000002,-0.000684,0.005734,0.00377,-0.007016,0.000042
4,01705aaf2c869203d7a8374472f5907f53f3b15f7b4faa4af169b8843859c4cb,-0.000506,-0.00419,0.000504,0.007887,0.00111,-0.003703,0.000709,-0.000193,-0.000119,...,-0.000046,-0.000052,-0.00481,-0.000696,-0.000003,-0.000392,0.008026,0.002345,-0.002405,-0.000045
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1221,ff1d8cc22fb5bf2bd80e31d6d3a6cf1709562bb7e9a22f405074a77fb34ac067,-0.000523,-0.005198,0.000586,0.008612,0.003155,-0.004147,0.000922,-0.000411,-0.000109,...,-0.0001,-0.000048,-0.005529,-0.00081,-0.000003,-0.00048,0.009766,0.002979,-0.001673,-0.000045
1222,ff21c0f13c9535e8339ce653a268b26df8e4172212ac0588b1e6b69cd257dfd8,-0.000669,-0.002873,0.00091,0.005679,0.001979,0.000205,0.001735,-0.001075,-0.000111,...,0.000129,-0.000028,-0.004969,-0.001248,-0.000001,-0.000891,0.003468,0.004301,-0.009551,0.000075
1223,ff48382bcf5922a2db52db36c791b02910015feee82505f411dd74b35cb0f4ce,-0.000503,-0.005577,0.000553,0.009302,0.003248,-0.004886,0.00078,-0.000287,-0.000111,...,-0.000127,-0.00005,-0.005631,-0.000734,-0.000004,-0.00041,0.01094,0.002758,-0.000522,-0.000057
1224,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803bcd30ea99c58dcf91d7,-0.000711,-0.002992,0.001008,0.006133,0.002391,0.000379,0.001893,-0.001184,-0.000113,...,0.000134,-0.000012,-0.005255,-0.001348,-0.0,-0.000969,0.003674,0.004663,-0.010495,0.000104


In [26]:
def map_latent_columns(s: str):
    return f"latent_{s}"

def map_prefix(prefix: str):
    def map_cbk(s: str):
        return f"{prefix}_{s}"

    return map_cbk

def expand_input_df(ref_df, mode):
    ref_id = dataframe_hash(ref_df)
    bkp_file = workspace / f"latent_ext_version_{ref_id}_{mode}_{n_hidden_layers}_{n_hidden_units}_v3.csv"
    
    #if bkp_file.exists():
    #    return pd.read_csv(bkp_file)
    
    latent = generate_latent_repr(ref_df, mode = mode)
    latent = latent.rename(map_latent_columns, axis = "columns")
    
    output = ref_df.copy()
    output[latent.columns] = 0
    
    output["total_visits"] = 0
    output["last_visit"] = 0
    
    for rid in ref_df["RID_HASH"].unique():
        local_latent = latent[latent["latent_RID_HASH"] == rid].values
        assert len(local_latent) == 1
        assert len(local_latent[0]) == len(latent.columns)
        
        output.loc[output["RID_HASH"] == rid, latent.columns] = local_latent

        visits = len(ref_df[ref_df["RID_HASH"] == rid])
        last_visit = ref_df[ref_df["RID_HASH"] == rid]["VISCODE"].max()

        output.loc[output["RID_HASH"] == rid, "total_visits"] = visits
        output.loc[output["RID_HASH"] == rid, "last_visit"] = last_visit
        
    output = output.drop(columns = ["latent_RID_HASH"])
    output = output.reset_index(drop = True)
    
#     fb_imputation = interm_imputation(ref_df, forward_first = True)[temporal_features]
#     fb_imputation = fb_imputation.rename(map_prefix("fb"), axis = "columns")
#     assert (fb_imputation["fb_RID_HASH"].values == ref_df["RID_HASH"].values).all()
#     fb_imputation = fb_imputation.drop(columns = ["fb_RID_HASH", "fb_VISCODE"])
#     fb_imputation = fb_imputation.fillna(fb_imputation.mean())
#     output[fb_imputation.columns] = fb_imputation.values

#     bf_imputation = interm_imputation(ref_df, forward_first = False)
#     bf_imputation = bf_imputation.rename(map_prefix("bf"), axis = "columns")
#     assert (bf_imputation["bf_RID_HASH"].values == ref_df["RID_HASH"].values).all()
#     bf_imputation = bf_imputation.drop(columns = ["bf_RID_HASH", "bf_VISCODE"])
#     bf_imputation = bf_imputation.fillna(bf_imputation.mean())
#     output[bf_imputation.columns] = bf_imputation.values
    
    output.to_csv(bkp_file, index = None)
    return output

def prepare_imputation_input(ref_df, mode):
    ext_ref_df = expand_input_df(ref_df, mode = mode)

    ext_ref_df = prepare_consts(dev_set, ext_ref_df)
    ext_ref_df = prepare_age(dev_set, ext_ref_df)
    
#     for rid in ext_ref_df["RID_HASH"].unique():
#         patient = ext_ref_df[ext_ref_df["RID_HASH"] == rid]
        
#         if patient["AGE"].isna().sum() == 0:
#             continue
        
#         if len(patient) > 1:
#             continue
            
#         first_viscode = patient["VISCODE"].min()
#         patient_index = (ext_ref_df["RID_HASH"] == rid) & (ext_ref_df["VISCODE"] == first_viscode)
        
#         ext_ref_df.loc[patient_index, "AGE"] = dev_set["AGE"].mean().round(1)

#     ext_ref_df = prepare_age(dev_set, ext_ref_df)

    return ext_ref_df
    
eval_mode = "Transformer"
dev_1_ext = prepare_imputation_input(dev_1, mode = eval_mode)
dev_2_ext = prepare_imputation_input(dev_2, mode = eval_mode)
#dev_sim_A_ext = prepare_imputation_input(dev_sim_A, mode = eval_mode)
#dev_sim_B_ext = prepare_imputation_input(dev_sim_B, mode = eval_mode)
test_A_ext = prepare_imputation_input(test_A, mode = eval_mode)
test_B_ext = prepare_imputation_input(test_B, mode = eval_mode)
dev_set_ext = prepare_imputation_input(dev_set, mode = eval_mode)

dev_1_ext

Using cached  workspace/seed_imputation_take7_4097467927144633164_8477102391824886331_0_catlimit10.csv
Using cached  workspace/seed_imputation_take7_4097467927144633164_6199915737732549321_0_catlimit10.csv
Using cached  workspace/seed_imputation_take7_4097467927144633164_8499560854800192759_0_catlimit10.csv
Using cached  workspace/seed_imputation_take7_4097467927144633164_7802728617463065696_0_catlimit10.csv
Using cached  workspace/seed_imputation_take7_4097467927144633164_4097467927144633164_0_catlimit10.csv


Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,...,latent_142,latent_143,latent_144,latent_145,latent_146,latent_147,latent_148,latent_149,total_visits,last_visit
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158,0,79.1,0.0,20.0,1.0,1.0,0.5,0.923077,0.164384,...,0.031356,0.029956,-0.026714,-0.000240,-0.045432,0.006648,0.035536,0.002062,2,6
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158,6,79.6,0.0,20.0,1.0,1.0,1.5,0.923077,0.237397,...,0.031356,0.029956,-0.026714,-0.000240,-0.045432,0.006648,0.035536,0.002062,2,6
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c,0,72.9,1.0,12.0,,1.0,,,,...,0.015364,0.006526,-0.011569,0.016994,0.017259,0.014446,0.002768,-0.010133,6,60
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c,6,73.4,1.0,12.0,,1.0,,,,...,0.015364,0.006526,-0.011569,0.016994,0.017259,0.014446,0.002768,-0.010133,6,60
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c,12,73.9,1.0,12.0,,1.0,,,,...,0.015364,0.006526,-0.011569,0.016994,0.017259,0.014446,0.002768,-0.010133,6,60
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803bcd30ea99c58dcf91d7,60,79.8,1.0,19.0,,0.0,,,,...,0.019282,0.017428,-0.016017,0.001812,-0.025430,0.010571,0.038981,0.003129,7,102
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803bcd30ea99c58dcf91d7,102,83.3,1.0,19.0,,0.0,,,,...,0.019282,0.017428,-0.016017,0.001812,-0.025430,0.010571,0.038981,0.003129,7,102
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c81b37c40ecf646cb0c6,0,72.1,0.0,12.0,1.0,0.0,0.5,0.884615,0.150685,...,0.017718,0.004001,-0.014053,0.012950,0.014980,0.022750,0.009605,-0.008867,3,24
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c81b37c40ecf646cb0c6,12,73.1,0.0,12.0,1.0,0.0,1.0,0.961538,0.155205,...,0.017718,0.004001,-0.014053,0.012950,0.014980,0.022750,0.009605,-0.008867,3,24


In [27]:
from hyperimpute.plugins.imputers import Imputers

def full_imputation(ref_df, imputer = "hyperimpute"):
    imputed_test_data = ref_df.copy()
    if imputer == "hyperimpute":
        imputer_kwargs = {
            "optimizer": "simple",
            "classifier_seed": ["catboost"],
            "regression_seed": ["catboost_regressor", "xgboost_regressor"],
            "class_threshold": cat_limit,
        }
        imputer = Imputers().get(
            "hyperimpute",
            **imputer_kwargs,
        )
    else:
        imputer = Imputers().get(
            imputer
        )
    imputed_test_data = imputer.fit_transform(imputed_test_data)

    return imputed_test_data


all_dev_sets = pd.concat([
    dev_set_ext,
    dev_1_ext, dev_2_ext,
    #dev_sim_A_ext, dev_sim_B_ext,
    test_A_ext, test_B_ext,
], ignore_index = True)

all_dev_rids = all_dev_sets["RID_HASH"]
all_dev_sets = all_dev_sets.drop(columns = ["RID_HASH"])
all_dev_sets

imputer = "hyperimpute"
all_dev_sets_hash = dataframe_hash(all_dev_sets)
interm_bkp_file = workspace / f"interm_imputation_{all_dev_sets_hash}_{eval_mode}_{n_hidden_layers}_{n_hidden_units}_{imputer}.csv"


if interm_bkp_file.exists():
    all_dev_sets_imputed = pd.read_csv(interm_bkp_file)
else:
    all_dev_sets_imputed = full_imputation(all_dev_sets, imputer)
    all_dev_sets_imputed.to_csv(interm_bkp_file, index = None)
    

cols = list(all_dev_sets.columns)
all_dev_sets_imputed.columns = cols

all_dev_sets_imputed["RID_HASH"] = all_dev_rids.values

all_dev_sets_imputed

Unnamed: 0,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,...,latent_143,latent_144,latent_145,latent_146,latent_147,latent_148,latent_149,total_visits,last_visit,RID_HASH
0,0.0,79.100000,0.0,20.0,1.0,1.0,0.5,0.923077,0.164384,0.071871,...,0.019390,-0.022054,0.002527,-0.038565,0.016597,0.031256,0.001197,2.0,6.0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158
1,6.0,79.600000,0.0,20.0,1.0,1.0,1.5,0.923077,0.237397,0.071956,...,0.019390,-0.022054,0.002527,-0.038565,0.016597,0.031256,0.001197,2.0,6.0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158
2,0.0,72.900000,1.0,12.0,1.0,1.0,1.0,1.000000,0.123288,0.142655,...,0.009699,-0.010656,0.018631,0.009689,0.007170,0.001198,-0.011246,6.0,60.0,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c
3,6.0,73.400000,1.0,12.0,1.0,1.0,1.0,1.000000,0.164384,0.144729,...,0.009699,-0.010656,0.018631,0.009689,0.007170,0.001198,-0.011246,6.0,60.0,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c
4,12.0,73.900000,1.0,12.0,1.0,1.0,1.0,0.961538,0.109589,0.155550,...,0.009699,-0.010656,0.018631,0.009689,0.007170,0.001198,-0.011246,6.0,60.0,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15096,0.0,76.792175,0.0,18.0,1.0,1.0,1.5,0.884615,0.114110,0.257823,...,0.026051,-0.022910,0.000029,-0.029419,0.002335,0.023245,0.000420,3.0,84.0,ff4eb5a64e2b89861d5dea81190669893070b227f3a33577708f31b910ee5d96
15097,12.0,77.375237,0.0,18.0,1.0,1.0,1.5,0.923077,0.242055,0.330060,...,0.026051,-0.022910,0.000029,-0.029419,0.002335,0.023245,0.000420,3.0,84.0,ff4eb5a64e2b89861d5dea81190669893070b227f3a33577708f31b910ee5d96
15098,84.0,83.421364,0.0,18.0,1.0,1.0,1.5,1.000000,0.178082,0.483653,...,0.026051,-0.022910,0.000029,-0.029419,0.002335,0.023245,0.000420,3.0,84.0,ff4eb5a64e2b89861d5dea81190669893070b227f3a33577708f31b910ee5d96
15099,0.0,66.300000,1.0,13.0,0.0,0.0,0.0,0.923077,0.118767,0.177669,...,-0.000899,-0.000812,0.009761,0.035019,0.000149,-0.011346,-0.007122,2.0,24.0,ffa86109ba8684f31325842d0ff26568e105f0f63b366acd4c77c0d2ece69a2f


In [28]:
dev_set_first_visit = dev_set[dev_set["VISCODE"] == 0]
dev_1_first_visit = dev_1[dev_1["VISCODE"] == 0]
dev_2_first_visit = dev_2[dev_2["VISCODE"] == 0]

test_A_first_visit = test_A[test_A["VISCODE"] == 0]
test_B_first_visit = test_B[test_B["VISCODE"] == 0]

first_visits = pd.concat([
    dev_set_first_visit,
    dev_1_first_visit,
    dev_2_first_visit,
    test_A_first_visit,
    test_B_first_visit,
], ignore_index = True)

first_visits_imputed = full_imputation(first_visits)

first_visits_imputed

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158,0,79.100000,0.0,20.000000,1.0,1.0,0.500000,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
1,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c,0,72.900000,1.0,12.000000,1.0,1.0,1.000000,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
2,0131f7f44ff183309c590b9ff440806b20f639c90c124da03f0c76b377cd6e2b,0,73.900000,0.0,12.000000,1.0,0.0,0.500000,0.884615,0.305890,0.289870,0.393341,0.506837,0.416049,0.552102,0.519123
3,01513c9ff1e8fcc22cbfc9093845a37ee69307e3493daf0697429bd4d177d5e6,0,73.400000,1.0,12.000000,1.0,0.0,1.000000,0.923077,0.073014,0.303962,0.479910,0.362519,0.600882,0.421089,0.448699
4,01705aaf2c869203d7a8374472f5907f53f3b15f7b4faa4af169b8843859c4cb,0,70.400000,0.0,16.000000,2.0,0.0,5.000000,0.730769,0.296849,0.342653,0.346072,0.544908,0.417813,0.524649,0.455579
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3917,fe82ebf6fe6c75cd24f06574aba0dca66b7437ff3e6c4026cdb3878b1971e0cf,0,81.900000,1.0,16.000000,1.0,1.0,2.561409,0.840629,0.358562,0.589613,0.356546,0.333930,0.159259,0.152708,0.298876
3918,fec7a837563be0867cf942005a22eef6b45d646e56fa81c03087a3f50803a88a,0,81.900000,0.0,16.672003,2.0,1.0,1.000000,0.769231,0.251096,0.277787,0.504393,0.500762,0.513678,0.540986,0.458125
3919,ff09fabfcd92e4039749dd4c1b0af3f1438f9a2cb85401ec5f8c19f6ff09677d,0,73.881142,1.0,13.624475,1.0,0.0,2.000000,0.884615,0.397260,0.211863,0.402800,0.490408,0.360247,0.428215,0.575683
3920,ff4eb5a64e2b89861d5dea81190669893070b227f3a33577708f31b910ee5d96,0,59.003201,1.0,18.000000,1.0,1.0,1.500000,0.884615,0.114110,0.054334,0.502370,0.364933,0.394356,0.397160,0.531003


In [29]:
all_dev_sets_imputed[all_dev_sets_imputed["VISCODE"] == 0]

Unnamed: 0,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,...,latent_143,latent_144,latent_145,latent_146,latent_147,latent_148,latent_149,total_visits,last_visit,RID_HASH
0,0.0,79.100000,0.0,20.000000,1.0,1.0,0.500000,0.923077,0.164384,0.071871,...,0.019390,-0.022054,0.002527,-0.038565,0.016597,0.031256,0.001197,2.0,6.0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158
2,0.0,72.900000,1.0,12.000000,1.0,1.0,1.000000,1.000000,0.123288,0.142655,...,0.009699,-0.010656,0.018631,0.009689,0.007170,0.001198,-0.011246,6.0,60.0,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c
8,0.0,73.900000,0.0,12.000000,1.0,0.0,0.500000,0.884615,0.305890,0.289870,...,0.001109,-0.005188,0.017499,0.018854,0.006221,-0.010471,-0.013342,8.0,120.0,0131f7f44ff183309c590b9ff440806b20f639c90c124da03f0c76b377cd6e2b
16,0.0,73.400000,1.0,12.000000,1.0,0.0,1.000000,0.923077,0.073014,0.303962,...,0.005184,-0.006846,0.016963,0.017519,0.006961,0.002465,-0.010239,10.0,120.0,01513c9ff1e8fcc22cbfc9093845a37ee69307e3493daf0697429bd4d177d5e6
26,0.0,70.400000,0.0,16.000000,2.0,0.0,5.000000,0.730769,0.296849,0.342653,...,0.019974,-0.019008,0.007587,-0.035680,0.010265,0.031221,-0.003105,2.0,12.0,01705aaf2c869203d7a8374472f5907f53f3b15f7b4faa4af169b8843859c4cb
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15080,0.0,81.900000,0.0,16.000000,2.0,1.0,5.819269,0.606758,0.513758,0.589613,...,0.023188,-0.016623,0.008465,-0.050981,-0.000699,0.024377,-0.004088,2.0,6.0,fe82ebf6fe6c75cd24f06574aba0dca66b7437ff3e6c4026cdb3878b1971e0cf
15082,0.0,81.900000,0.0,14.686128,2.0,1.0,1.000000,0.769231,0.251096,0.277787,...,0.026414,-0.023963,0.008275,-0.034905,0.006983,0.027018,-0.005210,3.0,24.0,fec7a837563be0867cf942005a22eef6b45d646e56fa81c03087a3f50803a88a
15085,0.0,76.423119,1.0,12.000000,1.0,0.0,2.000000,0.884615,0.397260,0.360508,...,0.018344,-0.011875,0.013646,-0.009061,-0.004112,0.014994,-0.008767,11.0,144.0,ff09fabfcd92e4039749dd4c1b0af3f1438f9a2cb85401ec5f8c19f6ff09677d
15096,0.0,76.792175,0.0,18.000000,1.0,1.0,1.500000,0.884615,0.114110,0.257823,...,0.026051,-0.022910,0.000029,-0.029419,0.002335,0.023245,0.000420,3.0,84.0,ff4eb5a64e2b89861d5dea81190669893070b227f3a33577708f31b910ee5d96


In [30]:
# def fill_first_visit(test_data, imputed_df):
#     test_data = test_data.copy()
#     test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

#     for rid in test_data["RID_HASH"].unique():
#         local_idx = (test_data["VISCODE"] == 0) & (
#             test_data["RID_HASH"] == rid
#         )
#         imputed_idx = (imputed_df["VISCODE"] == 0) & (
#             imputed_df["RID_HASH"] == rid
#         )

#         if len(test_data[local_idx]) == 0:
#             continue

#         for col in test_data.columns:
#             val = test_data.loc[local_idx][col].values[0]
#             if val == val:
#                 continue
#             imputed_val = imputed_df.loc[imputed_idx][col].values[0]
#             test_data.loc[local_idx, col] = imputed_val

#     return test_data

def fill_only_one_line(test_data, imputed_df):
    test_data = test_data.copy()
    test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

    for rid in test_data["RID_HASH"].unique():
        patient = test_data[test_data["RID_HASH"] == rid]
        misses = []
        viscodes = []
        for idx, row in patient.iterrows():
            misses.append(row.isna().sum())
            viscodes.append(row["VISCODE"])
        cidx = np.argmin(misses)

        current_viscode = viscodes[cidx]
        local_idx = (test_data["VISCODE"] == current_viscode) & (
            test_data["RID_HASH"] == rid
        )
        imputed_idx = (imputed_df["VISCODE"] == current_viscode) & (
            imputed_df["RID_HASH"] == rid
        )

        if len(test_data[local_idx]) == 0:
            continue

        for col in test_data.columns:
            val = test_data.loc[local_idx][col].values[0]
            if val == val:
                continue
            imputed_val = imputed_df.loc[imputed_idx][col].values[0]
            test_data.loc[local_idx, col] = imputed_val

    return test_data

#dev_1_ext = fill_first_visit(dev_1_ext,first_visits_imputed)
dev_1_ext = fill_only_one_line(dev_1_ext,all_dev_sets_imputed)

#dev_2_ext = fill_first_visit(dev_2_ext,first_visits_imputed)
dev_2_ext = fill_only_one_line(dev_2_ext,all_dev_sets_imputed)

# dev_sim_A_ext = fill_first_visit(dev_sim_A_ext,all_dev_sets_imputed)
# dev_sim_A_ext = fill_only_one_line(dev_sim_A_ext,all_dev_sets_imputed)

# dev_sim_B_ext = fill_first_visit(dev_sim_B_ext,all_dev_sets_imputed)
# dev_sim_B_ext = fill_only_one_line(dev_sim_B_ext,all_dev_sets_imputed)

#test_A_ext = fill_first_visit(test_A_ext,first_visits_imputed)
test_A_ext = fill_only_one_line(test_A_ext,all_dev_sets_imputed)

#test_B_ext = fill_first_visit(test_B_ext,first_visits_imputed)
test_B_ext = fill_only_one_line(test_B_ext,all_dev_sets_imputed)

In [31]:
def review_constants(ref_df):
    ref_df = prepare_consts(dev_set, ref_df)
    ref_df = prepare_age(dev_set, ref_df)
    
    return ref_df

dev_1_ext = review_constants(dev_1_ext)
dev_2_ext = review_constants(dev_2_ext)
# dev_sim_A_ext = review_constants(dev_sim_A_ext)
# dev_sim_B_ext = review_constants(dev_sim_B_ext)
test_A_ext = review_constants(test_A_ext)
test_B_ext = review_constants(test_B_ext)

test_B_ext.isna().sum().sum()

5126

In [32]:
output_all_dev_sets = pd.concat([
    dev_set_ext,
    dev_1_ext, dev_2_ext,
    #dev_sim_A_ext, dev_sim_B_ext,
    test_A_ext, test_B_ext,
], ignore_index = True)

output_all_dev_rids = output_all_dev_sets["RID_HASH"]
output_all_dev_sets = output_all_dev_sets.drop(columns = ["RID_HASH"])

output_all_dev_sets_hash = dataframe_hash(output_all_dev_sets)
interm_bkp_file = workspace / f"interm_imputation_{output_all_dev_sets_hash}_{eval_mode}_{n_hidden_layers}_{n_hidden_units}_{imputer}.csv"

if interm_bkp_file.exists():
    output_all_dev_sets_imputed = pd.read_csv(interm_bkp_file)
else:
    output_all_dev_sets_imputed = full_imputation(all_dev_sets, imputer = imputer)
    output_all_dev_sets_imputed.to_csv(interm_bkp_file, index = None)
    
output_cols = list(all_dev_sets.columns)
output_all_dev_sets_imputed.columns = output_cols
output_all_dev_sets_imputed["RID_HASH"] = output_all_dev_rids.values

output_all_dev_sets_imputed

Unnamed: 0,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,...,latent_143,latent_144,latent_145,latent_146,latent_147,latent_148,latent_149,total_visits,last_visit,RID_HASH
0,0.0,79.100000,0.0,20.0,1.0,1.0,0.5,0.923077,0.164384,0.071871,...,0.019390,-0.022054,0.002527,-0.038565,0.016597,0.031256,0.001197,2.0,6.0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158
1,6.0,79.600000,0.0,20.0,1.0,1.0,1.5,0.923077,0.237397,0.071956,...,0.019390,-0.022054,0.002527,-0.038565,0.016597,0.031256,0.001197,2.0,6.0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158
2,0.0,72.900000,1.0,12.0,1.0,1.0,1.0,1.000000,0.123288,0.142655,...,0.009699,-0.010656,0.018631,0.009689,0.007170,0.001198,-0.011246,6.0,60.0,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c
3,6.0,73.400000,1.0,12.0,1.0,1.0,1.0,1.000000,0.164384,0.144729,...,0.009699,-0.010656,0.018631,0.009689,0.007170,0.001198,-0.011246,6.0,60.0,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c
4,12.0,73.900000,1.0,12.0,1.0,1.0,1.0,0.961538,0.109589,0.155550,...,0.009699,-0.010656,0.018631,0.009689,0.007170,0.001198,-0.011246,6.0,60.0,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15096,0.0,76.792175,0.0,18.0,1.0,1.0,1.5,0.884615,0.114110,0.257823,...,0.026051,-0.022910,0.000029,-0.029419,0.002335,0.023245,0.000420,3.0,84.0,ff4eb5a64e2b89861d5dea81190669893070b227f3a33577708f31b910ee5d96
15097,12.0,77.375237,0.0,18.0,1.0,1.0,1.5,0.923077,0.242055,0.330060,...,0.026051,-0.022910,0.000029,-0.029419,0.002335,0.023245,0.000420,3.0,84.0,ff4eb5a64e2b89861d5dea81190669893070b227f3a33577708f31b910ee5d96
15098,84.0,83.421364,0.0,18.0,1.0,1.0,1.5,1.000000,0.178082,0.483653,...,0.026051,-0.022910,0.000029,-0.029419,0.002335,0.023245,0.000420,3.0,84.0,ff4eb5a64e2b89861d5dea81190669893070b227f3a33577708f31b910ee5d96
15099,0.0,66.300000,1.0,13.0,0.0,0.0,0.0,0.923077,0.118767,0.177669,...,-0.000899,-0.000812,0.009761,0.035019,0.000149,-0.011346,-0.007122,2.0,24.0,ffa86109ba8684f31325842d0ff26568e105f0f63b366acd4c77c0d2ece69a2f


In [33]:
# from hyperimpute.plugins.imputers import Imputers
# from hyperimpute.utils.benchmarks import benchmark_model
# from sklearn.preprocessing import LabelEncoder

# gt = pd.concat([dev_set, dev_set], ignore_index=True)
# gt = gt.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)

# ordered_cols = list(gt.columns)

# gt_mask = pd.concat([dev_1, dev_2], ignore_index=True)
# gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
# gt_mask = gt_mask[ordered_cols].isna().astype(int)


# predictions = output_all_dev_sets_imputed[dev_set.columns].tail(len(output_all_dev_sets_imputed) - len(dev_set)).head(2 * len(dev_set))
# predictions = predictions.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
# predictions = predictions[ordered_cols]

# le = LabelEncoder().fit(gt["RID_HASH"])
# gt["RID_HASH"] = le.transform(gt["RID_HASH"])
# predictions["RID_HASH"] = le.transform(predictions["RID_HASH"])

# plugin = Imputers().get(
#     "hyperimpute",
#     optimizer="simple",
#     classifier_seed=["catboost"],
#     regression_seed=["xgboost_regressor", "catboost_regressor"],
#     class_threshold=cat_limit,
# )

# benchmark_model("nn", plugin, gt, predictions, gt_mask)

In [34]:
# from hyperimpute.plugins.imputers import Imputers
# from hyperimpute.utils.benchmarks import benchmark_model
# from sklearn.preprocessing import LabelEncoder

# test_A_baseline_imputation = impute_baseline_data(dev_set, test_A).reset_index(drop=True)
# test_B_baseline_imputation = impute_baseline_data(dev_set, test_B).reset_index(drop=True)

# gt = pd.concat([test_A_gt, test_B_gt], ignore_index=True)
# gt = gt.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)

# ordered_cols = list(gt.columns)

# gt_mask = pd.concat([test_A, test_B], ignore_index=True)
# gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
# gt_mask = gt_mask[ordered_cols].isna().astype(int)

# predictions = pd.concat([test_A_baseline_imputation, test_B_baseline_imputation], ignore_index=True)
# predictions = predictions.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)


# le = LabelEncoder().fit(gt["RID_HASH"])
# gt["RID_HASH"] = le.transform(gt["RID_HASH"])
# predictions["RID_HASH"] = le.transform(predictions["RID_HASH"])

# plugin = Imputers().get(
#     "hyperimpute",
#     optimizer="simple",
#     classifier_seed=["catboost"],
#     regression_seed=["xgboost_regressor", "catboost_regressor"],
#     class_threshold=cat_limit,
# )

# benchmark_model("nn", plugin, gt, predictions, gt_mask)



In [35]:
# full
from hyperimpute.plugins.imputers import Imputers
from hyperimpute.utils.benchmarks import benchmark_model
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler

gt = pd.concat([test_A_gt, test_B_gt], ignore_index=True)
gt = gt.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
extra_scaler = MinMaxScaler().fit(gt[["AGE", "PTEDUCAT"]])
gt.loc[:, ["AGE", "PTEDUCAT"]] = extra_scaler.transform(gt[["AGE", "PTEDUCAT"]])

gt  = gt.reset_index(drop = True)
ordered_cols = list(gt.columns)

gt_mask = pd.concat([test_A, test_B], ignore_index=True)
gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
gt_mask = gt_mask[ordered_cols].isna().astype(int)


predictions = output_all_dev_sets_imputed[dev_set.columns].tail(len(test_A) + len(test_B))
predictions = predictions.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
predictions = predictions[ordered_cols]
predictions.loc[:, ["AGE", "PTEDUCAT"]] = extra_scaler.transform(predictions[["AGE", "PTEDUCAT"]])

le = LabelEncoder().fit(gt["RID_HASH"])
gt["RID_HASH"] = le.transform(gt["RID_HASH"])
predictions["RID_HASH"] = le.transform(predictions["RID_HASH"])

plugin = Imputers().get(
    "hyperimpute",
    optimizer="simple",
    classifier_seed=["catboost"],
    regression_seed=["xgboost_regressor", "catboost_regressor"],
    class_threshold=cat_limit,
)

benchmark_model("nn", plugin, gt, predictions, gt_mask)

(0.425304110915405, 0.2519891938036695)

In [36]:
# only baseline
from hyperimpute.plugins.imputers import Imputers
from hyperimpute.utils.benchmarks import benchmark_model
from sklearn.preprocessing import LabelEncoder

gt = pd.concat([test_A_gt, test_B_gt], ignore_index=True)
gt = gt.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
gt  = gt[gt["VISCODE"] == 0].reset_index(drop = True)
extra_scaler = MinMaxScaler().fit(gt[["AGE", "PTEDUCAT"]])
gt.loc[:, ["AGE", "PTEDUCAT"]] = extra_scaler.transform(gt[["AGE", "PTEDUCAT"]])

ordered_cols = list(gt.columns)

gt_mask = pd.concat([test_A, test_B], ignore_index=True)
gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
gt_mask  = gt_mask[gt_mask["VISCODE"] == 0].reset_index(drop = True)

gt_mask = gt_mask[ordered_cols].isna().astype(int)


predictions = output_all_dev_sets_imputed[dev_set.columns].tail(len(test_A) + len(test_B))
predictions = predictions.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
predictions = predictions[ordered_cols]
predictions  = predictions[predictions["VISCODE"] == 0].reset_index(drop = True)
predictions.loc[:, ["AGE", "PTEDUCAT"]] = extra_scaler.transform(predictions[["AGE", "PTEDUCAT"]])

le = LabelEncoder().fit(gt["RID_HASH"])
gt["RID_HASH"] = le.transform(gt["RID_HASH"])
predictions["RID_HASH"] = le.transform(predictions["RID_HASH"])

plugin = Imputers().get(
    "hyperimpute",
    optimizer="simple",
    classifier_seed=["catboost"],
    regression_seed=["xgboost_regressor", "catboost_regressor"],
    class_threshold=cat_limit,
)

benchmark_model("nn", plugin, gt, predictions, gt_mask)

(0.42331869329319277, 0.24586861237925278)

In [37]:
#by patient
from hyperimpute.utils.benchmarks import RMSE
from sklearn.preprocessing import MinMaxScaler

gt = pd.concat([test_A_gt, test_B_gt], ignore_index=True)
gt = gt.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
gt  = gt.reset_index(drop = True)

extra_scaler = MinMaxScaler().fit(gt[["AGE", "PTEDUCAT"]])
test_data = pd.concat([test_A, test_B], ignore_index=True)
test_data = test_data.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
test_data  = test_data.reset_index(drop = True)


ordered_cols = list(gt.columns)

predictions = output_all_dev_sets_imputed[dev_set.columns].tail(len(test_A) + len(test_B))
predictions = predictions.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
predictions = predictions[ordered_cols]

rids = list(predictions["RID_HASH"].unique())
patient_errors = []
for rid in rids:
    predicted_patient = predictions[predictions["RID_HASH"] == rid].copy()
    predicted_patient.loc[:, ["AGE", "PTEDUCAT"]] = extra_scaler.transform(predicted_patient[["AGE", "PTEDUCAT"]])
    gt_patient = gt[gt["RID_HASH"] == rid].copy()
    gt_patient.loc[:, ["AGE", "PTEDUCAT"]] = extra_scaler.transform(gt_patient[["AGE", "PTEDUCAT"]])
    
    patient_mask = test_data[test_data["RID_HASH"] == rid]
    assert len(predicted_patient) == len(gt_patient)
    assert (gt_patient["VISCODE"].values == predicted_patient["VISCODE"].values).all()
    
    patient_err = RMSE(
        predicted_patient.drop(columns=["RID_HASH"]).values,
        gt_patient.drop(columns=["RID_HASH"]).values,
        patient_mask.drop(columns=["RID_HASH"]).values,
    )
    patient_errors.append(patient_err)
    
np.argmax(patient_errors)

484

In [38]:
err_rank = np.argsort(patient_errors)
err_id = err_rank[-2]

In [39]:
worst_rid = rids[err_id]

predictions[predictions["RID_HASH"] == worst_rid]

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
828,4bd1aa17b0c3faf8c327c10f38b5ae7a4e9e2aa2c2c77c92626a0a4d7b65267c,0.0,71.51725,0.0,18.0,1.0,2.0,1.111976,0.968519,0.148704,0.157724,0.583219,0.553319,0.546004,0.552429,0.579001


In [40]:
gt[gt["RID_HASH"] == worst_rid]

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
828,4bd1aa17b0c3faf8c327c10f38b5ae7a4e9e2aa2c2c77c92626a0a4d7b65267c,0.0,83.2,0.0,18.0,2.0,2.0,7.0,0.730769,0.383562,0.415863,0.335102,0.553319,0.208466,0.469195,0.49421


In [41]:
test_data[test_data["RID_HASH"] == worst_rid]

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
828,4bd1aa17b0c3faf8c327c10f38b5ae7a4e9e2aa2c2c77c92626a0a4d7b65267c,0,,0.0,18.0,,2.0,,,,,,0.553319,,,


In [42]:
single_visits = pd.DataFrame([], columns = test_data.columns)

for rid in test_data["RID_HASH"].unique():
    test_patient = gt[gt["RID_HASH"] == rid]
    
    if len(test_patient) > 1:
        continue
        
    single_visits = pd.concat([single_visits, test_patient], ignore_index = True)
    
single_visits

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,00d5e0050fbd3b6b610f6673347232eb0862df77b5b7a8f667526b0e4520129b,0.0,81.3,1.0,16.0,1.0,0.0,0.5,0.961538,0.219178,0.274673,0.397517,0.272565,0.405996,0.345331,0.50579
1,03970bdbd31927905f2c098ccd3f150b31e4f9600b2640af59d7286cbccb2d4b,18.0,66.5,1.0,18.0,0.0,0.0,0.0,1.0,0.0,0.141111,0.628442,0.623964,0.584127,0.516854,0.533439
2,04226ed8396b78fdf3995df094821d7a47b19bff154005cfadc69c6909a9b777,0.0,75.420151,1.0,18.0,1.0,0.0,1.73887,0.960689,0.182116,0.235311,0.477359,0.311863,0.457496,0.347168,0.47045
3,0629011c0446a19b9d67018b9ecfd46c2bd51bde6156b6324dbe2b5d1db72105,12.0,82.9,1.0,18.0,1.0,0.0,1.0,0.961538,0.164384,0.222455,0.324266,0.250778,0.159965,0.134985,0.252553
4,06407d9ec85d62cd38189108ddffec23822f421b3db357329effdd717a74b837,0.0,83.3,0.0,20.0,1.0,0.0,2.0,0.961538,0.351644,0.18368,0.478705,0.447394,0.429316,0.435656,0.457018
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
177,fbfdb767734a44e62b52e76d6311f7d0fb0f4eb957a08eb424915336dcdce1d9,12.0,82.5,0.0,18.0,1.0,0.0,3.5,0.846154,0.415479,0.516466,0.203725,0.474423,0.147266,0.333565,0.220973
178,fc4178306591e7d4c2212dd69fbae58562590b0843c553c0299eaf2ae2f6e7d7,0.0,70.6,0.0,18.0,1.0,0.0,1.21826,0.953822,0.16017,0.295162,0.619226,0.564826,0.562248,0.55994,0.585126
179,fcced825186806c1933bb385511159f55b5c6b5049bfe9f309f7983369923156,12.0,71.5,0.0,18.0,1.0,0.0,1.0,0.923077,0.191781,0.191544,0.523815,0.577101,0.563139,0.603386,0.57942
180,fd1c6b33c75383b037d555012ee62dc049b582153d6ec80b4d613e6581cdb5ba,0.0,74.767059,0.0,20.0,1.0,0.0,1.582055,0.948814,0.181941,0.321542,0.548691,0.476798,0.347972,0.503252,0.478313


In [43]:
dev_set["AGE"].astype(float).describe()

count    4101.000000
mean       74.699829
std         7.154518
min        54.400000
25%        70.200000
50%        75.000000
75%        79.700000
max        97.400000
Name: AGE, dtype: float64

In [44]:
#after baseline (1.2506788549446284, 0.7308089802338017)
#only baseline (1.6911729720046835, 0.7748322807857492)

#drop extensions dev_A_sim 
#first visit (1.6682034054762505, 0.8184243427517103)
#full (1.3553018330092792, 0.6488727381867302)

# added all directions
# full (1.7102879134817492, 1.8964830911951538)


In [45]:
# Test evaluation
# Baseline (1.4042880265987805, 0.731769237280525)
# LSTM 1 100 (1.3640127691743418, 1.0356662381066515)
# Transformer 1 100 (1.3549483385727494, 0.8050221892104822)
# xceptionTime 1 100(1.4155041581436953, 0.8720785341842062)

# xceptionTime 2 150 (1.3620218597921618, 0.9483736516506034)
# Transformer (1.4409427762164733, 0.6911534459721259)
# Transformer + data aug (1.36128888764713, 0.6224523920143119)



# Train evaluation
# Transformer 1 100 (1.2383225684818857, 0.8607633819994501)


In [46]:
# Benchmark static predictors
# from pydantic import validate_arguments
# import numpy as np
# from hyperimpute.plugins.prediction import Classifiers, Regression
# from hyperimpute.utils.tester import evaluate_estimator, evaluate_regression
# from sklearn.preprocessing import LabelEncoder

# static_features_imputation = ["AGE", "PTGENDER_num", "PTEDUCAT", "APOE4"]
# temporal_features_imputation = [
#     "AGE",
#     "DX_num",
#     "CDRSB",
#     "MMSE",
#     "ADAS13",
#     "Ventricles",
#     "Hippocampus",
#     "WholeBrain",
#     "Entorhinal",
#     "Fusiform",
#     "MidTemp",
# ]

# @validate_arguments(config=dict(arbitrary_types_allowed=True))
# def prepare_static_feature_covariates(
#     latent_space: list
# ):
#     covariates = []
#     ids = []
#     for latent in latent_space:
#         covariates.append(latent.drop(columns=["RID_HASH"]))
#         ids.append(latent["RID_HASH"])
#     return pd.concat(ids, ignore_index=True), pd.concat(covariates, ignore_index=True).astype(float)

# @validate_arguments(config=dict(arbitrary_types_allowed=True))
# def prepare_static_feature_data(
#     dataset: pd.DataFrame, latents: list, feature: str
# ):
#     dbg_Y = (
#         dataset.drop_duplicates("RID_HASH")[feature]
#         .reset_index(drop=True)
#         .values.astype(float)
#     )

#     labels = pd.Series(dbg_Y.tolist() * len(latents))

#     _, covariates = prepare_static_feature_covariates(latents)
#     return covariates, labels

# @validate_arguments(config=dict(arbitrary_types_allowed=True))
# def prepare_temporal_feature_covariates(
#     dataset: pd.DataFrame, latents: list
# ):
#     working_latents =  [[] for i in range(len(latents))]

#     rids = []
#     viscodes = []
#     for rid in dataset["RID_HASH"].unique():
#         patient = dataset[dataset["RID_HASH"] == rid]

#         patient_viscode = patient["VISCODE"]

#         for idx, latent in enumerate(latents):
#             patient_latent = latent[latent["RID_HASH"] == rid]
#             patient_latent_data = patient_latent.loc[
#                 patient_latent.index.repeat(len(patient_viscode))
#             ].reset_index(drop=True)
#             patient_latent_data["VISCODE"] = patient_viscode.values
#             patient_latent_data = patient_latent_data.drop(columns=["RID_HASH"])
            
#             rids.append(patient["RID_HASH"])
#             viscodes.append(patient["VISCODE"])
#             working_latents[idx].append(patient_latent_data)

#     full_latents = []
#     for idx, latent in enumerate(working_latents):
#         full_latents.append(pd.concat(latent, ignore_index=True))

#     covariates = pd.concat(full_latents, ignore_index=True).astype(float)
#     rids = pd.concat(rids, ignore_index=True)
#     viscodes = pd.concat(viscodes, ignore_index=True)
    
#     return rids, viscodes, covariates

# @validate_arguments(config=dict(arbitrary_types_allowed=True))
# def prepare_temporal_feature_data(
#     dataset: pd.DataFrame, latents: list, feature: str
# ):
#     working_target = [[] for i in range(len(latents))]
#     working_latents =  [[] for i in range(len(latents))]

#     for rid in dataset["RID_HASH"].unique():
#         patient = dataset[dataset["RID_HASH"] == rid]

#         patient_target = patient[feature]
#         patient_viscode = patient["VISCODE"]

#         for idx, latent in enumerate(latents):
#             patient_latent = latent[latent["RID_HASH"] == rid]
#             patient_latent_data = patient_latent.loc[
#                 patient_latent.index.repeat(len(patient_viscode))
#             ].reset_index(drop=True)
#             patient_latent_data["VISCODE"] = patient_viscode.values
#             patient_latent_data = patient_latent_data.drop(columns=["RID_HASH"])

#             working_latents[idx].append(patient_latent_data)
#             working_target[idx].append(patient_target)

#     full_latents = []
#     full_targets = []
#     for idx, latent in enumerate(working_latents):
#         full_latents.append(pd.concat(latent, ignore_index=True))
#         full_targets.append(pd.concat(working_target[idx], ignore_index=True))

#     covariates = pd.concat(full_latents, ignore_index=True).astype(float)
#     labels = pd.concat(full_targets, ignore_index=True)

#     return covariates, labels

# def benchmark_static_feature(feature, base_model="xgboost"):
#     print("Benchmarking static", feature)
#     for mode in modes:
#         covariates, labels = prepare_static_feature_data(
#             dev_set, train_latents[mode], feature
#         )

#         if len(np.unique(labels)) < cat_limit:
#             encoded_labels = LabelEncoder().fit_transform(labels)

#             eval_model = Classifiers().get(base_model)
#             score = evaluate_estimator(
#                 eval_model, covariates, pd.Series(encoded_labels)
#             )["str"]
#         else:
#             eval_model = Regression().get(f"{base_model}_regressor")
#             score = evaluate_regression(eval_model, covariates.values, labels.values)[
#                 "str"
#             ]

#         print(" >>> ", mode, score)


# def benchmark_temporal_feature(feature):
#     print("Benchmarking temporal ", feature)
#     for mode in modes:
#         covariates, labels = prepare_temporal_feature_data(
#             dev_set, train_latents[mode], feature
#         )

#         if len(np.unique(labels)) < cat_limit:
#             encoded_labels = LabelEncoder().fit_transform(labels)

#             eval_model = Classifiers().get("xgboost")
#             score = evaluate_estimator(
#                 eval_model, covariates, pd.Series(encoded_labels)
#             )["str"]
#         else:
#             eval_model = Regression().get("xgboost_regressor")
#             score = evaluate_regression(eval_model, covariates.values, labels.values)[
#                 "str"
#             ]

#         print(" >>> ", mode, score)

# # static
# benchmark_static_feature("AGE")
# benchmark_static_feature("PTGENDER_num")
# benchmark_static_feature("PTEDUCAT")
# benchmark_static_feature("APOE4")

# # temporal
# benchmark_temporal_feature("AGE")
# benchmark_temporal_feature("DX_num")
# benchmark_temporal_feature("CDRSB")
# benchmark_temporal_feature("MMSE")
# benchmark_temporal_feature("ADAS13")
# benchmark_temporal_feature("Ventricles")
# benchmark_temporal_feature("Hippocampus")
# benchmark_temporal_feature("WholeBrain")
# benchmark_temporal_feature("Entorhinal")
# benchmark_temporal_feature("Fusiform")
# benchmark_temporal_feature("MidTemp")

In [47]:
# Benchmarking static AGE
#  >>>  LSTM {'rmse': '33.7224 +/- 1.2081', 'wnd': '1.7105 +/- 0.2028', 'r2': '0.3432 +/- 0.0393'}
#  >>>  GRU {'rmse': '40.3391 +/- 1.6157', 'wnd': '2.1703 +/- 0.2149', 'r2': '0.2151 +/- 0.0282'}
#  >>>  RNN {'rmse': '41.1159 +/- 1.8664', 'wnd': '2.1867 +/- 0.1716', 'r2': '0.2001 +/- 0.0315'}
#  >>>  Transformer {'rmse': '33.5526 +/- 0.3647', 'wnd': '1.6906 +/- 0.119', 'r2': '0.3464 +/- 0.0349'}
#  >>>  XceptionTime {'rmse': '35.2435 +/- 0.2436', 'wnd': '1.807 +/- 0.1296', 'r2': '0.3136 +/- 0.0331'}
#  >>>  ResCNN {'rmse': '36.9825 +/- 0.3269', 'wnd': '1.7995 +/- 0.2043', 'r2': '0.2799 +/- 0.0305'}
# Benchmarking static PTGENDER_num
#  >>>  LSTM {'aucroc': '0.9775 +/- 0.0039'}
#  >>>  GRU {'aucroc': '0.9766 +/- 0.0027'}
#  >>>  RNN {'aucroc': '0.9775 +/- 0.002'}
#  >>>  Transformer {'aucroc': '0.9775 +/- 0.0033'}
#  >>>  XceptionTime {'aucroc': '0.9776 +/- 0.0025'}
#  >>>  ResCNN {'aucroc': '0.9803 +/- 0.0037'}
# Benchmarking static PTEDUCAT
#  >>>  LSTM {'rmse': '0.8336 +/- 0.098', 'wnd': '0.3238 +/- 0.0099', 'r2': '0.8915 +/- 0.0059'}
#  >>>  GRU {'rmse': '1.0568 +/- 0.1082', 'wnd': '0.3458 +/- 0.0174', 'r2': '0.8623 +/- 0.0055'}
#  >>>  RNN {'rmse': '1.3529 +/- 0.122', 'wnd': '0.3798 +/- 0.0182', 'r2': '0.8234 +/- 0.009'}
#  >>>  Transformer {'rmse': '0.6925 +/- 0.071', 'wnd': '0.3234 +/- 0.0042', 'r2': '0.9097 +/- 0.0054'}
#  >>>  XceptionTime {'rmse': '0.6739 +/- 0.0253', 'wnd': '0.3191 +/- 0.0068', 'r2': '0.9119 +/- 0.0029'}
#  >>>  ResCNN {'rmse': '1.5594 +/- 0.1129', 'wnd': '0.432 +/- 0.0105', 'r2': '0.7965 +/- 0.0015'}
# Benchmarking static APOE4
#  >>>  LSTM {'aucroc': '0.9981 +/- 0.001'}
#  >>>  GRU {'aucroc': '0.9981 +/- 0.0011'}
#  >>>  RNN {'aucroc': '0.998 +/- 0.0009'}
#  >>>  Transformer {'aucroc': '0.9979 +/- 0.0011'}
#  >>>  XceptionTime {'aucroc': '0.9987 +/- 0.0011'}
#  >>>  ResCNN {'aucroc': '0.9982 +/- 0.0008'}
# Benchmarking temporal  AGE
#  >>>  LSTM {'rmse': '9.5289 +/- 0.5976', 'wnd': '1.0305 +/- 0.064', 'r2': '0.8138 +/- 0.01'}
#  >>>  GRU {'rmse': '11.5789 +/- 0.2229', 'wnd': '1.3678 +/- 0.0432', 'r2': '0.7737 +/- 0.0026'}
#  >>>  RNN {'rmse': '12.3926 +/- 0.4477', 'wnd': '1.3992 +/- 0.0059', 'r2': '0.7577 +/- 0.0114'}
#  >>>  Transformer {'rmse': '9.3344 +/- 0.4739', 'wnd': '0.9091 +/- 0.0398', 'r2': '0.8175 +/- 0.0098'}
#  >>>  XceptionTime {'rmse': '10.2583 +/- 0.2251', 'wnd': '0.9717 +/- 0.0564', 'r2': '0.7995 +/- 0.002'}
#  >>>  ResCNN {'rmse': '9.2169 +/- 0.3061', 'wnd': '0.9536 +/- 0.0477', 'r2': '0.8199 +/- 0.0048'}
# Benchmarking temporal  DX_num
#  >>>  LSTM {'aucroc': '0.9883 +/- 0.0007'}
#  >>>  GRU {'aucroc': '0.9861 +/- 0.0006'}
#  >>>  RNN {'aucroc': '0.9864 +/- 0.0011'}
#  >>>  Transformer {'aucroc': '0.9899 +/- 0.0005'}
#  >>>  XceptionTime {'aucroc': '0.987 +/- 0.0007'}
#  >>>  ResCNN {'aucroc': '0.9849 +/- 0.0008'}
# Benchmarking temporal  CDRSB
#  >>>  LSTM {'rmse': '0.9127 +/- 0.0342', 'wnd': '0.1806 +/- 0.0092', 'r2': '0.8142 +/- 0.0091'}
#  >>>  GRU {'rmse': '0.9477 +/- 0.0403', 'wnd': '0.1821 +/- 0.0039', 'r2': '0.8071 +/- 0.0101'}
#  >>>  RNN {'rmse': '0.9338 +/- 0.0353', 'wnd': '0.1786 +/- 0.0053', 'r2': '0.8099 +/- 0.0094'}
#  >>>  Transformer {'rmse': '0.9776 +/- 0.0059', 'wnd': '0.1801 +/- 0.0113', 'r2': '0.8011 +/- 0.0032'}
#  >>>  XceptionTime {'rmse': '1.0576 +/- 0.023', 'wnd': '0.1958 +/- 0.0052', 'r2': '0.7848 +/- 0.0071'}
#  >>>  ResCNN {'rmse': '0.986 +/- 0.0388', 'wnd': '0.1999 +/- 0.0067', 'r2': '0.7993 +/- 0.0101'}
# Benchmarking temporal  MMSE
#  >>>  LSTM {'rmse': '0.0049 +/- 0.0002', 'wnd': '0.0161 +/- 0.0007', 'r2': '0.7109 +/- 0.0181'}
#  >>>  GRU {'rmse': '0.0051 +/- 0.0002', 'wnd': '0.0167 +/- 0.0007', 'r2': '0.6988 +/- 0.0172'}
#  >>>  RNN {'rmse': '0.0051 +/- 0.0002', 'wnd': '0.0172 +/- 0.0005', 'r2': '0.6976 +/- 0.0183'}
#  >>>  Transformer {'rmse': '0.0051 +/- 0.0002', 'wnd': '0.0162 +/- 0.0006', 'r2': '0.6991 +/- 0.0187'}
#  >>>  XceptionTime {'rmse': '0.0055 +/- 0.0003', 'wnd': '0.0169 +/- 0.0007', 'r2': '0.6785 +/- 0.0225'}
#  >>>  ResCNN {'rmse': '0.0051 +/- 0.0001', 'wnd': '0.0169 +/- 0.0012', 'r2': '0.6995 +/- 0.0141'}
# Benchmarking temporal  ADAS13
#  >>>  LSTM {'rmse': '0.0039 +/- 0.0001', 'wnd': '0.0113 +/- 0.0005', 'r2': '0.8165 +/- 0.0074'}
#  >>>  GRU {'rmse': '0.0042 +/- 0.0', 'wnd': '0.0122 +/- 0.0006', 'r2': '0.8025 +/- 0.0055'}
#  >>>  RNN {'rmse': '0.0042 +/- 0.0001', 'wnd': '0.0126 +/- 0.0008', 'r2': '0.8007 +/- 0.0071'}
#  >>>  Transformer {'rmse': '0.004 +/- 0.0', 'wnd': '0.0104 +/- 0.0008', 'r2': '0.8108 +/- 0.0038'}
#  >>>  XceptionTime {'rmse': '0.0044 +/- 0.0001', 'wnd': '0.0112 +/- 0.0007', 'r2': '0.7955 +/- 0.0085'}
#  >>>  ResCNN {'rmse': '0.0043 +/- 0.0001', 'wnd': '0.0121 +/- 0.0007', 'r2': '0.7973 +/- 0.0075'}
# Benchmarking temporal  Ventricles
#  >>>  LSTM {'rmse': '0.0037 +/- 0.0002', 'wnd': '0.0196 +/- 0.0007', 'r2': '0.8195 +/- 0.0121'}
#  >>>  GRU {'rmse': '0.0042 +/- 0.0003', 'wnd': '0.0235 +/- 0.0007', 'r2': '0.7968 +/- 0.0171'}
#  >>>  RNN {'rmse': '0.0043 +/- 0.0002', 'wnd': '0.0232 +/- 0.0008', 'r2': '0.7921 +/- 0.0125'}
#  >>>  Transformer {'rmse': '0.0021 +/- 0.0001', 'wnd': '0.0074 +/- 0.0005', 'r2': '0.8997 +/- 0.0053'}
#  >>>  XceptionTime {'rmse': '0.0021 +/- 0.0001', 'wnd': '0.0078 +/- 0.0005', 'r2': '0.8982 +/- 0.0081'}
#  >>>  ResCNN {'rmse': '0.0035 +/- 0.0002', 'wnd': '0.0183 +/- 0.0008', 'r2': '0.828 +/- 0.0075'}
# Benchmarking temporal  Hippocampus
#  >>>  LSTM {'rmse': '0.002 +/- 0.0002', 'wnd': '0.009 +/- 0.0011', 'r2': '0.8918 +/- 0.0082'}
#  >>>  GRU {'rmse': '0.0028 +/- 0.0001', 'wnd': '0.0157 +/- 0.0012', 'r2': '0.849 +/- 0.0084'}
#  >>>  RNN {'rmse': '0.0029 +/- 0.0001', 'wnd': '0.0157 +/- 0.0005', 'r2': '0.8449 +/- 0.0032'}
#  >>>  Transformer {'rmse': '0.0019 +/- 0.0001', 'wnd': '0.0075 +/- 0.0009', 'r2': '0.9003 +/- 0.0041'}
#  >>>  XceptionTime {'rmse': '0.0021 +/- 0.0', 'wnd': '0.0087 +/- 0.0008', 'r2': '0.8879 +/- 0.0006'}
#  >>>  ResCNN {'rmse': '0.002 +/- 0.0', 'wnd': '0.0093 +/- 0.0013', 'r2': '0.8912 +/- 0.0021'}
# Benchmarking temporal  WholeBrain
#  >>>  LSTM {'rmse': '0.0019 +/- 0.0001', 'wnd': '0.0096 +/- 0.0005', 'r2': '0.8922 +/- 0.0067'}
#  >>>  GRU {'rmse': '0.0029 +/- 0.0001', 'wnd': '0.0172 +/- 0.0002', 'r2': '0.8399 +/- 0.005'}
#  >>>  RNN {'rmse': '0.0033 +/- 0.0002', 'wnd': '0.0189 +/- 0.0001', 'r2': '0.8185 +/- 0.0135'}
#  >>>  Transformer {'rmse': '0.0016 +/- 0.0001', 'wnd': '0.0068 +/- 0.0007', 'r2': '0.9097 +/- 0.0059'}
#  >>>  XceptionTime {'rmse': '0.0018 +/- 0.0001', 'wnd': '0.0073 +/- 0.0005', 'r2': '0.9004 +/- 0.0068'}
#  >>>  ResCNN {'rmse': '0.0019 +/- 0.0001', 'wnd': '0.009 +/- 0.0008', 'r2': '0.8943 +/- 0.0062'}
# Benchmarking temporal  Entorhinal
#  >>>  LSTM {'rmse': '0.005 +/- 0.0001', 'wnd': '0.0184 +/- 0.0013', 'r2': '0.7782 +/- 0.0038'}
#  >>>  GRU {'rmse': '0.0057 +/- 0.0002', 'wnd': '0.0234 +/- 0.0018', 'r2': '0.7446 +/- 0.0074'}
#  >>>  RNN {'rmse': '0.0058 +/- 0.0002', 'wnd': '0.0232 +/- 0.0012', 'r2': '0.743 +/- 0.0115'}
#  >>>  Transformer {'rmse': '0.0052 +/- 0.0001', 'wnd': '0.0178 +/- 0.0013', 'r2': '0.7688 +/- 0.0027'}
#  >>>  XceptionTime {'rmse': '0.0055 +/- 0.0001', 'wnd': '0.0184 +/- 0.0008', 'r2': '0.754 +/- 0.0086'}
#  >>>  ResCNN {'rmse': '0.0051 +/- 0.0001', 'wnd': '0.0176 +/- 0.0007', 'r2': '0.7736 +/- 0.0016'}
# Benchmarking temporal  Fusiform
#  >>>  LSTM {'rmse': '0.0027 +/- 0.0', 'wnd': '0.0124 +/- 0.0008', 'r2': '0.8499 +/- 0.0024'}
#  >>>  GRU {'rmse': '0.0037 +/- 0.0001', 'wnd': '0.0195 +/- 0.0007', 'r2': '0.7959 +/- 0.0116'}
#  >>>  RNN {'rmse': '0.004 +/- 0.0002', 'wnd': '0.0207 +/- 0.0014', 'r2': '0.78 +/- 0.0133'}
#  >>>  Transformer {'rmse': '0.0027 +/- 0.0001', 'wnd': '0.0108 +/- 0.0006', 'r2': '0.8528 +/- 0.006'}
#  >>>  XceptionTime {'rmse': '0.003 +/- 0.0001', 'wnd': '0.0122 +/- 0.0007', 'r2': '0.835 +/- 0.0059'}
#  >>>  ResCNN {'rmse': '0.003 +/- 0.0001', 'wnd': '0.0127 +/- 0.001', 'r2': '0.8373 +/- 0.0076'}
# Benchmarking temporal  MidTemp
#  >>>  LSTM {'rmse': '0.0023 +/- 0.0', 'wnd': '0.011 +/- 0.0007', 'r2': '0.8684 +/- 0.0029'}
#  >>>  GRU {'rmse': '0.0031 +/- 0.0001', 'wnd': '0.0185 +/- 0.0003', 'r2': '0.8222 +/- 0.0053'}
#  >>>  RNN {'rmse': '0.0033 +/- 0.0001', 'wnd': '0.0188 +/- 0.0007', 'r2': '0.8101 +/- 0.0064'}
#  >>>  Transformer {'rmse': '0.0023 +/- 0.0001', 'wnd': '0.0102 +/- 0.001', 'r2': '0.8695 +/- 0.0065'}
#  >>>  XceptionTime {'rmse': '0.0024 +/- 0.0001', 'wnd': '0.0099 +/- 0.0003', 'r2': '0.8652 +/- 0.0088'}
#  >>>  ResCNN {'rmse': '0.0024 +/- 0.0001', 'wnd': '0.011 +/- 0.0003', 'r2': '0.8664 +/- 0.0082'}

In [48]:
raise

RuntimeError: No active exception to reraise

## Submission data

In [None]:
def normalize_output(test_data):
    test_data = test_data.copy()
    factor = test_data["CDRSB"] / 0.5
    factor[factor < 0] = 0
    factor = factor.fillna(-1)
    factor = factor.round(0).astype(int)
    factor = factor.replace(-1, np.nan)
    test_data["CDRSB"] = factor * 0.5

    test_data["ADAS13"] = ((test_data["ADAS13"] * 3).round(0) / 3).round(2)
    test_data["MMSE"] = test_data["MMSE"].round(0)

    return test_data


def dump_results(imputed_data: pd.DataFrame, fpath: str):
    results = []

    for name, data in [
        ("test_A", test_A.sort_index()),
        ("test_B", test_B.sort_index()),
    ]:
        for idx, row in data.iterrows():
            for col in row.index:
                local = row.T
                val = local[col]
                if val == val:
                    continue
                imputed_id = f"{local['RID_HASH']}_{local['VISCODE']}_{col}_{name}"
                imputed_val = imputed_data[
                    (imputed_data["RID_HASH"] == local["RID_HASH"])
                    & (imputed_data["VISCODE"] == local["VISCODE"])
                ][col].values[0]

                assert imputed_val == imputed_val
                assert imputed_val != ""

                results.append([imputed_id, imputed_val])

    output = pd.DataFrame(results, columns=submission.columns)
    output.to_csv(fpath, index=None)

    return output


def get_submission_data(inputed_data):
    inputed_data = inputed_data.copy()
    inputed_data[scaled_cols] = scaler.inverse_transform(inputed_data[scaled_cols])

    output_fpath = (
        results_dir
        / f"imputation_results_{version}_{changelog}_{eval_mode}_normalized.csv"
    )

    print("Prepare output", output_fpath)
    output_normalized = dump_results(normalize_output(inputed_data), output_fpath)

    return output_fpath, output_normalized


fpath, output = get_submission_data(output_all_dev_sets_imputed)

In [None]:
output