In [1]:
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from hyperimpute.utils.serialization import load_model_from_file, save_model_to_file
from sklearn.preprocessing import MinMaxScaler
from baseline_imputation import prepare_consts, prepare_age

workspace = Path("workspace")
results_dir = Path("results")
data_dir = Path("data")

workspace.mkdir(parents=True, exist_ok=True)

warnings.filterwarnings("ignore")

cat_limit = 10
n_seeds = 5

version = "take8"
changelog = f"hyperlatent_transformer"

In [2]:
def dataframe_hash(df: pd.DataFrame) -> str:
    cols = sorted(list(df.columns))
    return str(abs(pd.util.hash_pandas_object(df[cols].fillna(0)).sum()))


def augment_base_dataset(df, scaler, scaled_cols):
    df = df.sort_values(["RID_HASH", "VISCODE"])
    df = prepare_consts(df)
    df = prepare_age(df, scaler, scaled_cols)

    return df

In [3]:
dev_set = pd.read_csv(data_dir / "dev_set.csv")

scaled_cols = [
    "AGE",
    "PTEDUCAT",
    "MMSE",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]

scaler = MinMaxScaler().fit(dev_set[scaled_cols])
dev_set[scaled_cols] = scaler.transform(dev_set[scaled_cols])

dev_set = augment_base_dataset(dev_set, scaler, scaled_cols)

dev_set

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.574419,0,1.0000,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.586047,0,1.0000,1.0,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.430233,1,0.5000,1.0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.441860,1,0.5000,1.0,1.0,1.0,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.453488,1,0.5000,1.0,1.0,1.0,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,0.590698,1,0.9375,1.0,0.0,3.0,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,0.672093,1,0.9375,1.0,0.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,0.411628,0,0.5000,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.434884,0,0.5000,1.0,0.0,1.0,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [4]:
static_features = ["RID_HASH", "PTGENDER_num", "PTEDUCAT", "APOE4"]
temporal_features = [
    "RID_HASH",
    "VISCODE",
    "DX_num",
    "CDRSB",
    "MMSE",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]  #

dev_set_static = dev_set.sort_values(["RID_HASH", "VISCODE"]).drop_duplicates(
    "RID_HASH"
)[static_features]
dev_set_temporal = dev_set.sort_values(["RID_HASH", "VISCODE"])[temporal_features]

dev_set_static

Unnamed: 0,RID_HASH,PTGENDER_num,PTEDUCAT,APOE4
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,1.0000,1.0
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,1,0.5000,1.0
298,0131f7f44ff183309c590b9ff440806b20f639c90c124d...,0,0.5000,0.0
1762,01513c9ff1e8fcc22cbfc9093845a37ee69307e3493daf...,1,0.5000,0.0
406,01705aaf2c869203d7a8374472f5907f53f3b15f7b4faa...,0,0.7500,0.0
...,...,...,...,...
2205,ff1d8cc22fb5bf2bd80e31d6d3a6cf1709562bb7e9a22f...,1,0.7500,1.0
1593,ff21c0f13c9535e8339ce653a268b26df8e4172212ac05...,1,0.8750,0.0
3458,ff48382bcf5922a2db52db36c791b02910015feee82505...,0,0.5000,1.0
1438,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,1,0.9375,0.0


In [5]:
dev_set_temporal

Unnamed: 0,RID_HASH,VISCODE,DX_num,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,1.0,1.0,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,1.0,1.0,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,1.0,3.0,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,1.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,1.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,1.0,1.0,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [6]:
raw_dev_1 = pd.read_csv(data_dir / "dev_1.csv")
dev_1 = augment_base_dataset(raw_dev_1, scaler, scaled_cols)
dev_1[scaled_cols] = scaler.transform(dev_1[scaled_cols])

