In [1]:
import pandas as pd
import numpy as np
import warnings
from pathlib import Path
from sklearn.preprocessing import MinMaxScaler

from hyperimpute.utils.serialization import load_model_from_file, save_model_to_file


workspace = Path("workspace")
results_dir = Path("results")
data_dir = Path("data")

workspace.mkdir(parents=True, exist_ok=True)

warnings.filterwarnings("ignore")

cat_limit = 10
n_seeds = 5

version = "take7"
changelog = f"ffill_catlimit{cat_limit}"


In [2]:
def dataframe_hash(df: pd.DataFrame) -> str:
    cols = sorted(list(df.columns))
    return str(abs(pd.util.hash_pandas_object(df[cols]).sum()))

def augment_base_dataset(df):
    df = df.sort_values(["RID_HASH", "VISCODE"])

    return df

In [3]:
dev_set = pd.read_csv(data_dir / "dev_set.csv")
dev_set = dev_set.sort_values(["RID_HASH", "VISCODE"])
dev_set = augment_base_dataset(dev_set)

scaled_cols = [
    "MMSE",
    "ADAS13",
    "Ventricles",
    "Hippocampus",
    "WholeBrain",
    "Entorhinal",
    "Fusiform",
    "MidTemp",
]

scaler = MinMaxScaler().fit(dev_set[scaled_cols])
dev_set[scaled_cols] = scaler.transform(dev_set[scaled_cols])

dev_set

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,0,20,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,0,20,1.0,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,72.9,1,12,1.0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,73.4,1,12,1.0,1.0,1.0,1.000000,0.164384,0.144729,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,73.9,1,12,1.0,1.0,1.0,0.961538,0.109589,0.155550,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,79.8,1,19,1.0,0.0,3.0,0.923077,0.223699,0.170895,0.357020,0.321346,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,83.3,1,19,1.0,0.0,3.0,0.846154,0.168904,0.178231,0.352043,0.309095,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,0,12,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,0.636654,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,73.1,0,12,1.0,0.0,1.0,0.961538,0.155205,0.398451,0.608521,0.634650,0.617108,0.729087,0.638477


In [4]:
dev_1 = pd.read_csv(data_dir / "dev_1.csv")
dev_1 = augment_base_dataset(dev_1)
dev_1[scaled_cols] = scaler.transform(dev_1[scaled_cols])

dev_1

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,,0.0,20.0,1.0,1.0,0.5,0.923077,0.164384,,,0.376516,,,
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,0.0,20.0,1.0,1.0,1.5,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,,1.0,12.0,,1.0,,,,,0.525169,0.235599,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,,1.0,12.0,,1.0,,,,,0.549210,0.230361,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,,1.0,12.0,,1.0,,,,,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,79.8,1.0,19.0,,0.0,,,,0.170895,,0.321346,,,
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,83.3,1.0,19.0,,0.0,,,,0.178231,,0.309095,,,
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,,12.0,1.0,0.0,0.5,0.884615,0.150685,0.416382,0.602438,,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,73.1,,12.0,1.0,0.0,1.0,0.961538,0.155205,0.398451,0.608521,,0.617108,0.729087,0.638477


In [5]:
dev_2 = pd.read_csv(data_dir / "dev_2.csv")
dev_2 = augment_base_dataset(dev_2)
dev_2[scaled_cols] = scaler.transform(dev_2[scaled_cols])

dev_2

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
2163,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,0.0,20.0,1.0,1.0,0.5,0.923077,0.164384,0.071871,0.548646,0.376516,0.464021,0.194906,0.400709
154,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,,,,1.0,,,,0.071956,0.548307,,0.403880,0.193367,0.397291
1385,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,72.9,,12.0,1.0,1.0,1.0,1.000000,0.123288,0.142655,0.525169,,0.513404,0.356253,0.294774
2698,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,,,12.0,1.0,1.0,1.0,1.000000,0.164384,,0.549210,,0.435097,0.322395,0.294175
2291,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,,,12.0,1.0,1.0,1.0,0.961538,0.109589,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,,,19.0,1.0,0.0,3.0,0.923077,0.223699,,0.357020,,0.310935,0.399047,0.461476
2646,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,,,19.0,1.0,0.0,3.0,0.846154,0.168904,,0.352043,,0.256790,0.372685,0.416478
1962,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,,12.0,,0.0,,,,0.416382,0.602438,,0.610229,0.743037,0.624631
122,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,,,12.0,,0.0,,,,,,,,,


