In [1]:
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from hyperimpute.utils.serialization import (load_model_from_file,
                                             save_model_to_file)
from sklearn.preprocessing import MinMaxScaler

workspace = Path("workspace")
results_dir = Path("results")
data_dir = Path("data")

workspace.mkdir(parents=True, exist_ok=True)

warnings.filterwarnings("ignore")

cat_limit = 10
n_seeds = 5

version = "take7"
changelog = f"ffill_catlimit{cat_limit}"

In [2]:
def dataframe_hash(df: pd.DataFrame) -> str:
    cols = sorted(list(df.columns))
    return str(abs(pd.util.hash_pandas_object(df[cols].fillna(0)).sum()))


def augment_base_dataset(df):
    df = df.sort_values(["RID_HASH", "VISCODE"])

    return df

In [3]:
dev_set = pd.read_csv(data_dir / "dev_set.csv")
dev_set = dev_set.sort_values(["RID_HASH", "VISCODE"])
dev_set = augment_base_dataset(dev_set)

scaled_cols = [
    "MMSE",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]

scaler = MinMaxScaler().fit(dev_set[scaled_cols])
dev_set[scaled_cols] = scaler.transform(dev_set[scaled_cols])

dev_set

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,0,20,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,0,20,1.0,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,72.9,1,12,1.0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,73.4,1,12,1.0,1.0,1.0,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,73.9,1,12,1.0,1.0,1.0,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,79.8,1,19,1.0,0.0,3.0,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,83.3,1,19,1.0,0.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,0,12,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,73.1,0,12,1.0,0.0,1.0,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [4]:
static_features = ["RID_HASH", "AGE", "PTGENDER_num", "PTEDUCAT", "APOE4"]  # first age
temporal_features = [
    "RID_HASH",
    "VISCODE",
    "DX_num",
    "CDRSB",
    "MMSE",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]  #

dev_set_static = dev_set.sort_values(["RID_HASH", "VISCODE"]).drop_duplicates(
    "RID_HASH"
)[static_features]
dev_set_temporal = dev_set.sort_values(["RID_HASH", "VISCODE"])[temporal_features]

dev_set_static

Unnamed: 0,RID_HASH,AGE,PTGENDER_num,PTEDUCAT,APOE4
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,79.1,0,20,1.0
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,72.9,1,12,1.0
298,0131f7f44ff183309c590b9ff440806b20f639c90c124d...,73.9,0,12,0.0
1762,01513c9ff1e8fcc22cbfc9093845a37ee69307e3493daf...,73.4,1,12,0.0
406,01705aaf2c869203d7a8374472f5907f53f3b15f7b4faa...,70.4,0,16,0.0
...,...,...,...,...,...
2205,ff1d8cc22fb5bf2bd80e31d6d3a6cf1709562bb7e9a22f...,71.5,1,16,1.0
1593,ff21c0f13c9535e8339ce653a268b26df8e4172212ac05...,75.9,1,18,0.0
3458,ff48382bcf5922a2db52db36c791b02910015feee82505...,70.3,0,12,1.0
1438,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,74.8,1,19,0.0


In [5]:
dev_set_temporal

Unnamed: 0,RID_HASH,VISCODE,DX_num,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,1.0,1.0,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,1.0,1.0,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,1.0,3.0,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,1.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,1.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,1.0,1.0,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [6]:
dev_1 = pd.read_csv(data_dir / "dev_1.csv")
dev_1 = augment_base_dataset(dev_1)
dev_1[scaled_cols] = scaler.transform(dev_1[scaled_cols])

dev_1

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,,0.0,20.0,1.0,1.0,0.5,0.923077,0.164384,,,0.376516,,,
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,0.0,20.0,1.0,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,,1.0,12.0,,1.0,,,,,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,,1.0,12.0,,1.0,,,,,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,,1.0,12.0,,1.0,,,,,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,79.8,1.0,19.0,,0.0,,,,0.170895,,0.321346,,,
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,83.3,1.0,19.0,,0.0,,,,0.178231,,0.309095,,,
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,,12.0,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,73.1,,12.0,1.0,0.0,1.0,0.961538,0.155205,0.398451,0.608521,,0.617108,0.729087,0.638477


In [7]:
dev_2 = pd.read_csv(data_dir / "dev_2.csv")
dev_2 = augment_base_dataset(dev_2)
dev_2[scaled_cols] = scaler.transform(dev_2[scaled_cols])

dev_2

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,0.0,20.0,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,,,,1.0,,,,0.071956,0.548307,,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,72.9,,12.0,1.0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,,,12.0,1.0,1.0,1.0,1.000000,0.164384,,0.549210,,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,,,12.0,1.0,1.0,1.0,0.961538,0.109589,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,,,19.0,1.0,0.0,3.0,0.923077,0.223699,,0.357020,,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,,,19.0,1.0,0.0,3.0,0.846154,0.168904,,0.352043,,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,,12.0,,0.0,,,,0.416382,0.602438,,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,,,12.0,,0.0,,,,,,,,,


In [8]:
submission = pd.read_csv(data_dir / "sample_submission.csv")

submission.values[1]

array(['6b6a7136f42a8dbd469a201b88e2abb54a93667822761357db2f6d620da6af8a_0_Ventricles_test_A',
       40613.0818580834], dtype=object)

In [9]:
test_A = pd.read_csv(data_dir / "test_A.csv")
test_A = augment_base_dataset(test_A)
test_A[scaled_cols] = scaler.transform(test_A[scaled_cols])