dev_1

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.585776,0.0,1.0000,1.0,1.0,0.5,0.923077,0.164384,,,0.376516,,,
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.586047,0.0,1.0000,1.0,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.543807,1.0,0.5000,,1.0,,,,,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.544078,1.0,0.5000,,1.0,,,,,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.544348,1.0,0.5000,,1.0,,,,,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,0.590698,1.0,0.9375,,0.0,,,,0.170895,,0.321346,,,
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,0.672093,1.0,0.9375,,0.0,,,,0.178231,,0.309095,,,
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,0.411628,0.0,0.5000,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.434884,0.0,0.5000,1.0,0.0,1.0,0.961538,0.155205,0.398451,0.608521,,0.617108,0.729087,0.638477


In [7]:
raw_dev_2 = pd.read_csv(data_dir / "dev_2.csv")
dev_2 = augment_base_dataset(raw_dev_2, scaler, scaled_cols)
dev_2[scaled_cols] = scaler.transform(dev_2[scaled_cols])

dev_2

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.574419,0.0,1.0000,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.586047,0.0,1.0000,,1.0,,,,0.071956,0.548307,,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.430233,1.0,0.5000,1.0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.430503,1.0,0.5000,1.0,1.0,1.0,1.000000,0.164384,,0.549210,,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.430773,1.0,0.5000,1.0,1.0,1.0,0.961538,0.109589,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,,1.0,0.9375,1.0,0.0,3.0,0.923077,0.223699,,0.357020,,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,,1.0,0.9375,1.0,0.0,3.0,0.846154,0.168904,,0.352043,,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,0.411628,0.0,0.5000,,0.0,,,,0.416382,0.602438,,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.412169,0.0,0.5000,,0.0,,,,,,,,,


In [8]:
submission = pd.read_csv(data_dir / "sample_submission.csv")

submission.values[1]

array(['6b6a7136f42a8dbd469a201b88e2abb54a93667822761357db2f6d620da6af8a_0_Ventricles_test_A',
       40613.0818580834], dtype=object)

In [9]:
raw_test_A = pd.read_csv(data_dir / "test_A.csv")
test_A = augment_base_dataset(raw_test_A, scaler, scaled_cols)
test_A[scaled_cols] = scaler.transform(test_A[scaled_cols])

test_A_gt = pd.read_csv(data_dir / "test_A_gt.csv")
test_A_gt = augment_base_dataset(test_A_gt, scaler, scaled_cols)
test_A_gt[scaled_cols] = scaler.transform(test_A_gt[scaled_cols])

assert (test_A["VISCODE"] == test_A_gt["VISCODE"]).all()

test_A_gt

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
247,00d5e0050fbd3b6b610f6673347232eb0862df77b5b7a8...,0.0,0.625581,1.0,0.7500,1.0,0.0,0.5,0.961538,0.219178,0.274673,0.397517,0.272565,0.405996,0.345331,0.505790
819,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,0.0,0.420930,1.0,0.5000,1.0,1.0,0.5,0.807692,0.429178,0.057498,0.612302,0.423268,0.291182,0.433004,0.329131
276,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,6.0,0.432558,1.0,0.5000,1.0,1.0,2.5,0.615385,0.360685,0.067972,0.576975,0.399942,0.302646,0.415628,0.330157
350,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,12.0,0.444186,1.0,0.5000,1.0,1.0,2.0,0.769231,0.365342,0.077516,0.563770,0.415324,0.273721,0.389962,0.316610
1268,024efbff9265302acd00190e57ee08ba1fe1b90f561f79...,0.0,0.146512,0.0,0.6250,1.0,1.0,2.0,1.000000,0.164384,0.080886,0.607675,0.515223,0.536155,0.545500,0.577839
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
841,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,0.0,0.290698,0.0,0.8750,1.0,1.0,1.5,0.807692,0.150685,0.211471,0.594244,0.515137,0.465785,0.675868,0.427332
330,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,6.0,0.302326,0.0,0.8750,1.0,1.0,1.5,0.769231,0.095890,0.228441,0.510497,0.470574,0.472134,0.622549,0.437075
939,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,24.0,0.337209,0.0,0.8750,1.0,1.0,1.5,0.769231,0.150685,0.243265,0.521219,0.464475,0.476190,0.640719,0.433016
119,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,48.0,0.383721,0.0,0.8750,1.0,1.0,2.5,0.807692,0.246575,0.307697,0.420993,0.525265,0.392416,0.577719,0.403872