In [6]:
submission = pd.read_csv(data_dir / "sample_submission.csv")

submission.values[1]

array(['6b6a7136f42a8dbd469a201b88e2abb54a93667822761357db2f6d620da6af8a_0_Ventricles_test_A',
       40613.0818580834], dtype=object)

In [7]:
test_A = pd.read_csv(data_dir / "test_A.csv")
test_A = augment_base_dataset(test_A)
test_A[scaled_cols] = scaler.transform(test_A[scaled_cols])

test_A

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
247,00d5e0050fbd3b6b610f6673347232eb0862df77b5b7a8...,0,,,16.0,1.0,0.0,0.5,0.961538,0.219178,,,,,,
819,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,0,72.5,1.0,12.0,,1.0,,,,0.057498,0.612302,0.423268,0.291182,0.433004,0.329131
276,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,6,73.0,1.0,12.0,,1.0,,,,0.067972,,0.399942,,,
350,013c6f92763546c7ad9c0831f023886c15f05e7332aa0c...,12,73.5,1.0,12.0,1.0,1.0,2.0,0.769231,0.365342,0.077516,,0.415324,,,
1268,024efbff9265302acd00190e57ee08ba1fe1b90f561f79...,0,,0.0,14.0,1.0,1.0,2.0,1.000000,0.164384,,,0.515223,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
841,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,0,,,18.0,1.0,1.0,1.5,0.807692,0.150685,,,,,,
330,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,6,,,18.0,1.0,1.0,1.5,0.769231,0.095890,,,,,,
939,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,24,,,18.0,1.0,1.0,1.5,0.769231,0.150685,,,,,,
119,ff2966461950ba81280a0189ed2d504a8bd503d9f6b078...,48,70.9,,18.0,1.0,1.0,2.5,0.807692,0.246575,0.307697,0.420993,,0.392416,0.577719,0.403872


In [8]:
test_B = pd.read_csv(data_dir / "test_B.csv")
test_B = augment_base_dataset(test_B)
test_B[scaled_cols] = scaler.transform(test_B[scaled_cols])

test_B

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
1181,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,0,71.4,,15.0,0.0,2.0,0.0,0.961538,0.077671,0.085164,0.638939,,0.608113,0.424862,0.523781
1426,001854e92967164311f3acd5a58be9790f28ab3968bbbc...,36,74.4,,15.0,0.0,2.0,0.0,1.000000,0.027397,0.089750,,,,,
1201,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,0,,0.0,,1.0,0.0,0.5,0.846154,0.196301,,0.345711,0.286043,0.312698,0.276821,0.248579
757,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,6,,0.0,,1.0,0.0,1.0,1.000000,0.283151,,0.345147,0.278219,0.378307,0.289480,0.253793
763,0059bc7849aea9522b408fa0ddc60276a36cae00206b87...,12,,0.0,,1.0,0.0,2.5,0.807692,0.168904,,0.329233,0.253372,0.352028,0.259842,0.222042
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1423,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,0,,,18.0,1.0,1.0,1.5,0.884615,0.114110,,0.502370,,0.394356,0.397160,0.531003
558,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,12,,,18.0,1.0,1.0,1.5,0.923077,0.242055,,0.519639,,0.294356,0.416522,0.545575
70,ff4eb5a64e2b89861d5dea81190669893070b227f3a335...,84,,0.0,18.0,1.0,1.0,1.5,1.000000,0.178082,,0.432054,0.483387,0.363316,0.468451,0.508440
480,ffa86109ba8684f31325842d0ff26568e105f0f63b366a...,0,66.3,,13.0,0.0,0.0,0.0,0.923077,0.118767,0.177669,,,,,