test_A

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
247,00d5e0050fbd3b6b610f6673347232eb0862df77b5b7a8...,0,,,16.0,1.0,0.0,0.5,0.961538,0.219178,,,,,,
819,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,0,72.5,1.0,12.0,,1.0,,,,0.057498,0.612302,0.423268,0.291182,0.433004,0.329131
276,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,6,73.0,1.0,12.0,,1.0,,,,0.067972,,0.399942,,,
350,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,12,73.5,1.0,12.0,1.0,1.0,2.0,0.769231,0.365342,0.077516,,0.415324,,,
1268,024efbff9265302acd00190e57ee08ba1fe1b90f561f79...,0,,0.0,14.0,1.0,1.0,2.0,1.000000,0.164384,,,0.515223,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
841,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,0,,,18.0,1.0,1.0,1.5,0.807692,0.150685,,,,,,
330,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,6,,,18.0,1.0,1.0,1.5,0.769231,0.095890,,,,,,
939,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,24,,,18.0,1.0,1.0,1.5,0.769231,0.150685,,,,,,
119,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,48,70.9,,18.0,1.0,1.0,2.5,0.807692,0.246575,0.307697,0.420993,,0.392416,0.577719,0.403872


In [10]:
test_B = pd.read_csv(data_dir / "test_B.csv")
test_B = augment_base_dataset(test_B)
test_B[scaled_cols] = scaler.transform(test_B[scaled_cols])

test_B

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
1181,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,0,71.4,,15.0,0.0,2.0,0.0,0.961538,0.077671,0.085164,0.638939,,0.608113,0.424862,0.523781
1426,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,36,74.4,,15.0,0.0,2.0,0.0,1.000000,0.027397,0.089750,,,,,
1201,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,0,,0.0,,1.0,0.0,0.5,0.846154,0.196301,,0.345711,0.286043,0.312698,0.276821,0.248579
757,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,6,,0.0,,1.0,0.0,1.0,1.000000,0.283151,,0.345147,0.278219,0.378307,0.289480,0.253793
763,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,12,,0.0,,1.0,0.0,2.5,0.807692,0.168904,,0.329233,0.253372,0.352028,0.259842,0.222042
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1423,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,0,,,18.0,1.0,1.0,1.5,0.884615,0.114110,,0.502370,,0.394356,0.397160,0.531003
558,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,12,,,18.0,1.0,1.0,1.5,0.923077,0.242055,,0.519639,,0.294356,0.416522,0.545575
70,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,84,,0.0,18.0,1.0,1.0,1.5,1.000000,0.178082,,0.432054,0.483387,0.363316,0.468451,0.508440
480,ffa86109ba8684f31325842d0ff26568e105f0f63b366a...,0,66.3,,13.0,0.0,0.0,0.0,0.923077,0.118767,0.177669,,,,,


In [11]:
test_A.isna().sum()

RID_HASH          0
VISCODE           0
AGE             612
PTGENDER_num    626
PTEDUCAT         65
DX_num          428
APOE4            49
CDRSB           428
MMSE            428
ADAS13          428
Ventricles      612
Hippocampus     668
WholeBrain      626
Entorhinal      668
Fusiform        668
MidTemp         668
dtype: int64

In [12]:
test_A.columns

Index(['RID_HASH', 'VISCODE', 'AGE', 'PTGENDER_num', 'PTEDUCAT', 'DX_num',
       'APOE4', 'CDRSB', 'MMSE', 'ADAS13', 'Ventricles', 'Hippocampus',
       'WholeBrain', 'Entorhinal', 'Fusiform', 'MidTemp'],
      dtype='object')

## Emulate missingness

In [13]:
def copy_missingness(ref_data):
    ref_data_ids = ref_data["RID_HASH"].unique()

    len_to_miss = {}
    for rid in ref_data_ids:
        local_A = ref_data[ref_data["RID_HASH"] == rid]
        # print(len(local_A), local_A.isna().sum().sum())

        local_len = len(local_A)
        if local_len not in len_to_miss:
            len_to_miss[local_len] = []
        for reps in range(5):
            len_to_miss[local_len].append(local_A.notna().reset_index(drop=True))

    out_data = pd.DataFrame([], columns=dev_set.columns)
    out_data_ids = dev_set["RID_HASH"].unique()
    for rid in out_data_ids:
        local_A = dev_set[dev_set["RID_HASH"] == rid].copy().reset_index(drop=True)
        local_len = len(local_A)

        if local_len in len_to_miss and len(len_to_miss[local_len]) > 0:
            target_mask = len_to_miss[local_len].pop(0)
            out_data = pd.concat([out_data, local_A[target_mask]], ignore_index=True)
        else:
            out_data = pd.concat([out_data, local_A], ignore_index=True)

    return out_data

In [14]:
dev_sim_A = copy_missingness(test_A)

dev_sim_A

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,0,20,,1.0,,,,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,,0,20,1.0,1.0,1.5,0.923077,0.237397,,0.548307,0.366398,0.40388,0.193367,0.397291
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,72.9,1,12,,1.0,,,,0.142655,,0.235599,,,
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,,1,12,,1.0,,,,,,0.230361,,,
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,,1,12,,1.0,,,,,,0.215944,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,79.8,1,19.0,1.0,0.0,3.0,0.923077,0.223699,0.170895,0.35702,0.321346,0.310935,0.399047,0.461476
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,83.3,1,19.0,1.0,0.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.25679,0.372685,0.416478
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,0.0,12,1.0,0.0,0.5,0.884615,0.150685,0.416382,,0.636654,,,
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,,,12,1.0,0.0,1.0,0.961538,0.155205,,,,,,


In [15]:
dev_sim_B = copy_missingness(test_B)

