In [1]:
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from hyperimpute.utils.serialization import load_model_from_file, save_model_to_file
from sklearn.preprocessing import MinMaxScaler

from baseline_imputation import prepare_age, prepare_consts

workspace = Path("workspace")
results_dir = Path("results")
data_dir = Path("data")

workspace.mkdir(parents=True, exist_ok=True)

warnings.filterwarnings("ignore")

cat_limit = 10
n_seeds = 5

version = "take8"
changelog = f"hyperlatent_transformer"

In [2]:
def dataframe_hash(df: pd.DataFrame) -> str:
    cols = sorted(list(df.columns))
    return str(abs(pd.util.hash_pandas_object(df[cols].fillna(0)).sum()))


def augment_base_dataset(df, scaler, scaled_cols):
    df = df.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
    df = prepare_consts(df)
    df = prepare_age(df, scaler, scaled_cols)

    return df

In [3]:
dev_set = pd.read_csv(data_dir / "dev_set.csv")

scaled_cols = [
    "AGE",
    "PTEDUCAT",
    "MMSE",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]

scaler = MinMaxScaler().fit(dev_set[scaled_cols])
dev_set[scaled_cols] = scaler.transform(dev_set[scaled_cols])

dev_set = augment_base_dataset(dev_set, scaler, scaled_cols)

dev_set

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.574419,0,1.0000,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.586047,0,1.0000,1.0,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.430233,1,0.5000,1.0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.441860,1,0.5000,1.0,1.0,1.0,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.453488,1,0.5000,1.0,1.0,1.0,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,0.590698,1,0.9375,1.0,0.0,3.0,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,0.672093,1,0.9375,1.0,0.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,0.411628,0,0.5000,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.434884,0,0.5000,1.0,0.0,1.0,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [4]:
static_features = ["RID_HASH", "AGE", "PTGENDER_num", "PTEDUCAT", "APOE4"]
temporal_features = [
    "RID_HASH",
    "VISCODE",
    "DX_num",
    "CDRSB",
    "MMSE",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]  #

dev_set_static = dev_set.sort_values(["RID_HASH", "VISCODE"]).drop_duplicates(
    "RID_HASH"
)[static_features]
dev_set_temporal = dev_set.sort_values(["RID_HASH", "VISCODE"])[temporal_features]

dev_set_static

Unnamed: 0,RID_HASH,AGE,PTGENDER_num,PTEDUCAT,APOE4
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0.574419,0,1.0000,1.0
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0.430233,1,0.5000,1.0
8,0131f7f44ff183309c590b9ff440806b20f639c90c124d...,0.453488,0,0.5000,0.0
16,01513c9ff1e8fcc22cbfc9093845a37ee69307e3493daf...,0.441860,1,0.5000,0.0
26,01705aaf2c869203d7a8374472f5907f53f3b15f7b4faa...,0.372093,0,0.7500,0.0
...,...,...,...,...,...
4076,ff1d8cc22fb5bf2bd80e31d6d3a6cf1709562bb7e9a22f...,0.397674,1,0.7500,1.0
4079,ff21c0f13c9535e8339ce653a268b26df8e4172212ac05...,0.500000,1,0.8750,0.0
4084,ff48382bcf5922a2db52db36c791b02910015feee82505...,0.369767,0,0.5000,1.0
4091,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,0.474419,1,0.9375,0.0


In [5]:
dev_set_temporal

Unnamed: 0,RID_HASH,VISCODE,DX_num,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,1.0,1.0,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,1.0,1.0,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,1.0,3.0,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,1.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,1.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,1.0,1.0,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [6]:
raw_dev_1 = pd.read_csv(data_dir / "dev_1.csv")
dev_1 = augment_base_dataset(raw_dev_1, scaler, scaled_cols)
dev_1[scaled_cols] = scaler.transform(dev_1[scaled_cols])

dev_1

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.585776,0.0,1.0000,1.0,1.0,0.5,0.923077,0.164384,,,0.376516,,,
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.586047,0.0,1.0000,1.0,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.543807,1.0,0.5000,,1.0,,,,,0.525169,0.235599,0.513404,0.356253,0.294774
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.544078,1.0,0.5000,,1.0,,,,,0.549210,0.230361,0.435097,0.322395,0.294175
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.544348,1.0,0.5000,,1.0,,,,,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,0.590698,1.0,0.9375,,0.0,,,,0.170895,,0.321346,,,
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,0.672093,1.0,0.9375,,0.0,,,,0.178231,,0.309095,,,
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,0.411628,0.0,0.5000,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,,0.610229,0.743037,0.624631
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.434884,0.0,0.5000,1.0,0.0,1.0,0.961538,0.155205,0.398451,0.608521,,0.617108,0.729087,0.638477


In [7]:
raw_dev_2 = pd.read_csv(data_dir / "dev_2.csv")
dev_2 = augment_base_dataset(raw_dev_2, scaler, scaled_cols)
dev_2[scaled_cols] = scaler.transform(dev_2[scaled_cols])