In [9]:
test_A.isna().sum()

RID_HASH          0
VISCODE           0
AGE             612
PTGENDER_num    626
PTEDUCAT         65
DX_num          428
APOE4            49
CDRSB           428
MMSE            428
ADAS13          428
Ventricles      612
Hippocampus     668
WholeBrain      626
Entorhinal      668
Fusiform        668
MidTemp         668
dtype: int64

In [10]:
test_A.columns

Index(['RID_HASH', 'VISCODE', 'AGE', 'PTGENDER_num', 'PTEDUCAT', 'DX_num',
       'APOE4', 'CDRSB', 'MMSE', 'ADAS13', 'Ventricles', 'Hippocampus',
       'WholeBrain', 'Entorhinal', 'Fusiform', 'MidTemp'],
      dtype='object')

## Baseline imputation

In [11]:
from hyperimpute.plugins.imputers import Imputers

# VISCODE 6 * x -> AGE 0.5 * x

const_by_patient = ["PTGENDER_num", "PTEDUCAT", "APOE4"]

def prepare_consts(train_data, test_data):
    test_data = test_data.copy()
    train_data = train_data.copy()

    train_data = train_data.sort_values(["RID_HASH", "VISCODE"])
    test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

    for item in test_data.groupby("RID_HASH"):
        local = item[1]

        # fill consts
        for col in const_by_patient:
            if len(local[col].unique()) == 1:
                continue
            rid = local["RID_HASH"].unique()[0]

            val = local[col][~local[col].isna()].unique()[0]
            local[col] = local[col].fillna(val)
            test_data.loc[test_data["RID_HASH"] == rid, col] = test_data[
                test_data["RID_HASH"] == rid
            ][col].fillna(val)
            assert len(local[col].unique()) == 1, col

    return test_data


def prepare_age(train_data, test_data):
    test_data = test_data.copy()
    train_data = train_data.copy()

    train_data = train_data.sort_values(["RID_HASH", "VISCODE"])
    test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

    col = "AGE"

    for rid in test_data["RID_HASH"].unique():
        local = test_data[test_data["RID_HASH"] == rid]

        # fill age
        ages = local["AGE"]
        if ages.isna().sum() == 0:
            continue

        if ages.isna().sum() == len(ages):
            continue

        # forward impute age
        prev_viscode = 0
        prev_age = 0
        for idx, row in local.iterrows():
            current_viscode = row["VISCODE"]
            local_idx = (test_data["VISCODE"] == current_viscode) & (
                test_data["RID_HASH"] == rid
            )
            if prev_age > 0 and prev_age == prev_age:
                pred_age = (current_viscode - prev_viscode) / 6 * 0.5 + prev_age
            else:
                pred_age = row[col]

            if pred_age == pred_age:
                # print("forward imputed", pred_age, current_viscode)
                test_data.loc[local_idx, col] = test_data.loc[local_idx][col].fillna(
                    pred_age
                )

            prev_viscode = row["VISCODE"]
            prev_age = pred_age

        # reverse impute age
        prev_viscode = 0
        prev_age = 0
        for idx, row in local.iloc[::-1].iterrows():
            current_viscode = row["VISCODE"]
            local_idx = (test_data["VISCODE"] == current_viscode) & (
                test_data["RID_HASH"] == rid
            )

            if prev_age > 0 and prev_age == prev_age:
                pred_age = prev_age - (prev_viscode - current_viscode) / 6 * 0.5
            else:
                pred_age = row[col]

            if pred_age == pred_age:
                # print("reversed imputed", pred_age, current_viscode)
                test_data.loc[local_idx, col] = test_data.loc[local_idx][col].fillna(
                    pred_age
                )

            prev_viscode = row["VISCODE"]
            prev_age = pred_age

        # print(test_data[(test_data["RID_HASH"] == rid)][["VISCODE", "AGE"]])
    return test_data