dev_sim_B

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,,20,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,,20,1.0,1.0,1.5,0.923077,0.237397,0.071956,,,,,
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,,,12,1.0,1.0,1.0,1.0,0.123288,,0.525169,,0.513404,0.356253,0.294774
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,,,12,1.0,1.0,1.0,1.0,0.164384,,,,,,
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,,,12,1.0,1.0,1.0,0.961538,0.109589,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,,1,,1.0,0.0,3.0,0.923077,0.223699,,,0.321346,,,
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,,1,,1.0,0.0,3.0,0.846154,0.168904,,,0.309095,,,
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,,,12,1.0,0.0,0.5,0.884615,0.150685,,,,,,
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,,,12,1.0,0.0,1.0,0.961538,0.155205,,,,,,


In [16]:
dataframe_hash(dev_sim_B)

'1718898274424252771'

## Baseline imputation

In [17]:
from hyperimpute.plugins.imputers import Imputers

# VISCODE 6 * x -> AGE 0.5 * x

const_by_patient = ["PTGENDER_num", "PTEDUCAT", "APOE4"]


def prepare_consts(train_data, test_data):
    test_data = test_data.copy()
    train_data = train_data.copy()

    train_data = train_data.sort_values(["RID_HASH", "VISCODE"])
    test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

    for item in test_data.groupby("RID_HASH"):
        local = item[1]

        # fill consts
        for col in const_by_patient:
            if len(local[col].unique()) == 1:
                continue
            rid = local["RID_HASH"].unique()[0]

            val = local[col][~local[col].isna()].unique()[0]
            local[col] = local[col].fillna(val)
            test_data.loc[test_data["RID_HASH"] == rid, col] = test_data[
                test_data["RID_HASH"] == rid
            ][col].fillna(val)
            assert len(local[col].unique()) == 1, col

    return test_data


def prepare_age(train_data, test_data):
    test_data = test_data.copy()
    train_data = train_data.copy()

    train_data = train_data.sort_values(["RID_HASH", "VISCODE"])
    test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

    col = "AGE"

    for rid in test_data["RID_HASH"].unique():
        local = test_data[test_data["RID_HASH"] == rid]

        # fill age
        ages = local["AGE"]
        if ages.isna().sum() == 0:
            continue

        if ages.isna().sum() == len(ages):
            continue

        # forward impute age
        prev_viscode = 0
        prev_age = 0
        for idx, row in local.iterrows():
            current_viscode = row["VISCODE"]
            local_idx = (test_data["VISCODE"] == current_viscode) & (
                test_data["RID_HASH"] == rid
            )
            if prev_age > 0 and prev_age == prev_age:
                pred_age = (current_viscode - prev_viscode) / 6 * 0.5 + prev_age
            else:
                pred_age = row[col]

            if pred_age == pred_age:
                # print("forward imputed", pred_age, current_viscode)
                test_data.loc[local_idx, col] = test_data.loc[local_idx][col].fillna(
                    pred_age
                )

            prev_viscode = row["VISCODE"]
            prev_age = pred_age

        # reverse impute age
        prev_viscode = 0
        prev_age = 0
        for idx, row in local.iloc[::-1].iterrows():
            current_viscode = row["VISCODE"]
            local_idx = (test_data["VISCODE"] == current_viscode) & (
                test_data["RID_HASH"] == rid
            )

            if prev_age > 0 and prev_age == prev_age:
                pred_age = prev_age - (prev_viscode - current_viscode) / 6 * 0.5
            else:
                pred_age = row[col]

            if pred_age == pred_age:
                # print("reversed imputed", pred_age, current_viscode)
                test_data.loc[local_idx, col] = test_data.loc[local_idx][col].fillna(
                    pred_age
                )

            prev_viscode = row["VISCODE"]
            prev_age = pred_age

        # print(test_data[(test_data["RID_HASH"] == rid)][["VISCODE", "AGE"]])
    return test_data


def interm_imputation(train_data, test_data):
    test_data = test_data.copy()

    for rid in test_data["RID_HASH"].unique():
        local = test_data[test_data["RID_HASH"] == rid]

        local = local.ffill()
        local = local.bfill()

        test_data.loc[test_data["RID_HASH"] == rid] = local

    return test_data


def full_imputation(train_data, test_data, random_state: int = 0):
    imputed_test_data = test_data.copy()

    imputer_kwargs = {
        "optimizer": "simple",
        "classifier_seed": ["xgboost"],
        "regression_seed": ["xgboost_regressor"],
        "class_threshold": cat_limit,
        "random_state": random_state,
    }

    imputer = Imputers().get(
        "hyperimpute",
        **imputer_kwargs,
    )
    imputation_input = pd.concat([train_data, test_data], ignore_index=True)
    imputed_test_data = imputer.fit_transform(imputation_input)
    imputed_test_data = imputed_test_data.tail(len(test_data))

    return imputed_test_data


def evaluate_static_imputation(train_data, test_data, static_imputation):
    test_data = test_data.copy()
    train_data = train_data.copy()

    train_data = train_data.sort_values(["RID_HASH", "VISCODE"])
    test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

    for rid in test_data["RID_HASH"].unique():
        patient = test_data[test_data["RID_HASH"] == rid]
        misses = []
        viscodes = []
        for idx, row in patient.iterrows():
            misses.append(row.isna().sum())
            viscodes.append(row["VISCODE"])
        cidx = np.argmin(misses)

        current_viscode = viscodes[cidx]
        local_idx = (test_data["VISCODE"] == current_viscode) & (
            test_data["RID_HASH"] == rid
        )
        imputed_idx = (static_imputation["VISCODE"] == current_viscode) & (
            static_imputation["RID_HASH"] == rid
        )

        if len(test_data[local_idx]) == 0:
            continue

        for col in test_data.columns:
            val = test_data.loc[local_idx][col].values[0]
            if val == val:
                continue
            imputed_val = static_imputation.loc[imputed_idx][col].values[0]
            test_data.loc[local_idx, col] = imputed_val

            # print("imputed", test_data.loc[local_idx, col])

    return test_data