dev_2

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.574419,0.0,1.0000,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.586047,0.0,1.0000,,1.0,,,,0.071956,0.548307,,0.403880,0.193367,0.397291
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.430233,1.0,0.5000,1.0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,,0.513404,0.356253,0.294774
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.430503,1.0,0.5000,1.0,1.0,1.0,1.000000,0.164384,,0.549210,,0.435097,0.322395,0.294175
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.430773,1.0,0.5000,1.0,1.0,1.0,0.961538,0.109589,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,,1.0,0.9375,1.0,0.0,3.0,0.923077,0.223699,,0.357020,,0.310935,0.399047,0.461476
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,,1.0,0.9375,1.0,0.0,3.0,0.846154,0.168904,,0.352043,,0.256790,0.372685,0.416478
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,0.411628,0.0,0.5000,,0.0,,,,0.416382,0.602438,,0.610229,0.743037,0.624631
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.412169,0.0,0.5000,,0.0,,,,,,,,,


In [8]:
submission = pd.read_csv(data_dir / "sample_submission.csv")

submission.values[1]

array(['6b6a7136f42a8dbd469a201b88e2abb54a93667822761357db2f6d620da6af8a_0_Ventricles_test_A',
       40613.0818580834], dtype=object)

In [9]:
raw_test_A = pd.read_csv(data_dir / "test_A.csv")
test_A = augment_base_dataset(raw_test_A, scaler, scaled_cols)
test_A[scaled_cols] = scaler.transform(test_A[scaled_cols])

test_A_gt = pd.read_csv(data_dir / "test_A_gt.csv")
test_A_gt = augment_base_dataset(test_A_gt, scaler, scaled_cols)
test_A_gt[scaled_cols] = scaler.transform(test_A_gt[scaled_cols])

assert (test_A["VISCODE"] == test_A_gt["VISCODE"]).all()

test_A_gt

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,00d5e0050fbd3b6b610f6673347232eb0862df77b5b7a8...,0.0,0.625581,1.0,0.7500,1.0,0.0,0.5,0.961538,0.219178,0.274673,0.397517,0.272565,0.405996,0.345331,0.505790
1,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,0.0,0.420930,1.0,0.5000,1.0,1.0,0.5,0.807692,0.429178,0.057498,0.612302,0.423268,0.291182,0.433004,0.329131
2,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,6.0,0.432558,1.0,0.5000,1.0,1.0,2.5,0.615385,0.360685,0.067972,0.576975,0.399942,0.302646,0.415628,0.330157
3,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,12.0,0.444186,1.0,0.5000,1.0,1.0,2.0,0.769231,0.365342,0.077516,0.563770,0.415324,0.273721,0.389962,0.316610
4,024efbff9265302acd00190e57ee08ba1fe1b90f561f79...,0.0,0.146512,0.0,0.6250,1.0,1.0,2.0,1.000000,0.164384,0.080886,0.607675,0.515223,0.536155,0.545500,0.577839
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1323,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,0.0,0.290698,0.0,0.8750,1.0,1.0,1.5,0.807692,0.150685,0.211471,0.594244,0.515137,0.465785,0.675868,0.427332
1324,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,6.0,0.302326,0.0,0.8750,1.0,1.0,1.5,0.769231,0.095890,0.228441,0.510497,0.470574,0.472134,0.622549,0.437075
1325,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,24.0,0.337209,0.0,0.8750,1.0,1.0,1.5,0.769231,0.150685,0.243265,0.521219,0.464475,0.476190,0.640719,0.433016
1326,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,48.0,0.383721,0.0,0.8750,1.0,1.0,2.5,0.807692,0.246575,0.307697,0.420993,0.525265,0.392416,0.577719,0.403872


In [10]:
raw_test_B = pd.read_csv(data_dir / "test_B.csv")
test_B = augment_base_dataset(raw_test_B, scaler, scaled_cols)
test_B[scaled_cols] = scaler.transform(test_B[scaled_cols])

test_B_gt = pd.read_csv(data_dir / "test_B_gt.csv")
test_B_gt = augment_base_dataset(test_B_gt, scaler, scaled_cols)
test_B_gt[scaled_cols] = scaler.transform(test_B_gt[scaled_cols])

assert (test_B["VISCODE"] == test_B_gt["VISCODE"]).all()

test_B

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,0,0.395349,,0.6875,0.0,2.0,0.0,0.961538,0.077671,0.085164,0.638939,,0.608113,0.424862,0.523781
1,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,36,0.465116,,0.6875,0.0,2.0,0.0,1.000000,0.027397,0.089750,,,,,
2,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,0,,0.0,,1.0,0.0,0.5,0.846154,0.196301,,0.345711,0.286043,0.312698,0.276821,0.248579
3,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,6,,0.0,,1.0,0.0,1.0,1.000000,0.283151,,0.345147,0.278219,0.378307,0.289480,0.253793
4,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,12,,0.0,,1.0,0.0,2.5,0.807692,0.168904,,0.329233,0.253372,0.352028,0.259842,0.222042
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1465,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,0,,0.0,0.8750,1.0,1.0,1.5,0.884615,0.114110,,0.502370,,0.394356,0.397160,0.531003
1466,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,12,,0.0,0.8750,1.0,1.0,1.5,0.923077,0.242055,,0.519639,,0.294356,0.416522,0.545575
1467,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,84,,0.0,0.8750,1.0,1.0,1.5,1.000000,0.178082,,0.432054,0.483387,0.363316,0.468451,0.508440
1468,ffa86109ba8684f31325842d0ff26568e105f0f63b366a...,0,0.276744,1.0,0.5625,0.0,0.0,0.0,0.923077,0.118767,0.177669,,,,,


In [11]:
test_AB_input = pd.concat([test_A, test_B], ignore_index=True)