def interm_imputation(train_data, test_data):   
    test_data = test_data.copy()
    
    for rid in test_data["RID_HASH"].unique():
        local = test_data[test_data["RID_HASH"] == rid]
        
        local = local.ffill()
        local = local.bfill()
    
        test_data.loc[test_data["RID_HASH"] == rid] = local
    
    return test_data

def full_imputation(train_data, test_data, random_state: int = 0):
    imputed_test_data = test_data.copy()

    imputer_kwargs = {
        "optimizer": "simple",
        "classifier_seed": ["xgboost"],
        "regression_seed": ["xgboost_regressor"],
        "class_threshold": cat_limit,
        "random_state" : random_state,
    }

    imputer = Imputers().get(
        "hyperimpute",
        **imputer_kwargs,
    )
    imputation_input = pd.concat([train_data, test_data], ignore_index=True)
    imputed_test_data = imputer.fit_transform(imputation_input)
    imputed_test_data = imputed_test_data.tail(len(test_data))
    
    return imputed_test_data


def evaluate_static_imputation(train_data, test_data, static_imputation):
    test_data = test_data.copy()
    train_data = train_data.copy()

    train_data = train_data.sort_values(["RID_HASH", "VISCODE"])
    test_data = test_data.sort_values(["RID_HASH", "VISCODE"])

    for rid in test_data["RID_HASH"].unique():
        patient = test_data[test_data["RID_HASH"] == rid]
        misses = []
        viscodes = []
        for idx, row in patient.iterrows():
            misses.append(row.isna().sum())
            viscodes.append(row["VISCODE"])
        cidx = np.argmin(misses)

        current_viscode = viscodes[cidx]
        local_idx = (test_data["VISCODE"] == current_viscode) & (
            test_data["RID_HASH"] == rid
        )
        imputed_idx = (static_imputation["VISCODE"] == current_viscode) & (
            static_imputation["RID_HASH"] == rid
        )

        if len(test_data[local_idx]) == 0:
            continue

        for col in test_data.columns:
            val = test_data.loc[local_idx][col].values[0]
            if val == val:
                continue
            imputed_val = static_imputation.loc[imputed_idx][col].values[0]
            test_data.loc[local_idx, col] = imputed_val

            # print("imputed", test_data.loc[local_idx, col])

    return test_data


def impute_data_step(
    train_data, test_data, 
    use_longitudinal=True, 
    static_strategy="missmin", 
    random_state: int = 0
):
    test_id = dataframe_hash(test_data)
    train_id = dataframe_hash(train_data)

    print(" >>> Evaluate constants", test_id, test_data.isna().sum().sum())
    test_data = prepare_consts(train_data, test_data)
    test_data = prepare_age(train_data, test_data)
    test_data = interm_imputation(train_data, test_data)

    print(
        " >>> Evaluate static imputation",
        test_id,
        test_data.isna().sum().sum(),
        static_strategy,
    )
    static_imputation = full_imputation(train_data, test_data, random_state = random_state)
    test_data = evaluate_static_imputation(train_data, test_data, static_imputation)

    print(" >>> Evaluate constants take 2", test_id, test_data.isna().sum().sum())
    test_data = prepare_consts(train_data, test_data)
    test_data = prepare_age(train_data, test_data)
    
    test_data = interm_imputation(train_data, test_data)
    
    assert test_data.isna().sum().sum() == 0, test_data
    
    return test_data

def impute_baseline_data(train_data, test_data, random_state: int = 0):
    test_id = dataframe_hash(test_data)
    train_id = dataframe_hash(train_data)

    bkp_file = workspace / f"seed_imputation_{version}_{train_id}_{test_id}_{random_state}_catlimit{cat_limit}.csv"
    
    if bkp_file.exists():
        return pd.read_csv(bkp_file)
    
    print("Evaluate ", bkp_file)
    imputed_test_data = impute_data_step(
        train_data, test_data, 
        random_state = random_state,
    )
    imputed_test_data.to_csv(bkp_file, index = None)

    return imputed_test_data