In [10]:
test_A_gt[["RID_HASH", "PTGENDER_num"]]

Unnamed: 0,RID_HASH,PTGENDER_num
247,00d5e0050fbd3b6b610f6673347232eb0862df77b5b7a8...,1.0
819,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,1.0
276,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,1.0
350,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,1.0
1268,024efbff9265302acd00190e57ee08ba1fe1b90f561f79...,0.0
...,...,...
841,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,0.0
330,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,0.0
939,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,0.0
119,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,0.0


In [11]:
raw_test_B = pd.read_csv(data_dir / "test_B.csv")
test_B = augment_base_dataset(raw_test_B, scaler, scaled_cols)
test_B[scaled_cols] = scaler.transform(test_B[scaled_cols])

test_B_gt = pd.read_csv(data_dir / "test_B_gt.csv")
test_B_gt = augment_base_dataset(test_B_gt, scaler, scaled_cols)
test_B_gt[scaled_cols] = scaler.transform(test_B_gt[scaled_cols])

assert (test_B["VISCODE"] == test_B_gt["VISCODE"]).all()

test_B

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
1181,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,0,0.395349,,0.6875,0.0,2.0,0.0,0.961538,0.077671,0.085164,0.638939,,0.608113,0.424862,0.523781
1426,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,36,0.465116,,0.6875,0.0,2.0,0.0,1.000000,0.027397,0.089750,,,,,
1201,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,0,,0.0,,1.0,0.0,0.5,0.846154,0.196301,,0.345711,0.286043,0.312698,0.276821,0.248579
757,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,6,,0.0,,1.0,0.0,1.0,1.000000,0.283151,,0.345147,0.278219,0.378307,0.289480,0.253793
763,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,12,,0.0,,1.0,0.0,2.5,0.807692,0.168904,,0.329233,0.253372,0.352028,0.259842,0.222042
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1423,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,0,,0.0,0.8750,1.0,1.0,1.5,0.884615,0.114110,,0.502370,,0.394356,0.397160,0.531003
558,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,12,,0.0,0.8750,1.0,1.0,1.5,0.923077,0.242055,,0.519639,,0.294356,0.416522,0.545575
70,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,84,,0.0,0.8750,1.0,1.0,1.5,1.000000,0.178082,,0.432054,0.483387,0.363316,0.468451,0.508440
480,ffa86109ba8684f31325842d0ff26568e105f0f63b366a...,0,0.276744,1.0,0.5625,0.0,0.0,0.0,0.923077,0.118767,0.177669,,,,,


In [12]:
test_AB_input = pd.concat([test_A, test_B], ignore_index=True)
test_AB_output = pd.concat([test_A_gt, test_B_gt], ignore_index=True)

test_AB_output

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,00d5e0050fbd3b6b610f6673347232eb0862df77b5b7a8...,0.0,0.625581,1.0,0.7500,1.0,0.0,0.5,0.961538,0.219178,0.274673,0.397517,0.272565,0.405996,0.345331,0.505790
1,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,0.0,0.420930,1.0,0.5000,1.0,1.0,0.5,0.807692,0.429178,0.057498,0.612302,0.423268,0.291182,0.433004,0.329131
2,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,6.0,0.432558,1.0,0.5000,1.0,1.0,2.5,0.615385,0.360685,0.067972,0.576975,0.399942,0.302646,0.415628,0.330157
3,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,12.0,0.444186,1.0,0.5000,1.0,1.0,2.0,0.769231,0.365342,0.077516,0.563770,0.415324,0.273721,0.389962,0.316610
4,024efbff9265302acd00190e57ee08ba1fe1b90f561f79...,0.0,0.146512,0.0,0.6250,1.0,1.0,2.0,1.000000,0.164384,0.080886,0.607675,0.515223,0.536155,0.545500,0.577839
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2793,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,0.0,0.462791,0.0,0.8750,1.0,1.0,1.5,0.884615,0.114110,0.125561,0.502370,0.457404,0.394356,0.397160,0.531003
2794,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,12.0,0.462791,0.0,0.8750,1.0,1.0,1.5,0.923077,0.242055,0.119052,0.519639,0.458203,0.294356,0.416522,0.545575
2795,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,84.0,0.462791,0.0,0.8750,1.0,1.0,1.5,1.000000,0.178082,0.190065,0.432054,0.483387,0.363316,0.468451,0.508440
2796,ffa86109ba8684f31325842d0ff26568e105f0f63b366a...,0.0,0.276744,1.0,0.5625,0.0,0.0,0.0,0.923077,0.118767,0.177669,0.563866,0.446764,0.498944,0.474951,0.495801