def impute_data_step(
    train_data,
    test_data,
    use_longitudinal=True,
    static_strategy="missmin",
    random_state: int = 0,
):
    test_id = dataframe_hash(test_data)
    train_id = dataframe_hash(train_data)

    print(" >>> Evaluate constants", test_id, test_data.isna().sum().sum())
    test_data = prepare_consts(train_data, test_data)
    test_data = prepare_age(train_data, test_data)
    test_data = interm_imputation(train_data, test_data)

    print(
        " >>> Evaluate static imputation",
        test_id,
        test_data.isna().sum().sum(),
        static_strategy,
    )
    static_imputation = full_imputation(
        train_data, test_data, random_state=random_state
    )
    test_data = evaluate_static_imputation(train_data, test_data, static_imputation)

    print(" >>> Evaluate constants take 2", test_id, test_data.isna().sum().sum())
    test_data = prepare_consts(train_data, test_data)
    test_data = prepare_age(train_data, test_data)

    test_data = interm_imputation(train_data, test_data)

    assert test_data.isna().sum().sum() == 0, test_data

    return test_data


def impute_baseline_data(train_data, test_data, random_state: int = 0):
    test_id = dataframe_hash(test_data)
    train_id = dataframe_hash(train_data)

    bkp_file = (
        workspace
        / f"seed_imputation_{version}_{train_id}_{test_id}_{random_state}_catlimit{cat_limit}.csv"
    )

    if bkp_file.exists():
        print("Using cached ", bkp_file)

        return pd.read_csv(bkp_file)

    print("Evaluate ", bkp_file)
    imputed_test_data = impute_data_step(
        train_data,
        test_data,
        random_state=random_state,
    )
    imputed_test_data.to_csv(bkp_file, index=None)

    return imputed_test_data

In [18]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

nn_scaled_cols = ["AGE", "CDRSB", "PTEDUCAT"]
nn_scaler = MinMaxScaler().fit(dev_set[nn_scaled_cols])
# tabular_encoder_static = TabularEncoder(categorical_limit = cat_limit).fit(dev_set_static.reset_index(drop = True))
# tabular_encoder_temporal = TabularEncoder(categorical_limit = cat_limit).fit(dev_set_temporal.reset_index(drop = True))


def mask_columns_map(s: str):
    return f"masked_{s}"


def generate_testcase(ref_df):
    baseline_imputation = impute_baseline_data(dev_set, ref_df).reset_index(drop=True)

    baseline_imputation_nn = baseline_imputation.copy()
    baseline_imputation_nn[nn_scaled_cols] = nn_scaler.transform(
        baseline_imputation_nn[nn_scaled_cols]
    )

    baseline_imputation_static = baseline_imputation_nn.sort_values(
        ["RID_HASH", "VISCODE"]
    )[static_features]
    baseline_imputation_temporal = baseline_imputation_nn.sort_values(
        ["RID_HASH", "VISCODE"]
    )[temporal_features]

    mask = (
        ref_df.isna()
        .astype(int)
        .drop(columns=["RID_HASH", "VISCODE"])
        .rename(mask_columns_map, axis="columns")
    ).reset_index(drop=True)

    full_input = pd.concat(
        [
            baseline_imputation_static.drop(columns=["AGE"]),
            baseline_imputation_temporal.drop(columns=["RID_HASH"]),
            mask,
        ],
        axis=1,
    )

    return baseline_imputation, full_input

In [19]:
testcases = []
for src in [
    dev_1,
    dev_2,
    dev_sim_A,
    dev_sim_B,
]:
    _, src_input = generate_testcase(src)
    testcases.append(src_input)

full_output = dev_set.copy()
full_output[nn_scaled_cols] = nn_scaler.transform(full_output[nn_scaled_cols])

full_output_static = full_output.sort_values(["RID_HASH", "VISCODE"])[static_features]
# full_output_static =  tabular_encoder_static.transform(full_output_static.reset_index(drop = True))

full_output_temporal = full_output.sort_values(["RID_HASH", "VISCODE"])[
    temporal_features
]
# full_output_temporal =  tabular_encoder_temporal.transform(full_output_temporal.reset_index(drop = True))

full_output_static

Using cached  workspace/seed_imputation_take7_4097467927144633164_8477102391824886331_0_catlimit10.csv
Using cached  workspace/seed_imputation_take7_4097467927144633164_6199915737732549321_0_catlimit10.csv
Using cached  workspace/seed_imputation_take7_4097467927144633164_1578441905907828792_0_catlimit10.csv
Using cached  workspace/seed_imputation_take7_4097467927144633164_1718898274424252771_0_catlimit10.csv


Unnamed: 0,RID_HASH,AGE,PTGENDER_num,PTEDUCAT,APOE4
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0.574419,0,1.0000,1.0
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0.586047,0,1.0000,1.0
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0.430233,1,0.5000,1.0
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0.441860,1,0.5000,1.0
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0.453488,1,0.5000,1.0
...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,0.590698,1,0.9375,0.0
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,0.672093,1,0.9375,0.0
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0.411628,0,0.5000,0.0
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0.434884,0,0.5000,0.0


In [20]:
full_output_temporal

Unnamed: 0,RID_HASH,VISCODE,DX_num,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,1.0,0.03125,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,1.0,0.09375,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,1.0,0.06250,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,1.0,0.06250,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,1.0,0.06250,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,1.0,0.18750,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,1.0,0.18750,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,1.0,0.03125,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,1.0,0.06250,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [21]:
# activation_layout_static = tabular_encoder_static.activation_layout(discrete_activation = "softmax")[1 : ]
# activation_layout_temporal = tabular_encoder_temporal.activation_layout(discrete_activation = "softmax")[2 : ]

# activation_layout_static