In [12]:
dev_1_baseline = impute_baseline_data(dev_set, dev_1)
dev_2_baseline = impute_baseline_data(dev_set, dev_2)
test_A_baseline = impute_baseline_data(dev_set, test_A)
test_B_baseline = impute_baseline_data(dev_set, test_B)

dev_1_baseline

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,0.0,20.0,1.0,1.0,0.500000,0.923077,0.164384,0.071956,0.548307,0.376516,0.403880,0.193367,0.397291
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,0.0,20.0,1.0,1.0,1.500000,0.923077,0.237397,0.071956,0.548307,0.366398,0.403880,0.193367,0.397291
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,72.9,1.0,12.0,1.0,1.0,1.173578,0.984970,0.133982,0.187193,0.525169,0.235599,0.513404,0.356253,0.294774
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,73.4,1.0,12.0,1.0,1.0,1.173578,0.984970,0.133982,0.187193,0.549210,0.230361,0.435097,0.322395,0.294175
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,73.9,1.0,12.0,1.0,1.0,1.173578,0.984970,0.133982,0.187193,0.527878,0.215944,0.487831,0.342600,0.277552
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,79.8,1.0,19.0,1.0,0.0,2.704856,0.922065,0.234043,0.170895,0.339868,0.321346,0.293336,0.400645,0.470582
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,83.3,1.0,19.0,1.0,0.0,2.704856,0.922065,0.234043,0.178231,0.339868,0.309095,0.293336,0.400645,0.470582
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,0.0,12.0,1.0,0.0,0.500000,0.884615,0.150685,0.416382,0.602438,0.624603,0.610229,0.743037,0.624631
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,73.1,0.0,12.0,1.0,0.0,1.000000,0.961538,0.155205,0.398451,0.608521,0.624603,0.617108,0.729087,0.638477


In [13]:
def mask_columns_map(s: str):
    return f"masked_{s}"

dev_1_mask = dev_1.isna().astype(int).drop(columns = ["RID_HASH", "VISCODE"]).rename(mask_columns_map, axis='columns')
dev_2_mask = dev_2.isna().astype(int).drop(columns = ["RID_HASH", "VISCODE"]).rename(mask_columns_map, axis='columns')
test_A_mask = test_A.isna().astype(int).drop(columns = ["RID_HASH", "VISCODE"]).rename(mask_columns_map, axis='columns')
test_B_mask = test_B.isna().astype(int).drop(columns = ["RID_HASH", "VISCODE"]).rename(mask_columns_map, axis='columns')

dev_1_mask

Unnamed: 0,masked_AGE,masked_PTGENDER_num,masked_PTEDUCAT,masked_DX_num,masked_APOE4,masked_CDRSB,masked_MMSE,masked_ADAS13,masked_Ventricles,masked_Hippocampus,masked_WholeBrain,masked_Entorhinal,masked_Fusiform,masked_MidTemp
2163,1,0,0,0,0,0,0,0,1,1,0,1,1,1
154,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1385,1,0,0,1,0,1,1,1,1,0,0,0,0,0
2698,1,0,0,1,0,1,1,1,1,0,0,0,0,0
2291,1,0,0,1,0,1,1,1,1,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2895,0,0,0,1,0,1,1,1,0,1,0,1,1,1
2646,0,0,0,1,0,1,1,1,0,1,0,1,1,1
1962,0,1,0,0,0,0,0,0,0,0,1,0,0,0
122,0,1,0,0,0,0,0,0,0,0,1,0,0,0


In [14]:
from tabular_encoder import TabularEncoder

whitelist = ["RID_HASH", "VISCODE"]

tabular_encoder = TabularEncoder(whitelist = whitelist)
tabular_encoder.fit(dev_set)