In [13]:
test_A.isna().sum()

RID_HASH          0
VISCODE           0
AGE             230
PTGENDER_num    261
PTEDUCAT         31
DX_num          428
APOE4            17
CDRSB           428
MMSE            428
ADAS13          428
Ventricles      612
Hippocampus     668
WholeBrain      626
Entorhinal      668
Fusiform        668
MidTemp         668
dtype: int64

In [14]:
test_A.columns

Index(['RID_HASH', 'VISCODE', 'AGE', 'PTGENDER_num', 'PTEDUCAT', 'DX_num',
       'APOE4', 'CDRSB', 'MMSE', 'ADAS13', 'Ventricles', 'Hippocampus',
       'WholeBrain', 'Entorhinal', 'Fusiform', 'MidTemp'],
      dtype='object')

## Emulate missigness

In [15]:
def copy_missingness(ref_data):
    ref_data_ids = ref_data["RID_HASH"].unique()

    len_to_miss = {}
    for rid in ref_data_ids:
        local_A = ref_data[ref_data["RID_HASH"] == rid]
        # print(len(local_A), local_A.isna().sum().sum())

        local_len = len(local_A)
        if local_len not in len_to_miss:
            len_to_miss[local_len] = []
        for reps in range(4):
            len_to_miss[local_len].append(local_A.notna().reset_index(drop=True))

    out_data = pd.DataFrame([], columns=dev_set.columns)
    out_data_ids = dev_set["RID_HASH"].unique()
    for rid in out_data_ids:
        local_A = dev_set[dev_set["RID_HASH"] == rid].copy().reset_index(drop=True)
        local_len = len(local_A)

        if local_len in len_to_miss and len(len_to_miss[local_len]) > 0:
            target_mask = len_to_miss[local_len].pop(0)
            out_data = pd.concat([out_data, local_A[target_mask]], ignore_index=True)
        else:
            out_data = pd.concat([out_data, local_A], ignore_index=True)

    return out_data

In [16]:
dev_sim_A = copy_missingness(test_A)

dev_sim_A

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.574419,0,1.0,,1.0,,,,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.586047,0,1.0,1.0,1.0,1.5,0.923077,0.237397,,0.548307,0.366398,0.40388,0.193367,0.397291
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.430233,1,0.5,,1.0,,,,0.142655,,0.235599,,,
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.44186,1,0.5,,1.0,,,,,,0.230361,,,
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.453488,1,0.5,,1.0,,,,,,0.215944,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,0.590698,1,0.9375,1.0,0.0,3.0,0.923077,0.223699,,0.35702,,0.310935,0.399047,0.461476
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,0.672093,1,0.9375,1.0,0.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.25679,0.372685,0.416478
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,0.411628,0,0.5,,0.0,,,,,,0.636654,,,
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.434884,0,0.5,,0.0,,,,,,0.63465,,,


In [17]:
dev_sim_B = copy_missingness(test_B)

dev_sim_B

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.574419,,1.0,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.586047,,1.0,1.0,1.0,1.5,0.923077,0.237397,0.071956,,,,,
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.430233,1,0.5,1.0,1.0,1.0,1.0,0.123288,,0.525169,,0.513404,0.356253,0.294774
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.44186,1,0.5,1.0,1.0,1.0,1.0,0.164384,,,,,,
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.453488,1,0.5,1.0,1.0,1.0,0.961538,0.109589,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,0.590698,1,0.9375,,0.0,,,,0.170895,,,,,
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,0.672093,1,0.9375,,0.0,,,,0.178231,0.352043,,0.25679,0.372685,0.416478
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,0.411628,0,0.5,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.434884,0,0.5,,0.0,,,,0.398451,0.608521,,0.617108,0.729087,0.638477