In [22]:
# activation_layout_temporal

In [23]:
from hyperimpute.utils.serialization import (load_model_from_file,
                                             save_model_to_file)
from sklearn.model_selection import train_test_split

from ts_imputer import TimeSeriesImputer, modes

for mode in modes:
    print("Training", mode)
    bkp_file = workspace / f"nn_imputer_mode_{mode}.bkp"

    if bkp_file.exists():
        continue

    imputer = TimeSeriesImputer(
        n_units_in=testcases[0].shape[-1] - 1,  # DROP RID_HASH
        n_units_out_static=full_output_static.shape[-1] - 1,  # DROP RID_HASH
        n_units_out_temporal=full_output_temporal.shape[-1]
        - 2,  # DROP RID_HASH and VISCODE
        nonlin="relu",
        dropout=0.05,
        # nonlin_out_static = activation_layout_static,
        # nonlin_out_temporal = activation_layout_temporal,
        n_layers_hidden=1,
        n_units_hidden=100,
        n_iter=10000,
        mode=mode,
        residual=True,
    )

    for outer_iter in range(3):
        for idx, full_input in enumerate(testcases):
            (
                train_input,
                test_input,
                train_output_static,
                test_output_static,
                train_output_temporal,
                test_output_temporal,
            ) = train_test_split(
                full_input, full_output_static, full_output_temporal, random_state=0
            )
            imputer.fit(
                train_input,
                train_output_static,
                train_output_temporal,
                test_input,
                test_output_static,
                test_output_temporal,
            )
    save_model_to_file(bkp_file, imputer)

Training LSTM
Training GRU
Training RNN
Training Transformer
Training XceptionTime
Training ResCNN


In [24]:
from hyperimpute.utils.serialization import (load_model_from_file,
                                             save_model_to_file)


def get_latent_imputer(mode: str = "LSTM"):
    bkp_file = workspace / f"nn_imputer_mode_{mode}.bkp"
    return load_model_from_file(bkp_file)


def generate_latent_repr(mode: str = "LSTM"):
    imputer = get_latent_imputer(mode=mode)

    output = []
    for idx, full_input in enumerate(testcases):
        test_id = dataframe_hash(full_input)
        bkp_file = workspace / f"latent_repr_testcase_{test_id}_{mode}.bkp"
        if bkp_file.exists():
            latent = load_model_from_file(bkp_file)
        else:
            latent = imputer.predict_latent(full_input)
            save_model_to_file(bkp_file, latent)
        output.append(latent)
    return output


train_latents = {}
for mode in modes:
    train_latents[mode] = generate_latent_repr(mode=mode)

train_latents["LSTM"][0]

Unnamed: 0,RID_HASH,0,1,2,3,4,5,6,7,8,...,90,91,92,93,94,95,96,97,98,99
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158,-0.02159,-0.005957,0.006697,0.003949,0.007376,0.010806,-0.002373,-0.013869,-0.00695,...,-0.007564,-0.000847,-0.000008,0.599474,-0.006181,-0.003592,-0.000002,-0.000653,0.053731,-0.002236
1,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c,-0.007806,-0.002332,0.007325,0.003889,0.001795,0.00777,-0.001229,-0.010535,-0.00616,...,-0.006332,-0.000608,-0.000005,0.581468,-0.00751,-0.001309,-0.000003,-0.001092,0.07342,-0.003886
2,0131f7f44ff183309c590b9ff440806b20f639c90c124da03f0c76b377cd6e2b,-0.016515,-0.005194,0.005585,0.003626,0.00779,0.009186,-0.002661,-0.012105,-0.005843,...,-0.006956,-0.00077,-0.000007,0.286756,-0.005473,-0.003803,-0.000002,-0.000717,0.241561,-0.002195
3,01513c9ff1e8fcc22cbfc9093845a37ee69307e3493daf0697429bd4d177d5e6,-0.007519,-0.001385,0.004775,0.003802,0.003106,0.004674,-0.000583,-0.007968,-0.002259,...,-0.003769,-0.000458,-0.000005,0.490799,-0.008589,-0.001029,-0.000003,-0.000053,0.010912,-0.006351
4,01705aaf2c869203d7a8374472f5907f53f3b15f7b4faa4af169b8843859c4cb,-0.029047,-0.008339,0.007929,0.005037,0.010965,0.014077,-0.002508,-0.01844,-0.008212,...,-0.008052,-0.001367,-0.000013,0.509389,-0.00721,-0.004685,-0.000003,0.000324,0.152769,-0.003189
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1221,ff1d8cc22fb5bf2bd80e31d6d3a6cf1709562bb7e9a22f405074a77fb34ac067,-0.03032,-0.009759,0.01042,0.005631,0.010822,0.016865,-0.00337,-0.021342,-0.011296,...,-0.010843,-0.001465,-0.000013,0.631204,-0.007648,-0.005287,-0.000003,-0.00071,0.182328,-0.001783
1222,ff21c0f13c9535e8339ce653a268b26df8e4172212ac0588b1e6b69cd257dfd8,-0.01952,-0.002734,0.004843,0.004434,0.006304,0.007059,-0.000204,-0.011157,-0.002187,...,-0.003595,-0.000811,-0.000009,0.693429,-0.009381,-0.001682,-0.000003,0.001256,-0.103618,-0.007839
1223,ff48382bcf5922a2db52db36c791b02910015feee82505f411dd74b35cb0f4ce,-0.008655,-0.002018,0.000828,0.002387,0.005119,0.001321,-0.001902,-0.003033,0.000591,...,-0.00355,0.000196,0.0,0.447165,-0.006419,-0.002216,-0.000001,-0.000947,-0.106431,-0.004012
1224,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803bcd30ea99c58dcf91d7,-0.026363,-0.005479,0.008916,0.005572,0.008208,0.013484,-0.001058,-0.018367,-0.007623,...,-0.007045,-0.001481,-0.000014,0.691466,-0.009154,-0.003018,-0.000004,0.001023,0.108964,-0.006259