encoded_dev_1_baseline = tabular_encoder.transform(dev_1_baseline)
encoded_dev_2_baseline = tabular_encoder.transform(dev_2_baseline)
encoded_test_A_baseline = tabular_encoder.transform(test_A_baseline)
encoded_test_B_baseline = tabular_encoder.transform(test_B_baseline)
encoded_dev_set_baseline = tabular_encoder.transform(dev_set)

activation_layout = tabular_encoder.activation_layout(discrete_activation = "softmax", continuous_activation = "tanh")

encoded_dev_1_baseline

Unnamed: 0,RID_HASH,VISCODE,AGE.normalized,AGE.component_0,AGE.component_1,AGE.component_2,AGE.component_3,AGE.component_4,AGE.component_5,AGE.component_6,...,MidTemp.component_0,MidTemp.component_1,MidTemp.component_2,MidTemp.component_3,MidTemp.component_4,MidTemp.component_5,MidTemp.component_6,MidTemp.component_7,MidTemp.component_8,MidTemp.component_9
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,0.027571,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,0.195424,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,0.087946,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,0.039147,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,0.085334,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,0.082681,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,0.003604,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,-0.080938,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,0.276366,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [15]:
training_input_1 = pd.concat([dev_1_baseline, dev_1_mask], axis = 1)
training_input_2 = pd.concat([dev_2_baseline, dev_2_mask], axis = 1)
training_output = pd.concat([dev_set], ignore_index = True)

training_input_1

Unnamed: 0,RID_HASH,VISCODE,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,...,masked_APOE4,masked_CDRSB,masked_MMSE,masked_ADAS13,masked_Ventricles,masked_Hippocampus,masked_WholeBrain,masked_Entorhinal,masked_Fusiform,masked_MidTemp
0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,0,79.1,0.0,20.0,1.0,1.0,0.500000,0.923077,0.164384,...,0,0,0,0,1,0,1,0,0,0
1,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bd...,6,79.6,0.0,20.0,1.0,1.0,1.500000,0.923077,0.237397,...,0,0,0,0,0,0,1,0,0,0
2,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,0,72.9,1.0,12.0,1.0,1.0,1.173578,0.984970,0.133982,...,0,0,0,0,0,0,0,0,0,0
3,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,6,73.4,1.0,12.0,1.0,1.0,1.173578,0.984970,0.133982,...,1,0,0,0,0,1,1,1,1,1
4,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8...,12,73.9,1.0,12.0,1.0,1.0,1.173578,0.984970,0.133982,...,0,0,0,0,0,1,0,1,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,60,79.8,1.0,19.0,1.0,0.0,2.704856,0.922065,0.234043,...,0,0,0,0,0,1,1,1,1,1
4097,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803b...,102,83.3,1.0,19.0,1.0,0.0,2.704856,0.922065,0.234043,...,0,0,0,0,0,1,0,1,1,1
4098,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,0,72.1,0.0,12.0,1.0,0.0,0.500000,0.884615,0.150685,...,0,0,0,0,0,0,1,0,0,0
4099,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c8...,12,73.1,0.0,12.0,1.0,0.0,1.000000,0.961538,0.155205,...,0,0,0,0,1,1,0,1,1,1


In [17]:
from ts_imputer import TimeSeriesImputer, modes
from sklearn.model_selection import train_test_split

imputer = TimeSeriesImputer(
    n_units_in = training_input_1.shape[-1] - 1, # DROP RID_HASH
    n_units_out = training_output.shape[-1] - 2,
    #nonlin_out = activation_layout,
    n_iter = 1000,
    mode = "Transformer",
)

imputer.fit(training_input_1, training_output)
imputer.fit(training_input_2, training_output)
imputer.fit(training_input_1, training_output)
imputer.fit(training_input_2, training_output)


preds = imputer.predict(training_input_1)