## Model

In [18]:
from sklearn.preprocessing import LabelEncoder


def mask_columns_map(s: str):
    return f"masked_{s}"


def generate_testcase(ref_df, out_df, target_column: str, cat_thresh: int = cat_limit):
    assert len(ref_df) == len(out_df)
    ref_df = ref_df.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
    out_df = out_df.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)

    assert (ref_df["RID_HASH"].values == out_df["RID_HASH"].values).all()
    assert (ref_df["VISCODE"].values == out_df["VISCODE"].values).all()

    mask = (
        ref_df.isna()
        .astype(int)
        .drop(columns=["RID_HASH", "VISCODE"])
        .rename(mask_columns_map, axis="columns")
    ).reset_index(drop=True)

    target_mask = (ref_df.isna().astype(int)).reset_index(drop=True)
    target_mask = target_mask[["RID_HASH", "VISCODE", target_column]]
    target_mask[["RID_HASH", "VISCODE"]] = ref_df[["RID_HASH", "VISCODE"]]

    ref_df = ref_df.fillna(-1)
    test_input = pd.concat([ref_df, mask], axis=1).reset_index(drop=True)
    test_output = out_df[["RID_HASH", "VISCODE", target_column]].reset_index(drop=True)

    if len(dev_set[target_column].unique()) < cat_thresh:
        encoding_data = pd.concat([dev_set, out_df, test_AB_output], ignore_index = True)
        encoder = LabelEncoder().fit(encoding_data[[target_column]])

        test_output[target_column] = encoder.transform(out_df[[target_column]])


    return test_input, test_output, target_mask


def prepare_dataset(target_column: str, cat_thresh: int = cat_limit):
    df_input_1, df_output_1, target_mask_1 = generate_testcase(
        dev_1, dev_set, target_column, cat_thresh=cat_thresh
    )
    df_input_2, df_output_2, target_mask_2 = generate_testcase(
        dev_2, dev_set, target_column, cat_thresh=cat_thresh
    )
    df_input_sim_A, df_output_sim_A, target_mask_sim_A = generate_testcase(
        dev_sim_A, dev_set, target_column, cat_thresh=cat_thresh
    )
    df_input_sim_B, df_output_sim_B, target_mask_sim_B = generate_testcase(
        dev_sim_B, dev_set, target_column, cat_thresh=cat_thresh
    )

    return (
        [df_input_1, df_output_1, target_mask_1],
        [df_input_2, df_output_2, target_mask_2],
        [df_input_sim_A, df_output_sim_A, target_mask_sim_A],
        [df_input_sim_B, df_output_sim_B, target_mask_sim_B],
    )


def prepare_test_dataset(target_column: str, cat_thresh: int = cat_limit):
    return generate_testcase(
        test_AB_input, test_AB_output, target_column, cat_thresh=cat_thresh
    )

In [None]:
from ts_imputer_v2 import TimeSeriesImputerTemporal
from sklearn.model_selection import train_test_split


def train_inputer_for_column(target_column: str, n_units_hidden: int = 150):
    bkp_file = workspace / f"dedicated_imputer_col_{target_column}_{n_units_hidden}.bkp"

    print("Training imputer for", bkp_file)
    if bkp_file.exists():
        return load_model_from_file(bkp_file)

    cat_thresh = 30
    testcases = prepare_dataset(target_column=target_column, cat_thresh=cat_thresh)
    test_in, test_output, test_target_mask = prepare_test_dataset(
        target_column=target_column, cat_thresh=cat_thresh
    )

    if len(dev_set[target_column].unique()) < cat_thresh:
        n_units_out = len(dev_set[target_column].unique())
        task_type = "classification"
    else:
        n_units_out = 1
        task_type = "regression"

    imputer = TimeSeriesImputerTemporal(
        task_type=task_type,
        n_units_in=testcases[0][0].shape[-1] - 1,  # DROP RID_HASH
        n_units_out=n_units_out,  # DROP RID_HASH and VISCODE
        nonlin="relu",
        dropout=0.05,
        # nonlin_out = activation_layout,
        n_layers_hidden=2,
        n_units_hidden=n_units_hidden,
        n_iter=10000,
        residual=False,
    )

    for repeat in range(3):
        for train_input, train_output, train_target_mask in testcases:
            imputer.fit(
                train_input,
                train_output,
                train_target_mask,
                test_in,
                test_output,
                test_target_mask,
            )
    save_model_to_file(bkp_file, imputer)

    return imputer