test_AB_raw = pd.concat([raw_test_A, raw_test_B], ignore_index=True)
test_AB_raw[scaled_cols] = scaler.transform(test_AB_raw[scaled_cols])
test_AB_raw = test_AB_raw.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)

test_AB_output = pd.concat([test_A_gt, test_B_gt], ignore_index=True)
test_AB_output = test_AB_output.sort_values(["RID_HASH", "VISCODE"]).reset_index(
    drop=True
)

test_AB_output

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,0.0,0.395349,1.0,0.6875,0.0,2.0,0.0,0.961538,0.077671,0.085164,0.638939,0.523499,0.608113,0.424862,0.523781
1,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,36.0,0.465116,1.0,0.6875,0.0,2.0,0.0,1.000000,0.027397,0.089750,0.605530,0.483793,0.619400,0.427990,0.453528
2,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,0.0,0.444186,0.0,0.8125,1.0,0.0,0.5,0.846154,0.196301,0.525252,0.345711,0.286043,0.312698,0.276821,0.248579
3,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,6.0,0.444186,0.0,0.8125,1.0,0.0,1.0,1.000000,0.283151,0.553592,0.345147,0.278219,0.378307,0.289480,0.253793
4,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,12.0,0.444186,0.0,0.8125,1.0,0.0,2.5,0.807692,0.168904,0.567501,0.329233,0.253372,0.352028,0.259842,0.222042
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2793,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,0.0,0.462791,0.0,0.8750,1.0,1.0,1.5,0.884615,0.114110,0.125561,0.502370,0.457404,0.394356,0.397160,0.531003
2794,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,12.0,0.462791,0.0,0.8750,1.0,1.0,1.5,0.923077,0.242055,0.119052,0.519639,0.458203,0.294356,0.416522,0.545575
2795,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,84.0,0.462791,0.0,0.8750,1.0,1.0,1.5,1.000000,0.178082,0.190065,0.432054,0.483387,0.363316,0.468451,0.508440
2796,ffa86109ba8684f31325842d0ff26568e105f0f63b366a...,0.0,0.276744,1.0,0.5625,0.0,0.0,0.0,0.923077,0.118767,0.177669,0.563866,0.446764,0.498944,0.474951,0.495801


In [12]:
test_AB_raw

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,0,0.395349,,0.6875,0.0,2.0,0.0,0.961538,0.077671,0.085164,0.638939,,0.608113,0.424862,0.523781
1,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,36,0.465116,,0.6875,0.0,2.0,0.0,1.000000,0.027397,0.089750,,,,,
2,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,0,,0.0,,1.0,0.0,0.5,0.846154,0.196301,,0.345711,0.286043,0.312698,0.276821,0.248579
3,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,6,,0.0,,1.0,0.0,1.0,1.000000,0.283151,,0.345147,0.278219,0.378307,0.289480,0.253793
4,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,12,,0.0,,1.0,0.0,2.5,0.807692,0.168904,,0.329233,0.253372,0.352028,0.259842,0.222042
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2793,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,0,,,0.8750,1.0,1.0,1.5,0.884615,0.114110,,0.502370,,0.394356,0.397160,0.531003
2794,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,12,,,0.8750,1.0,1.0,1.5,0.923077,0.242055,,0.519639,,0.294356,0.416522,0.545575
2795,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,84,,0.0,0.8750,1.0,1.0,1.5,1.000000,0.178082,,0.432054,0.483387,0.363316,0.468451,0.508440
2796,ffa86109ba8684f31325842d0ff26568e105f0f63b366a...,0,0.276744,,0.5625,0.0,0.0,0.0,0.923077,0.118767,0.177669,,,,,


In [13]:
test_AB_raw.isna().sum()

RID_HASH           0
VISCODE            0
AGE             1379
PTGENDER_num    1334
PTEDUCAT         338
DX_num           906
APOE4            126
CDRSB            906
MMSE             906
ADAS13           906
Ventricles      1379
Hippocampus     1378
WholeBrain      1334
Entorhinal      1378
Fusiform        1378
MidTemp         1378
dtype: int64

In [14]:
test_A.columns

Index(['RID_HASH', 'VISCODE', 'AGE', 'PTGENDER_num', 'PTEDUCAT', 'DX_num',
       'APOE4', 'CDRSB', 'MMSE', 'ADAS13', 'Ventricles', 'Hippocampus',
       'WholeBrain', 'Entorhinal', 'Fusiform', 'MidTemp'],
      dtype='object')

## Emulate missigness

In [15]:
def copy_missingness(ref_data):
    ref_data_ids = ref_data["RID_HASH"].unique()

    len_to_miss = {}
    for rid in ref_data_ids:
        local_A = ref_data[ref_data["RID_HASH"] == rid]
        # print(len(local_A), local_A.isna().sum().sum())

        local_len = len(local_A)
        if local_len not in len_to_miss:
            len_to_miss[local_len] = []
        for reps in range(4):
            len_to_miss[local_len].append(local_A.notna().reset_index(drop=True))

    out_data = pd.DataFrame([], columns=dev_set.columns)
    out_data_ids = dev_set["RID_HASH"].unique()
    for rid in out_data_ids:
        local_A = dev_set[dev_set["RID_HASH"] == rid].copy().reset_index(drop=True)
        local_len = len(local_A)

        if local_len in len_to_miss and len(len_to_miss[local_len]) > 0:
            target_mask = len_to_miss[local_len].pop(0)
            out_data = pd.concat([out_data, local_A[target_mask]], ignore_index=True)
        else:
            out_data = pd.concat([out_data, local_A], ignore_index=True)

    return out_data