Epoch 49 loss 1.7090247749772847
Epoch 99 loss 1.1691701878339817
Epoch 149 loss 1.157567172478407
Epoch 199 loss 0.7765469519246337
Epoch 249 loss 1.0073704676240938
Epoch 299 loss 0.4381127821074592
Epoch 349 loss 0.8838678756330767
Epoch 399 loss 0.6739168369617218
Epoch 449 loss 0.9380338416140304
Epoch 499 loss 0.7985916787233108
Epoch 549 loss 0.9327157147419758
Epoch 599 loss 0.6328605901227038
Epoch 649 loss 0.5227136141978778
Epoch 699 loss 0.336947701107233
Epoch 749 loss 0.5560902796494656
Epoch 799 loss 0.4525597699177571
Epoch 849 loss 0.3780128992138765
Epoch 899 loss 0.3518501045102747
Epoch 949 loss 0.4928133814380719
Epoch 999 loss 0.4169980382435342
Epoch 49 loss 0.5699108314310384
Epoch 99 loss 0.448140953723182
Epoch 149 loss 0.5468364340117854
Epoch 199 loss 0.29200037855368394
Epoch 249 loss 0.3206739251812299
Epoch 299 loss 0.3825725665968707
Epoch 349 loss 0.39062835263390827
Epoch 399 loss 0.34866757136888993
Epoch 449 loss 0.45150339909088916
Epoch 499 loss 0.

In [18]:
preds

Unnamed: 0,AGE,PTGENDER_num,PTEDUCAT,DX_num,APOE4,CDRSB,MMSE,ADAS13,Ventricles,Hippocampus,WholeBrain,Entorhinal,Fusiform,MidTemp,VISCODE,RID_HASH
0,77.412292,0.378526,19.75346,0.979201,1.323903,1.253184,0.960497,0.240416,0.27158,0.510428,0.459216,0.473073,0.454357,0.460835,0,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158
1,77.907539,0.362056,19.724102,1.025998,1.338204,1.455162,0.947653,0.252951,0.284123,0.49595,0.447975,0.460189,0.442397,0.451033,6,001c7955017f905ccf78d55c94e81070a1cca7b1efb5bdc713271adea9eaa158
2,70.698006,0.558444,11.742214,0.804613,1.073633,0.677403,0.918203,0.205593,0.173671,0.534382,0.4216,0.476191,0.427295,0.441008,0,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c
3,71.20697,0.542338,11.720034,0.852156,1.0895,0.878785,0.906243,0.218489,0.185929,0.521107,0.411215,0.464273,0.416033,0.431817,6,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c
4,71.685806,0.526299,11.686798,0.876667,1.086947,0.949755,0.900397,0.223688,0.19518,0.513975,0.406704,0.460345,0.41125,0.42736,12,00e6fb56250581a8c8b5133f91443dd8c037e3cd8d0ba8ea199212588d2d672c
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4096,81.984993,0.202276,18.982885,1.197948,-0.14986,2.777208,0.886612,0.282895,0.347757,0.510707,0.445091,0.488741,0.432699,0.431508,60,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803bcd30ea99c58dcf91d7
4097,85.450417,0.187772,18.983086,1.25853,-0.286669,4.422988,0.820123,0.332222,0.374831,0.543153,0.454334,0.53124,0.457327,0.459483,102,ff59785f0d6b12fc51a07f09bb3a02790e54d04bb0803bcd30ea99c58dcf91d7
4098,70.786026,0.528944,11.944712,0.805118,0.331211,0.844123,0.909576,0.207488,0.188356,0.540381,0.419957,0.476149,0.421753,0.432415,0,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c81b37c40ecf646cb0c6
4099,71.755417,0.496969,11.88087,0.870515,0.337264,1.085363,0.893208,0.223645,0.208898,0.521928,0.406705,0.462574,0.407423,0.420107,12,ff98c50c3e97b776ab61db883cf1c8fd5a6d304d7165c81b37c40ecf646cb0c6


In [21]:
from hyperimpute.utils.benchmarks import RMSE