In [25]:
# Benchmark static predictors

import numpy as np
from hyperimpute.plugins.prediction import Classifiers, Regression
from hyperimpute.utils.tester import evaluate_estimator, evaluate_regression
from sklearn.preprocessing import LabelEncoder


def prepare_static_feature_data(
    dataset: pd.DataFrame, latents: dict, feature: str, mode: str
):
    dbg_Y = (
        dataset.drop_duplicates("RID_HASH")[feature]
        .reset_index(drop=True)
        .values.astype(float)
    )

    labels = pd.Series(dbg_Y.tolist() * len(latents[mode]))

    covariates = []
    for latent in latents[mode]:
        covariates.append(latent.drop(columns=["RID_HASH"]))
    covariates = pd.concat(covariates, ignore_index=True).astype(float)

    return covariates, labels


def prepare_temporal_feature_data(
    dataset: pd.DataFrame, latents: dict, feature: str, mode: str
):
    working_target = [[], [], [], []]
    working_latents = [[], [], [], []]

    for rid in dataset["RID_HASH"].unique():
        patient = dataset[dataset["RID_HASH"] == rid]

        patient_target = patient[feature]
        patient_viscode = patient["VISCODE"]

        for idx, latent in enumerate(latents[mode]):
            patient_latent = latent[latent["RID_HASH"] == rid]
            patient_latent_data = patient_latent.loc[
                patient_latent.index.repeat(len(patient_viscode))
            ].reset_index(drop=True)
            patient_latent_data["VISCODE"] = patient_viscode.values
            patient_latent_data = patient_latent_data.drop(columns=["RID_HASH"])

            working_latents[idx].append(patient_latent_data)
            working_target[idx].append(patient_target)

    full_latents = []
    full_targets = []
    for idx, latent in enumerate(working_latents):
        full_latents.append(pd.concat(latent, ignore_index=True))
        full_targets.append(pd.concat(working_target[idx], ignore_index=True))

    covariates = pd.concat(full_latents, ignore_index=True).astype(float)
    labels = pd.concat(full_targets, ignore_index=True)

    return covariates, labels


def benchmark_static_feature(feature, base_model="xgboost"):
    print("Benchmarking static", feature)
    for mode in modes:
        covariates, labels = prepare_static_feature_data(
            dev_set, train_latents, feature, mode
        )

        if len(np.unique(labels)) < cat_limit:
            encoded_labels = LabelEncoder().fit_transform(labels)

            eval_model = Classifiers().get(base_model)
            score = evaluate_estimator(
                eval_model, covariates, pd.Series(encoded_labels)
            )["str"]
        else:
            eval_model = Regression().get(f"{base_model}_regressor")
            score = evaluate_regression(eval_model, covariates.values, labels.values)[
                "str"
            ]

        print(" >>> ", mode, score)


def benchmark_temporal_feature(feature):
    print("Benchmarking temporal ", feature)
    for mode in modes:
        covariates, labels = prepare_temporal_feature_data(
            dev_set, train_latents, feature, mode
        )

        if len(np.unique(labels)) < cat_limit:
            encoded_labels = LabelEncoder().fit_transform(labels)

            eval_model = Classifiers().get("xgboost")
            score = evaluate_estimator(
                eval_model, covariates, pd.Series(encoded_labels)
            )["str"]
        else:
            eval_model = Regression().get("xgboost_regressor")
            score = evaluate_regression(eval_model, covariates.values, labels.values)[
                "str"
            ]

        print(" >>> ", mode, score)

        # static


benchmark_static_feature("AGE")
benchmark_static_feature("PTGENDER_num")
benchmark_static_feature("PTEDUCAT")
benchmark_static_feature("APOE4")

# temporal
benchmark_temporal_feature("AGE")
benchmark_temporal_feature("DX_num")
benchmark_temporal_feature("CDRSB")
benchmark_temporal_feature("MMSE")
benchmark_temporal_feature("ADAS13")
benchmark_temporal_feature("Ventricles")
benchmark_temporal_feature("Hippocampus")
benchmark_temporal_feature("WholeBrain")
benchmark_temporal_feature("Entorhinal")
benchmark_temporal_feature("Fusiform")
benchmark_temporal_feature("MidTemp")

In [None]:
## Train latent predictors
import numpy as np
from hyperimpute.plugins.prediction import Classifiers, Regression
from sklearn.preprocessing import LabelEncoder

static_features_imputation = ["AGE", "PTGENDER_num", "PTEDUCAT", "APOE4"]
temporal_features_imputation = [
    "AGE",
    "DX_num",
    "CDRSB",
    "MMSE",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]


def train_static_feature_predictor(feature: str, mode: str, base_model="xgboost"):
    bkp_file = workspace / f"latent_predictor_static_{feature}_{mode}_{base_model}.bkp"
    print("Train static", feature, mode, bkp_file)

    if bkp_file.exists():
        model = load_model_from_file(bkp_file)
        return model

    covariates, labels = prepare_static_feature_data(
        dev_set, train_latents, feature, mode
    )

    if len(np.unique(labels)) < cat_limit:
        model = Classifiers().get(base_model)
    else:
        model = Regression().get(f"{base_model}_regressor")

    model.fit(covariates, labels)

    save_model_to_file(bkp_file, model)

    return model