In [16]:
dev_sim_A = copy_missingness(test_A)

dev_sim_A

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.574419,0,1.0,,1.0,,,,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.586047,0,1.0,1.0,1.0,1.5,0.923077,0.237397,,0.548307,0.366398,0.40388,0.193367,0.397291
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.430233,1,0.5,,1.0,,,,0.142655,,0.235599,,,
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.44186,1,0.5,,1.0,,,,,,0.230361,,,
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.453488,1,0.5,,1.0,,,,,,0.215944,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,0.590698,1,0.9375,1.0,0.0,3.0,0.923077,0.223699,,0.35702,,0.310935,0.399047,0.461476
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,0.672093,1,0.9375,1.0,0.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.25679,0.372685,0.416478
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,0.411628,0,0.5,,0.0,,,,,,0.636654,,,
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.434884,0,0.5,,0.0,,,,,,0.63465,,,


In [17]:
dev_sim_B = copy_missingness(test_B)

dev_sim_B

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.574419,,1.0,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,,0.464021,0.194906,0.400709
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.586047,,1.0,1.0,1.0,1.5,0.923077,0.237397,0.071956,,,,,
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.430233,1,0.5,1.0,1.0,1.0,1.0,0.123288,,0.525169,,0.513404,0.356253,0.294774
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.44186,1,0.5,1.0,1.0,1.0,1.0,0.164384,,,,,,
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.453488,1,0.5,1.0,1.0,1.0,0.961538,0.109589,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,0.590698,1,0.9375,,0.0,,,,0.170895,,,,,
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,0.672093,1,0.9375,,0.0,,,,0.178231,0.352043,,0.25679,0.372685,0.416478
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,0.411628,0,0.5,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.434884,0,0.5,,0.0,,,,0.398451,0.608521,,0.617108,0.729087,0.638477


## Model

In [18]:
from sklearn.preprocessing import LabelEncoder


def mask_columns_map(s: str):
    return f"masked_{s}"


def generate_covariates(ref_df):
    ref_df = ref_df.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
    mask = (
        ref_df.isna()
        .astype(int)
        .drop(columns=["RID_HASH", "VISCODE"])
        .rename(mask_columns_map, axis="columns")
    ).reset_index(drop=True)
    ref_df = ref_df.fillna(-1)
    test_input = pd.concat([ref_df, mask], axis=1).reset_index(drop=True)

    return test_input


def generate_testcase(ref_df, out_df, target_column: str, cat_thresh: int = cat_limit):
    assert len(ref_df) == len(out_df)
    ref_df = ref_df.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)
    out_df = out_df.sort_values(["RID_HASH", "VISCODE"]).reset_index(drop=True)

    assert (ref_df["RID_HASH"].values == out_df["RID_HASH"].values).all()
    assert (ref_df["VISCODE"].values == out_df["VISCODE"].values).all()

    test_input = generate_covariates(ref_df)

    target_mask = (ref_df.isna().astype(int)).reset_index(drop=True)
    target_mask = target_mask[["RID_HASH", "VISCODE", target_column]]
    target_mask[["RID_HASH", "VISCODE"]] = ref_df[["RID_HASH", "VISCODE"]]

    test_output = out_df[["RID_HASH", "VISCODE", target_column]].reset_index(drop=True)

    n_units_out = 1
    if len(dev_set[target_column].unique()) < cat_thresh:
        encoding_data = pd.concat([dev_set, out_df, test_AB_output], ignore_index=True)
        n_units_out = len(encoding_data[target_column].unique())

        encoder = LabelEncoder().fit(encoding_data[[target_column]])

        test_output[target_column] = encoder.transform(out_df[[target_column]])

    return test_input, test_output, target_mask, n_units_out


def prepare_dataset(target_column: str, cat_thresh: int = cat_limit):
    df_input_1, df_output_1, target_mask_1, n_units_out = generate_testcase(
        dev_1, dev_set, target_column, cat_thresh=cat_thresh
    )
    df_input_2, df_output_2, target_mask_2, _ = generate_testcase(
        dev_2, dev_set, target_column, cat_thresh=cat_thresh
    )
    df_input_sim_A, df_output_sim_A, target_mask_sim_A, _ = generate_testcase(
        dev_sim_A, dev_set, target_column, cat_thresh=cat_thresh
    )
    df_input_sim_B, df_output_sim_B, target_mask_sim_B, _ = generate_testcase(
        dev_sim_B, dev_set, target_column, cat_thresh=cat_thresh
    )

    return (
        [df_input_1, df_output_1, target_mask_1],
        [df_input_2, df_output_2, target_mask_2],
        [df_input_sim_A, df_output_sim_A, target_mask_sim_A],
        [df_input_sim_B, df_output_sim_B, target_mask_sim_B],
    ), n_units_out


def prepare_test_dataset(target_column: str, cat_thresh: int = cat_limit):
    test_input, test_output, test_mask, _ = generate_testcase(
        test_AB_input, test_AB_output, target_column, cat_thresh=cat_thresh
    )
    return test_input, test_output, test_mask

In [19]:
from sklearn.model_selection import train_test_split

from ts_imputer_v2 import TimeSeriesImputerTemporal