gt = pd.concat([dev_set,dev_set], ignore_index = True)
gt = gt.sort_values(["RID_HASH", "VISCODE"])
gt_mask = pd.concat([dev_1,dev_2], ignore_index = True)
gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"])  

imputation_baseline = pd.concat([dev_1_baseline, dev_2_baseline], ignore_index = True)
imputation_baseline = imputation_baseline.sort_values(["RID_HASH", "VISCODE"])
imputation_baseline = imputation_baseline.drop(columns = ["RID_HASH", "VISCODE"]).values

RMSE(
    imputation_baseline,
    gt.drop(columns = ["RID_HASH", "VISCODE"]).values, 
    gt_mask.drop(columns = ["RID_HASH", "VISCODE"]).values,
)

0.6574952123188846

In [22]:
gt = pd.concat([dev_set,dev_set], ignore_index = True)
gt = gt.sort_values(["RID_HASH", "VISCODE"])
gt_mask = pd.concat([dev_1,dev_2], ignore_index = True)
gt_mask = gt_mask.sort_values(["RID_HASH", "VISCODE"])  

#nn_preds = tabular_encoder.inverse_transform(preds).drop(columns = ["RID_HASH", "VISCODE"]).values
full_preds = imputer.predict(pd.concat([training_input_1, training_input_2]))
nn_preds = full_preds.drop(columns = ["RID_HASH", "VISCODE"]).values

RMSE(
    nn_preds,
    gt.drop(columns = ["RID_HASH", "VISCODE"]).values, 
    gt_mask.drop(columns = ["RID_HASH", "VISCODE"]).values,
)

0.7266754472056841

## Submission data

In [None]:
def normalize_output(test_data):
    test_data = test_data.copy()
    factor = test_data["CDRSB"] / 0.5
    factor = factor.fillna(-1)
    factor = factor.round(0).astype(int)
    factor = factor.replace(-1, np.nan)
    test_data["CDRSB"] = factor * 0.5

    test_data["ADAS13"] = ((test_data["ADAS13"] * 3).round(0) / 3).round(2)
    test_data["MMSE"] = test_data["MMSE"].round(0)

    return test_data



def dump_results(imputed_data: pd.DataFrame, fpath: str):
    results = []

    for name, data in [
        ("test_A", test_A.sort_index()),
        ("test_B", test_B.sort_index()),
    ]:
        for idx, row in data.iterrows():
            for col in row.index:
                local = row.T
                val = local[col]
                if val == val:
                    continue
                imputed_id = f"{local['RID_HASH']}_{local['VISCODE']}_{col}_{name}"
                imputed_val = imputed_data[
                    (imputed_data["RID_HASH"] == local["RID_HASH"])
                    & (imputed_data["VISCODE"] == local["VISCODE"])
                ][col].values[0]
                
                assert imputed_val == imputed_val
                assert imputed_val != ""
                
                results.append([imputed_id, imputed_val])

    output = pd.DataFrame(results, columns=submission.columns)
    output.to_csv(fpath, index=None)

    return output

def get_submission_data(random_state):
    test_A_eval = impute_data(dev_set, test_A, random_state = random_state).sort_index()
    test_B_eval = impute_data(dev_set, test_B, random_state = random_state).sort_index()

    eval_data = pd.concat([dev_set, test_A_eval, test_B_eval], ignore_index=True)
    eval_data[scaled_cols] = scaler.inverse_transform(eval_data[scaled_cols])

    output_fpath = results_dir / f"imputation_results_{version}_{changelog}_normalized_seeds_{random_state}.csv"

    print("Prepare output", output_fpath)
    output_normalized = dump_results(normalize_output(eval_data), output_fpath)
    
    return output_fpath, output_normalized

for random_state in range(10):
    output_path, output = get_submission_data(random_state = random_state)

In [None]:
output_path, output0 = get_submission_data(random_state = 0)

output0

In [None]:
output_path, output1 = get_submission_data(random_state = 1)

output1

In [None]:
(output1["Predicted"] - output0["Predicted"]).sum()