def predict_static_feature(feature: str, mode: str):
    bkp_file = workspace / f"latent_predictor_static_{feature}_{mode}_{base_model}.bkp"
    print("Predict static", feature, mode, bkp_file)

    if not bkp_file.exists():
        raise RuntimeError("fail")
    model = load_model_from_file(bkp_file)

    covariates, labels = prepare_static_feature_data(dev_set, feature, mode)

    if len(np.unique(labels)) < cat_limit:
        model = Classifiers().get(base_model)
    else:
        model = Regression().get(f"{base_model}_regressor")

    model.fit(covariates, labels)

    save_model_to_file(bkp_file, model)

    return model


def train_temporal_feature_predictor(feature, mode: str, base_model="xgboost"):
    bkp_file = (
        workspace / f"latent_predictor_temporal_{feature}_{mode}_{base_model}.bkp"
    )
    print("Train temporal", feature, mode, bkp_file)
    if bkp_file.exists():
        model = load_model_from_file(bkp_file)
        return model

    covariates, labels = prepare_temporal_feature_data(
        dev_set, train_latents, feature, mode
    )

    if len(np.unique(labels)) < cat_limit:
        model = Classifiers().get("xgboost")
    else:
        model = Regression().get("xgboost_regressor")

    model.fit(covariates, labels)

    save_model_to_file(bkp_file, model)
    return model


for mode in modes:
    for feature in static_features_imputation:
        train_static_feature_predictor(feature, mode)

for mode in modes:
    for feature in temporal_features_imputation:
        train_temporal_feature_predictor(feature, mode)

Train static AGE LSTM workspace/latent_predictor_static_AGE_LSTM_xgboost.bkp
Train static PTGENDER_num LSTM workspace/latent_predictor_static_PTGENDER_num_LSTM_xgboost.bkp
Train static PTEDUCAT LSTM workspace/latent_predictor_static_PTEDUCAT_LSTM_xgboost.bkp
Train static APOE4 LSTM workspace/latent_predictor_static_APOE4_LSTM_xgboost.bkp
Train static AGE GRU workspace/latent_predictor_static_AGE_GRU_xgboost.bkp
Train static PTGENDER_num GRU workspace/latent_predictor_static_PTGENDER_num_GRU_xgboost.bkp
Train static PTEDUCAT GRU workspace/latent_predictor_static_PTEDUCAT_GRU_xgboost.bkp
Train static APOE4 GRU workspace/latent_predictor_static_APOE4_GRU_xgboost.bkp
Train static AGE RNN workspace/latent_predictor_static_AGE_RNN_xgboost.bkp
Train static PTGENDER_num RNN workspace/latent_predictor_static_PTGENDER_num_RNN_xgboost.bkp
Train static PTEDUCAT RNN workspace/latent_predictor_static_PTEDUCAT_RNN_xgboost.bkp
Train static APOE4 RNN workspace/latent_predictor_static_APOE4_RNN_xgboost.

In [None]:
# derived from the benchmarks
static_features_config = [
    ("PTGENDER_num", "ResCNN"),
    ("PTEDUCAT", "XceptionTime"),
    ("APOE4", "XceptionTime"),
]

temporal_features_config = [
    ("AGE", "Transformer"),
    ("DX_num", "Transformer"),
    ("CDRSB", "LSTM"),
    ("MMSE", "LSTM"),
    ("ADAS13", "LSTM"),
    ("Ventricles", "Transformer"),
    ("Hippocampus", "Transformer"),
    ("WholeBrain", "Transformer"),
    ("Entorhinal", "LSTM"),
    ("Fusiform", "Transformer"),
    ("MidTemp", "Transformer"),
]

In [None]:
def impute_data(ref_df):
    ref_mask = ref_df.isna().astype(int)
    ref_id = dataframe_hash(ref_df)
    output_df = ref_df.copy().reset_index(drop=True)
    print("Imputing ", ref_id)
    # create baseline imputation and NN input
    baseline_imputation, nn_input = generate_testcase(ref_df)

    # generate latent representations
    latent_bk_file = workspace / f"latent_representation_{ref_id}.bkp"
    if latent_bk_file.exists():
        local_latents = load_model_from_file(latent_bk_file)
    else:
        local_latents = {}
        for mode in modes:
            imputer = get_latent_imputer(mode=mode)

            local_latents[mode] = imputer.predict_latent(nn_input)
        save_model_to_file(latent_bk_file, local_latents)

    # impute constants
    output_df = prepare_consts(dev_set, output_df)
    output_df = prepare_age(dev_set, output_df)

    return local_latents


impute_data(dev_1)["LSTM"]
# impute_data(dev_2)
# impute_data(test_A)
# impute_data(test_B)

In [None]:
raise

In [None]:
def normalize_output(test_data):
    test_data = test_data.copy()
    factor = test_data["CDRSB"] / 0.5
    factor[factor < 0] = 0

    factor = factor.fillna(-1)
    factor = factor.round(0).astype(int)
    factor = factor.replace(-1, np.nan)
    test_data["CDRSB"] = factor * 0.5

    test_data["ADAS13"] = ((test_data["ADAS13"] * 3).round(0) / 3).round(2)
    test_data["MMSE"] = test_data["MMSE"].round(0)

    return test_data


def post_processing(
    imputed,
    baseline_imputation,
    miss_data,
):
    ordered_cols = list(miss_data.columns)

    gt_mask = miss_data.isna().astype(int)
    gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"])
    gt_mask = gt_mask[ordered_cols]

    imputed.index = miss_data.index
    imputed = imputed[ordered_cols]

    use_from_baseline = ["AGE", "PTGENDER_num", "DX_num", "PTEDUCAT", "APOE4"]

    imputed[use_from_baseline] = baseline_imputation[use_from_baseline]

    imputed[scaled_cols] = scaler.inverse_transform(imputed[scaled_cols])
    imputed = normalize_output(imputed)
    imputed[scaled_cols] = scaler.transform(imputed[scaled_cols])

    return imputed