for target_column in [
     'AGE',
     #'PTGENDER_num',
     #'PTEDUCAT',
     #'DX_num',
     #'APOE4',
     #'CDRSB',
     #'MMSE',
     'ADAS13',
     'Ventricles',
     'Hippocampus',
     'WholeBrain',
     'Entorhinal',
     'Fusiform',
     'MidTemp'
]:
    train_inputer_for_column(target_column = target_column)
    

Training imputer for workspace/dedicated_imputer_col_AGE_150.bkp
   >>> Epoch 99 train loss  = 0.31008122639226104 val loss 0.42521642233646734
   >>> Epoch 199 train loss  = 0.22757615878557166 val loss 0.41981815843610626
   >>> Epoch 299 train loss  = 0.14488891574243704 val loss 0.3962608958846715
   >>> Epoch 399 train loss  = 0.11936264174679916 val loss 0.39616788511637274
   >>> Epoch 499 train loss  = 0.15471014079715437 val loss 0.4131661773086399
   >>> Epoch 599 train loss  = 0.09216465075345089 val loss 0.46781581298521707
   >>> Epoch 699 train loss  = 0.0979943454634243 val loss 0.40021212303330234
   >>> Epoch 799 train loss  = 0.11325509126860804 val loss 0.36771297421774934
   >>> Epoch 899 train loss  = 0.1142052881186828 val loss 0.40147016456013507
   >>> Epoch 999 train loss  = 0.08598399034235626 val loss 0.4140964564174326
   >>> Epoch 1099 train loss  = 0.08091516084580992 val loss 0.41623896270492255
   >>> Epoch 1199 train loss  = 0.10231056341338747 val loss

In [None]:
imputer.predict_latent(train_input)

## Evaluation

In [None]:
assert predictions.isna().sum().sum() == 0

In [None]:
# full
from hyperimpute.plugins.imputers import Imputers
from hyperimpute.utils.benchmarks import benchmark_model

gt = pd.concat([test_A_gt, test_B_gt], ignore_index=True)
gt = gt.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
gt = gt.reset_index(drop=True)

ordered_cols = list(gt.columns)

gt_mask = pd.concat([test_A, test_B], ignore_index=True)
gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
gt_mask = gt_mask[ordered_cols].isna().astype(int)

predictions = predictions.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
predictions = predictions[ordered_cols]

plugin = Imputers().get(
    "hyperimpute",
    optimizer="simple",
)

benchmark_model(
    "nn",
    plugin,
    gt.drop(columns=["RID_HASH"]),
    predictions.drop(columns=["RID_HASH"]),
    gt_mask.drop(columns=["RID_HASH"]),
)

In [None]:
# only baseline
from hyperimpute.plugins.imputers import Imputers
from hyperimpute.utils.benchmarks import benchmark_model

gt = pd.concat([test_A_gt, test_B_gt], ignore_index=True)
gt = gt.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
gt = gt[gt["VISCODE"] == 0].reset_index(drop=True)
s = list(gt.columns)

gt_mask = pd.concat([test_A, test_B], ignore_index=True)
gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
gt_mask = gt_mask[gt_mask["VISCODE"] == 0].reset_index(drop=True)
gt_mask = gt_mask[ordered_cols].isna().astype(int)

predictions_single_visit = predictions[predictions["VISCODE"] == 0].reset_index(
    drop=True
)

plugin = Imputers().get(
    "hyperimpute",
    optimizer="simple",
    classifier_seed=["catboost"],
)

benchmark_model(
    "nn",
    plugin,
    gt.drop(columns=["RID_HASH"]),
    predictions_single_visit.drop(columns=["RID_HASH"]),
    gt_mask.drop(columns=["RID_HASH"]),
)