def get_imputer_for_column(target_column: str, n_units_hidden: int = 50):
    bkp_file = workspace / f"dedicated_imputer_col_{target_column}_{n_units_hidden}.bkp"

    # print("Training imputer for", bkp_file)
    if bkp_file.exists():
        return load_model_from_file(bkp_file)

    cat_thresh = 30
    testcases, n_units_out = prepare_dataset(
        target_column=target_column, cat_thresh=cat_thresh
    )
    test_in, test_output, test_target_mask = prepare_test_dataset(
        target_column=target_column, cat_thresh=cat_thresh
    )

    if len(dev_set[target_column].unique()) < cat_thresh:
        task_type = "classification"
    else:
        task_type = "regression"

    imputer = TimeSeriesImputerTemporal(
        task_type=task_type,
        n_units_in=testcases[0][0].shape[-1] - 1,  # DROP RID_HASH
        n_units_out=n_units_out,  # DROP RID_HASH and VISCODE
        nonlin="relu",
        dropout=0.05,
        # nonlin_out = activation_layout,
        n_layers_hidden=2,
        n_units_hidden=n_units_hidden,
        n_iter=10000,
        residual=False,
        patience = 5,
    )

    for repeat in range(3):
        for train_input, train_output, train_target_mask in testcases:
            imputer.fit(
                train_input,
                train_output,
                train_target_mask,
                test_in,
                test_output,
                test_target_mask,
            )
    save_model_to_file(bkp_file, imputer)

    return imputer

n_units_hidden = 50

for _target_col in [
    "APOE4",
    "CDRSB",
    "MMSE",
    "AGE",
    "PTGENDER_num",
    "PTEDUCAT",
    "DX_num",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]:
    get_imputer_for_column(target_column=_target_col, n_units_hidden = n_units_hidden)

In [20]:
from hyperimpute.plugins.prediction import Classifiers, Regression
from sklearn.metrics import r2_score
from sklearn.preprocessing import LabelEncoder

target_columns = [
    "WholeBrain",
    "Fusiform",
    "APOE4",
    "CDRSB",
    "MMSE",
    "AGE",
    "PTGENDER_num",
    "PTEDUCAT",
    "DX_num",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "Entorhinal",
    "MidTemp",
]

target_encoders = {}

for _target_col in target_columns:
    if len(dev_set[_target_col].unique()) < 30:
        encoding_data = pd.concat([dev_set, test_AB_output], ignore_index=True)
        target_encoders[_target_col] = LabelEncoder().fit(encoding_data[[_target_col]])

def map_latent_columns(s: str):
    return f"latent_{s}"


def generate_latent_repr(ref_df: pd.DataFrame, target_column: str):
    ref_id = dataframe_hash(ref_df)
    bkp_file = workspace / f"df_latent_repr_{ref_id}_{target_column}_{n_units_hidden}.csv"
    if bkp_file.exists():
        return pd.read_csv(bkp_file)

    covs = generate_covariates(ref_df)
    imputer = get_imputer_for_column(target_column)

    latent = imputer.predict_latent(covs)
    latent = latent.rename(map_latent_columns, axis="columns")
    assert (latent["latent_RID_HASH"].values == ref_df["RID_HASH"].values).all()

    latent.to_csv(bkp_file, index=None)

    return latent


def generate_ext_repr(ref_df: pd.DataFrame, target_column: str):
    latent = generate_latent_repr(ref_df, target_column)
    covs = generate_covariates(ref_df)
    latent = latent.drop(columns=["latent_RID_HASH"])

    return pd.concat([covs, latent], axis="columns")


def generate_ext_train_dataset(target_column: str):
    bkp_file = workspace / f"ext_data_input_{target_column}_{n_units_hidden}.bkp"
    # print("generate_ext_dataset", bkp_file)
    if bkp_file.exists():
        return load_model_from_file(bkp_file)

    imputer = get_imputer_for_column(target_column)

    train_cases, _ = prepare_dataset(target_column=target_column, cat_thresh=30)
    train_inputs = []
    train_outputs = []
    train_masks = []
    for train_input, train_output, train_target_mask in train_cases:
        train_latents = imputer.predict_latent(train_input)
        train_latents = train_latents.rename(map_latent_columns, axis="columns")
        assert (
            train_latents["latent_RID_HASH"].values == train_input["RID_HASH"].values
        ).all()

        train_latents = train_latents.drop(columns=["latent_RID_HASH"])

        train_input_full = pd.concat([train_input, train_latents], axis="columns")
        train_output_full = train_output[target_column]

        train_inputs.append(train_input_full)
        train_outputs.append(train_output_full)
        train_masks.append(train_target_mask[target_column])

    train_inputs = pd.concat(train_inputs, ignore_index=True)
    train_outputs = pd.concat(train_outputs, ignore_index=True)
    train_masks = pd.concat(train_masks, ignore_index=True)

    if target_column in target_encoders:
        train_outputs = target_encoders[target_column].inverse_transform(train_outputs)
        
        
    save_model_to_file(bkp_file, (train_inputs, train_outputs, train_masks))

    return train_inputs, train_outputs, train_masks