imputation_baseline = pd.concat(
    [dev_1_baseline, dev_2_baseline], ignore_index=True
).reset_index(drop=True)
imputation_baseline = imputation_baseline.sort_values(
    ["RID_HASH", "VISCODE"]
).reset_index(drop=True)


miss_data = pd.concat([dev_1, dev_2], ignore_index=True).isna().astype(int)
miss_data = miss_data.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)

final_full_preds = post_processing(full_preds, imputation_baseline, miss_data)

final_full_preds

In [None]:
dev_set

In [None]:
from hyperimpute.utils.benchmarks import RMSE

gt = pd.concat([dev_set, dev_set], ignore_index=True)
gt = gt.sort_values(["RID_HASH", "VISCODE"])
gt_mask = pd.concat([dev_1, dev_2], ignore_index=True).isna().astype(int)
gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"])

imputation_baseline = pd.concat([dev_1_baseline, dev_2_baseline], ignore_index=True)
imputation_baseline = imputation_baseline.sort_values(["RID_HASH", "VISCODE"])
imputation_baseline = imputation_baseline.drop(columns=["RID_HASH", "VISCODE"]).values

RMSE(
    imputation_baseline,
    gt.drop(columns=["RID_HASH", "VISCODE"]).values,
    gt_mask.drop(columns=["RID_HASH", "VISCODE"]).values,
)

In [None]:
gt = pd.concat([dev_set, dev_set], ignore_index=True)
gt = gt.sort_values(["RID_HASH", "VISCODE"])
gt_mask = pd.concat([dev_1, dev_2], ignore_index=True).isna().astype(int)
gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"])

# nn_preds = tabular_encoder.inverse_transform(preds).drop(columns = ["RID_HASH", "VISCODE"]).values
# full_preds = imputer.predict(pd.concat([training_input_1, training_input_2]))
nn_preds = final_full_preds.drop(columns=["RID_HASH", "VISCODE"]).values

RMSE(
    nn_preds,
    gt.drop(columns=["RID_HASH", "VISCODE"]).values,
    gt_mask.drop(columns=["RID_HASH", "VISCODE"]).values,
)

In [None]:
dev_set

In [None]:
from hyperimpute.plugins.imputers import Imputers
from hyperimpute.utils.benchmarks import benchmark_model
from sklearn.preprocessing import LabelEncoder

gt = pd.concat([dev_set, dev_set], ignore_index=True)
gt = gt.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)

ordered_cols = list(gt.columns)


gt_mask = pd.concat([dev_1, dev_2], ignore_index=True)
gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
gt_mask = gt_mask[ordered_cols]

final_full_preds.index = gt.index
final_full_preds = final_full_preds[ordered_cols]

# nn_preds = tabular_encoder.inverse_transform(preds).drop(columns = ["RID_HASH", "VISCODE"]).values

le = LabelEncoder().fit(gt["RID_HASH"])
gt["RID_HASH"] = le.transform(gt["RID_HASH"])
# full_preds["RID_HASH"] = le.transform(full_preds["RID_HASH"])

plugin = Imputers().get(
    "hyperimpute",
    optimizer="simple",
    classifier_seed=["catboost"],
    regression_seed=["xgboost_regressor", "catboost_regressor"],
    class_threshold=cat_limit,
)

benchmark_model("nn", plugin, gt, final_full_preds, gt_mask)

In [None]:
gt

In [None]:
full_preds

## Submission data

In [None]:
def normalize_output(test_data):
    test_data = test_data.copy()
    factor = test_data["CDRSB"] / 0.5
    factor = factor.fillna(-1)
    factor = factor.round(0).astype(int)
    factor = factor.replace(-1, np.nan)
    test_data["CDRSB"] = factor * 0.5

    test_data["ADAS13"] = ((test_data["ADAS13"] * 3).round(0) / 3).round(2)
    test_data["MMSE"] = test_data["MMSE"].round(0)

    return test_data


def dump_results(imputed_data: pd.DataFrame, fpath: str):
    results = []

    for name, data in [
        ("test_A", test_A.sort_index()),
        ("test_B", test_B.sort_index()),
    ]:
        for idx, row in data.iterrows():
            for col in row.index:
                local = row.T
                val = local[col]
                if val == val:
                    continue
                imputed_id = f"{local['RID_HASH']}_{local['VISCODE']}_{col}_{name}"
                imputed_val = imputed_data[
                    (imputed_data["RID_HASH"] == local["RID_HASH"])
                    & (imputed_data["VISCODE"] == local["VISCODE"])
                ][col].values[0]

                assert imputed_val == imputed_val
                assert imputed_val != ""

                results.append([imputed_id, imputed_val])

    output = pd.DataFrame(results, columns=submission.columns)
    output.to_csv(fpath, index=None)

    return output


def get_submission_data(random_state):
    test_A_eval = impute_data(dev_set, test_A, random_state=random_state).sort_index()
    test_B_eval = impute_data(dev_set, test_B, random_state=random_state).sort_index()

    eval_data = pd.concat([dev_set, test_A_eval, test_B_eval], ignore_index=True)
    eval_data[scaled_cols] = scaler.inverse_transform(eval_data[scaled_cols])

    output_fpath = (
        results_dir
        / f"imputation_results_{version}_{changelog}_normalized_seeds_{random_state}.csv"
    )

    print("Prepare output", output_fpath)
    output_normalized = dump_results(normalize_output(eval_data), output_fpath)

    return output_fpath, output_normalized


for random_state in range(10):
    output_path, output = get_submission_data(random_state=random_state)

In [None]:
output_path, output0 = get_submission_data(random_state=0)

output0

In [None]:
output_path, output1 = get_submission_data(random_state=1)

output1

In [None]:
(output1["Predicted"] - output0["Predicted"]).sum()