In [None]:
# by patient
from hyperimpute.utils.benchmarks import RMSE

gt = pd.concat([test_A_gt, test_B_gt], ignore_index=True)
gt = gt.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
gt = gt.reset_index(drop=True)

test_data = pd.concat([test_A, test_B], ignore_index=True)
test_data = test_data.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
test_data = test_data.reset_index(drop=True)

rids = list(predictions["RID_HASH"].unique())
patient_errors = []
for rid in rids:
    predicted_patient = predictions[predictions["RID_HASH"] == rid]
    gt_patient = gt[gt["RID_HASH"] == rid]
    patient_mask = test_data[test_data["RID_HASH"] == rid]
    assert len(predicted_patient) == len(gt_patient)
    assert (gt_patient["VISCODE"].values == predicted_patient["VISCODE"].values).all()

    patient_err = RMSE(
        predicted_patient.drop(columns=["RID_HASH"]).values,
        gt_patient.drop(columns=["RID_HASH"]).values,
        patient_mask.drop(columns=["RID_HASH"]).values,
    )
    patient_errors.append(patient_err)

np.argmax(patient_errors)

In [None]:
err_rank = np.argsort(patient_errors)
err_id = err_rank[-2]

In [None]:
worst_rid = rids[err_id]

predictions[predictions["RID_HASH"] == worst_rid]

In [None]:
gt[gt["RID_HASH"] == worst_rid]

In [None]:
test_data[test_data["RID_HASH"] == worst_rid]

In [None]:
single_visits = pd.DataFrame([], columns=test_data.columns)

for rid in test_data["RID_HASH"].unique():
    test_patient = gt[gt["RID_HASH"] == rid]

    if len(test_patient) > 1:
        continue

    single_visits = pd.concat([single_visits, test_patient], ignore_index=True)

single_visits

In [None]:
# ref score (0.4787056257938771, 0.3545751382105585)
# fillna(0): (0.7638940096168367, 2.5436176061930693)

In [None]:
raise

## Submission data

In [None]:
def normalize_output(test_data):
    test_data = test_data.copy()

    test_data[scaled_cols] = scaler.inverse_transform(test_data[scaled_cols])

    factor = test_data["CDRSB"] / 0.5
    factor[factor < 0] = 0
    factor = factor.fillna(-1)
    factor = factor.round(0).astype(int)
    factor = factor.replace(-1, np.nan)
    test_data["CDRSB"] = factor * 0.5

    test_data["ADAS13"] = ((test_data["ADAS13"] * 3).round(0) / 3).round(2)
    test_data["MMSE"] = test_data["MMSE"].round(0)

    return test_data


def dump_results(imputed_data: pd.DataFrame, fpath: str):
    results = []

    for name, data in [
        ("test_A", test_A.sort_index()),
        ("test_B", test_B.sort_index()),
    ]:
        for idx, row in data.iterrows():
            for col in row.index:
                local = row.T
                val = local[col]
                if val == val:
                    continue
                imputed_id = f"{local['RID_HASH']}_{local['VISCODE']}_{col}_{name}"
                imputed_val = imputed_data[
                    (imputed_data["RID_HASH"] == local["RID_HASH"])
                    & (imputed_data["VISCODE"] == local["VISCODE"])
                ][col].values[0]

                assert imputed_val == imputed_val
                assert imputed_val != ""

                results.append([imputed_id, imputed_val])

    output = pd.DataFrame(results, columns=submission.columns)
    output.to_csv(fpath, index=None)

    return output


def get_submission_data(inputed_data):
    inputed_data = inputed_data.copy()
    imputed_data = normalize_predictions(imputed_data)

    output_fpath = (
        results_dir
        / f"imputation_results_{version}_{changelog}_{eval_mode}_normalized.csv"
    )

    print("Prepare output", output_fpath)
    output_normalized = dump_results(normalize_output(inputed_data), output_fpath)

    return output_fpath, output_normalized


normalize_output(predictions)

In [None]:
predictions

In [None]:
test_B

In [None]:
fpath, output = get_submission_data(predictions)

output