def generate_ext_test_dataset(target_column: str):
    bkp_file = workspace / f"ext_testdata_input_{target_column}_{n_units_hidden}.bkp"
    # print("generate_ext_test dataset", bkp_file)
    if bkp_file.exists():
        return load_model_from_file(bkp_file)

    imputer = get_imputer_for_column(target_column)

    test_in, test_output, test_target_mask = prepare_test_dataset(
        target_column=target_column, cat_thresh=30
    )
    latents = imputer.predict_latent(test_in)
    latents = latents.rename(map_latent_columns, axis="columns")

    assert (latents["latent_RID_HASH"].values == test_in["RID_HASH"].values).all()

    latents = latents.drop(columns=["latent_RID_HASH"])

    input_full = pd.concat([test_in, latents], axis="columns")
    output_full = test_output[target_column]
    
    if target_column in target_encoders:
        output_full = target_encoders[target_column].inverse_transform(output_full)
        
    mask_full = test_target_mask[target_column]

    save_model_to_file(bkp_file, (input_full, output_full, mask_full))

    return input_full, output_full, mask_full



#         continue


#         if column in target_encoders:
#             test_output[_target_col] = target_encoders[column].inverse_transform(test_output[[_target_col]]).squeeze()
#             preds[_target_col] = target_encoders[column].inverse_transform(preds[[_target_col]]).squeeze()

#         assert (test_output["RID_HASH"].values == preds['RID_HASH'].values).all()
#         assert (test_output["VISCODE"].values == preds['VISCODE'].values).all()

#         output_mask = test_target_mask[_target_col].values.astype(bool)
#         y_truth = test_output[_target_col].values[output_mask]
#         y_pred = preds[_target_col].values[output_mask]

In [21]:
# from hyperimpute.utils.optimizer import EarlyStoppingExceeded, create_study
# import optuna
# from typing import Any

# for _target_col in [
#     "MMSE",
#     "CDRSB",
#     "PTEDUCAT",
# ]: #target_columns:
#     print(f" >> {_target_col}")
#     scores = {}
#     for (clf_type, reg_type) in [
#         ("xgboost", "xgboost_regressor"),
#         ("logistic_regression", "linear_regression"),
#         ("catboost", "catboost_regressor"),
#         ("random_forest", "random_forest_regressor"),
#     ]:
#         train_inputs, train_outputs, train_masks = generate_ext_train_dataset(_target_col)
#         test_in, test_output, test_target_mask = generate_ext_test_dataset(_target_col)
#         if _target_col in target_encoders:
#             plugin = Classifiers().get_type(clf_type)
#         else:
#             plugin = Regression().get_type(reg_type)

#         study, pruner = create_study(
#             study_name=f"imputation_{plugin.name()}_{_target_col}",
#             direction="maximize",
#             load_if_exists = True,
#         )

#         def evaluate_args(**kwargs: Any) -> float:
#             model = plugin(**kwargs)
#             model.fit(train_inputs.drop(columns = ["RID_HASH"]).astype(float), train_outputs)
#             preds = model.predict(test_in.drop(columns = ["RID_HASH"]).astype(float))

#             output_mask = test_target_mask.values.astype(bool)
#             y_truth = test_output.values[output_mask]
#             y_pred = preds.values[output_mask]
#             #print(f"      >> bechnmark ", model.name(), kwargs, r2_score(y_truth, y_pred))
#             return r2_score(y_truth, y_pred)

#         def objective(trial: optuna.Trial) -> float:
#             args = plugin.sample_hyperparameters(trial)
#             pruner.check_trial(trial)

#             try:
#                 score = evaluate_args(**args)
#             except BaseException:
#                 print("      failed evaluation", plugin.name(), args)
#                 return -1

#             pruner.report_score(score)

#             return score

#         try:
#             study.optimize(objective, n_trials=50, timeout=60 * 10)
#         except EarlyStoppingExceeded:
#             pass

#         baseline_score = evaluate_args()

#         if study.best_value > baseline_score:
#             score = study.best_value
#             args = study.best_trial.params
#         else:
#             score = baseline_score
#             args = {}
#         print(f"    >> {plugin.name()} -- {args}", score)

#         scores[plugin.name()] = score

#     print(f" >> {_target_col} selected {max(scores, key=scores.get)}")

In [24]:
from hyperimpute.plugins.prediction import Classifiers, Regression
from hyperimpute.utils.benchmarks import RMSE
from hyperimpute.utils.optimizer import EarlyStoppingExceeded, create_study
import optuna
from typing import Any
from joblib import parallel_backend

prev_configurations = {}

target_columns = [
    "WholeBrain",
    "Hippocampus",
    "Fusiform",
    "MidTemp",
    "ADAS13",
    "MMSE",
    "Ventricles",
    "CDRSB",
    "Entorhinal",
    "APOE4",
    "AGE",
    "PTEDUCAT",
    "PTGENDER_num",
    "DX_num",
]


def _compute_err(out, mask, gt, target_col = None):
    out = out.copy()
    gt = gt.copy()
    
    out[scaled_cols] = scaler.inverse_transform(out[scaled_cols])
    gt[scaled_cols] = scaler.inverse_transform(gt[scaled_cols])
    
    if target_col is None:
        out = out.copy().fillna(out.mean())

        assert (mask.columns == out.columns).all()
        assert (gt.columns == out.columns).all()

        return RMSE(
            out.drop(columns=["RID_HASH"]).values,
            gt.drop(columns=["RID_HASH"]).values,
            mask.drop(columns=["RID_HASH"]).values,
        )
    else:
        out = out[target_col].fillna(out[target_col].mean())

        return RMSE(
            out.values,
            gt[target_col].values,
            mask[target_col].values,
        )
    
def train_predict_column(target_column, target_plugin, target_plugin_args, 
                         col_X_train, col_y_train, col_X_test):
    assert len(col_X_train) == len(col_y_train)

    orig_train, orig_out, _ = generate_ext_train_dataset(target_column)
    orig_out = pd.Series(orig_out, index=orig_train.index)

    assert (orig_train.columns == col_X_train.columns).all()
    assert (orig_train.columns == col_X_test.columns).all()

    X_train = pd.concat([orig_train, col_X_train], ignore_index=True)
    y_train = pd.concat([orig_out, col_y_train], ignore_index=True)

    predictor = target_plugin(**target_plugin_args)
    #print("     >> Training", target_column, predictor.name())

    predictor.fit(X_train.drop(columns=["RID_HASH"]).astype(float), y_train)
    preds = predictor.predict(col_X_test.drop(columns=["RID_HASH"]).astype(float))

    if target_column in static_features:
        actual_ids = col_X_test["RID_HASH"].reset_index(drop=True)

        preds = pd.concat([actual_ids, preds], axis="columns")
        preds.columns = ["RID_HASH", "preds"]

        for rid in preds["RID_HASH"].unique():
            local_preds = preds[preds["RID_HASH"] == rid]["preds"]
            if target_column in ["PTGENDER_num", "APOE4"]:
                value_counts = local_preds.value_counts().to_dict()
                first_value = max(value_counts, key=value_counts.get)
            elif target_column == "AGE":
                first_value = local_preds.values[0]
            elif target_column == "PTEDUCAT":
                first_value = local_preds.mean()
            else:
                raise ValueError(f"unhandled {target_column}")

            local_out = [np.nan] * len(local_preds)
            local_out[0] = first_value
            preds.loc[preds["RID_HASH"] == rid, "preds"] = local_out

        return preds["preds"]
    else:
        return preds


def run_imputation_iteration(
    working_df: pd.DataFrame,
    ref_df: pd.DataFrame,
    target_column: str,
    working_mask: pd.DataFrame,
    ref_mask: pd.DataFrame,
    gt: pd.DataFrame,
):
    col_mask = ~working_mask[target_column]

    ext_repr = generate_ext_repr(working_df, target_column)
    target_col = ref_df[target_column]

    train_X = ext_repr[col_mask]
    train_y = target_col[col_mask]

    test_X = ext_repr[~col_mask]
    assert test_X.isna().sum().sum() == 0, test_X
    assert train_X.isna().sum().sum() == 0, train_X
    assert train_y.isna().sum() == 0, train_y.isna().sum()

    #target_plugin, target_plugin_args = target_ref_plugins[target_column]
    
    def evaluate_args(target_plugin, target_plugin_args = {}):
        imputed_y = train_predict_column(target_column, target_plugin, target_plugin_args, train_X, train_y, test_X)

        assert len(imputed_y) == len(test_X)

        candidate_df = working_df.copy()
        candidate_df[target_column][~col_mask] = imputed_y.values.squeeze()

        if target_column in static_features:
            candidate_df = prepare_consts(candidate_df)
            candidate_df = prepare_age(candidate_df, scaler, scaled_cols)

        return candidate_df, _compute_err(candidate_df, ref_mask, gt, target_col = target_column)
    
    if target_column in prev_configurations:
        best_target_plugin, best_target_plugin_args = prev_configurations[target_column]
        _, best_score = evaluate_args(best_target_plugin, best_target_plugin_args)
    else:
        best_score = 999999
        best_target_plugin, best_target_plugin_args = None, None
    
    for (clf_type, reg_type) in [
        ("xgboost", "xgboost_regressor"),
        (None, "linear_regression"),
        ("catboost", "catboost_regressor"),
        (None, "random_forest_regressor"),
        ("kneighbors", "kneighbors_regressor"),
    ]:
        if len(dev_set[target_column].unique()) < 10:
            if clf_type is None:
                continue
            plugin = Classifiers().get_type(clf_type)
        else:
            if reg_type is None:
                continue
            plugin = Regression().get_type(reg_type)
        
        study, pruner = create_study(
            study_name=f"imputation_{plugin.name()}_{target_column}_rmse_with_latent",
            direction="minimize",
            load_if_exists = False,
        )
        
        def objective(trial: optuna.Trial) -> float:
            args = plugin.sample_hyperparameters(trial)
            pruner.check_trial(trial)

            try:
                _, score = evaluate_args(plugin, args)
            except BaseException:
                #print("      failed evaluation", plugin.name(), args)
                return 9999999

            #print(f"    >>  {plugin.name()} {args} -> {score}")
            pruner.report_score(score)

            return score

        try:
            study.optimize(objective, n_trials=10, timeout=60 * 10)
        except EarlyStoppingExceeded:
            pass

        _, baseline_score = evaluate_args(plugin)

        if study.best_value < baseline_score:
            score = study.best_value
            args = study.best_trial.params
        else:
            score = baseline_score
            args = {}
            
        if score < best_score:
            best_score = score
            best_target_plugin, best_target_plugin_args = plugin, args
        
    # print(f"     >> [{_col}][{it}] Miss", out.isna().sum().sum())
    # print(f"     >> [{target_column}][{it}] Error", _compute_err(working_df, ref_mask))
    print(f"     >> Selected {best_target_plugin.name()} -- {best_target_plugin_args}", best_score)
    prev_configurations[target_column] = (best_target_plugin, best_target_plugin_args)
    
    candidate_df, _ = evaluate_args(best_target_plugin, best_target_plugin_args)
    return candidate_df


def run_imputation(ref_df: pd.DataFrame, gt: pd.DataFrame, num_iter: int = 10, working_df = None):
    ref_df = ref_df.copy()
    ref_mask = test_AB_raw.isna()
    print(" >> Baseline error", _compute_err(ref_df, ref_mask, gt))

    ref_df = augment_base_dataset(ref_df, scaler, scaled_cols)
    working_mask = ref_df.isna()

    if working_df is None:
        working_df = ref_df.copy()

    print(" >> Initial miss", working_df.isna().sum().sum())
    for it in range(num_iter):
        prev_out = working_df.copy().fillna(0)
        for _col in target_columns:
            print(f"[{_col}][{it}]")
            working_df = run_imputation_iteration(
                working_df, ref_df, _col, working_mask, ref_mask, gt
            )
            print(f"     >> [{_col}][{it}] Error", _compute_err(working_df, ref_mask, gt))

        assert working_df.isna().sum().sum() == 0
        inf_norm = np.linalg.norm(
            working_df.drop(columns=["RID_HASH"]).values
            - prev_out.drop(columns=["RID_HASH"]).values,
            ord=np.inf,
            axis=None,
        )
        prev_out = working_df.copy()
        
        print(f"[{it}] Diff norm = {inf_norm}")
        if inf_norm < 1e-8:
            print(f"[{it}] Early stopping.. inf_norm = {inf_norm}")
            break

    return working_df

In [None]:
bkp_file = workspace / f"test_AB_imputed_transformer_latent_{n_units_hidden}.csv"

for niter in range(10):
    test_AB_imputed = None
    if bkp_file.exists():
        test_AB_imputed = pd.read_csv(bkp_file)

    test_AB_imputed = run_imputation(test_AB_raw, 
                                     gt = test_AB_output, 
                                     working_df = test_AB_imputed, 
                                     num_iter = 2,
                                    )
    test_AB_imputed.to_csv(bkp_file, index = None)

test_AB_imputed

 >> Baseline error 32845.89786514254
 >> Initial miss 13086
[WholeBrain][0]
     >> Selected xgboost_regressor -- {'reg_lambda': 1.5282805283838838, 'reg_alpha': 3.0387078093102393, 'max_depth': 4, 'n_estimators': 213, 'lr': 0.01} 63547.864878803266
     >> [WholeBrain][0] Error 20251.219896844126
[Hippocampus][0]
     >> Selected xgboost_regressor -- {'reg_lambda': 9.7795606822451, 'reg_alpha': 3.791452511149702, 'max_depth': 4, 'n_estimators': 237, 'lr': 0.001} 696.1685810208094
     >> [Hippocampus][0] Error 20249.148146875115
[Fusiform][0]
     >> Selected random_forest_regressor -- {'criterion': 0, 'max_features': 1, 'min_samples_split': 5, 'min_samples_leaf': 10, 'max_depth': 2} 2353.2191162430454
     >> [Fusiform][0] Error 20244.964415176226
[MidTemp][0]
     >> Selected random_forest_regressor -- {'criterion': 0, 'max_features': 2, 'min_samples_split': 10, 'min_samples_leaf': 2, 'max_depth': 1} 2768.925890761426
     >> [MidTemp][0] Error 20241.284597147805
[ADAS13][0]
     >>

In [None]:
test_AB_imputed.describe()

In [None]:
test_AB_output.describe()

In [None]:
raise

## Evaluation

In [None]:
predictions = test_AB_imputed.copy()
assert predictions.isna().sum().sum() == 0

## Submission data

In [None]:
def normalize_output(test_data):
    test_data = test_data.copy()

    test_data[scaled_cols] = scaler.inverse_transform(test_data[scaled_cols])

    factor = test_data["CDRSB"] / 0.5
    factor[factor < 0] = 0
    factor = factor.fillna(-1)
    factor = factor.round(0).astype(int)
    factor = factor.replace(-1, np.nan)
    test_data["CDRSB"] = factor * 0.5

    test_data["ADAS13"] = ((test_data["ADAS13"] * 3).round(0) / 3).round(2)
    test_data["MMSE"] = test_data["MMSE"].round(0)

    return test_data



def dump_results(imputed_data: pd.DataFrame, fpath: str):
    results = []

    for name, data in [
        ("test_A", raw_test_A.sort_index()),
        ("test_B", raw_test_B.sort_index()),
    ]:
        for idx, row in data.iterrows():
            for col in row.index:
                local = row.T
                val = local[col]
                if val == val:
                    continue
                imputed_id = f"{local['RID_HASH']}_{local['VISCODE']}_{col}_{name}"
                imputed_val = imputed_data[
                    (imputed_data["RID_HASH"] == local["RID_HASH"])
                    & (imputed_data["VISCODE"] == local["VISCODE"])
                ][col].values[0]

                assert imputed_val == imputed_val
                assert imputed_val != ""

                results.append([imputed_id, imputed_val])

    output = pd.DataFrame(results, columns=submission.columns)
    output.to_csv(fpath, index=None)

    return output


def get_submission_data(imputed_data):
    imputed_data = imputed_data.copy()
    imputed_data = normalize_output(imputed_data)

    output_fpath = (
        results_dir
        / f"imputation_results_{version}_{changelog}_normalized.csv"
    )

    print("Prepare output", output_fpath)
    output_normalized = dump_results(imputed_data, output_fpath)

    return output_fpath, output_normalized


normalize_output(predictions)

In [None]:
fpath, output = get_submission_data(predictions)

output

##### 