In [3]:
import pandas as pd
from sklearn.model_selection import train_test_split
from MAEImputer import ReMaskerStep
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import pickle
from math import sqrt
import os

### Get the data

In [4]:
################ Read Datasets ################
df_test = pd.read_csv('../data/X_test.csv')
print(f'Test values shape: {df_test.shape}')
df_test.head()

Test values shape: (326025, 403)


Unnamed: 0,npval_50971,nptime_50971,npval_50983,nptime_50983,npval_50902,nptime_50902,npval_51221,nptime_51221,npval_50912,nptime_50912,...,nptime_last_51104,npval_last_51078,nptime_last_51078,npval_last_50884,nptime_last_50884,npval_last_51255,nptime_last_51255,first_race,chartyear,hadm_id
0,4.3,18.0,137.0,18.0,102.0,18.0,38.4,18.0,1.0,18.0,...,,,,,,,,WHITE,2190,20000057
1,3.3,16.0,141.0,16.0,103.0,16.0,38.2,16.0,1.1,16.0,...,,,,,,,,WHITE,2190,20000057
2,3.4,17.0,145.0,17.0,103.0,17.0,39.0,17.0,1.0,17.0,...,,,,,,,,WHITE,2190,20000057
3,4.0,12.0,138.0,12.0,103.0,12.0,37.8,12.0,0.8,12.0,...,,,,,,,,WHITE,2193,20000293
4,,,,,,,,,,,...,,,,,,,,BLACK/AFRICAN,2183,20000298


In [5]:
################ Clean Missing Data ################
def clean_missing(df, threshold=20 + 3, missing_per_col=100, cols_to_remove=None):
    # Remove rows with less than 20 values
    df = df.dropna(thresh=threshold)
    print(f"DataFrame after removing rows with at least 20 missing values: {df.shape}")
    
    if type(cols_to_remove) != list:
        if missing_per_col and not cols_to_remove:
            # Get columns where at least 100 values are not missing
            columns_all_nan = df.columns[df.notna().sum() < missing_per_col].tolist()
            # Identify columns that end with a number after the last underscore
            ids = ['_' + col.split('_')[-1] for col in columns_all_nan]

            def ids_in_string(value_list, target_string):
                for value in value_list:
                    if value in target_string:
                        return True
                return False

            cols_to_remove = []
            for column in df.columns:
                if ids_in_string(ids, column):
                    cols_to_remove.append(column)

    print(f'Removing columns: {cols_to_remove}')

    df.drop(columns=cols_to_remove, inplace=True)
    
    return df, cols_to_remove

missing_per_row = 20 + 3 # + 3 because of: first_race, chartyear, hadm_id
missing_per_col = 500

df_test, _ = clean_missing(df_test, missing_per_row, cols_to_remove=[])

DataFrame after removing rows with at least 20 missing values: (313763, 403)
Removing columns: []


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.drop(columns=cols_to_remove, inplace=True)


In [6]:
# Since is a demo let's use only 1000 rows
df_test = df_test.head(1000)

### Create an instance of the imputer model

In [7]:
################ Create Imputer Instance ################
columns = df_test.shape[1] - 3 # + 3 because of: first_race, chartyear, hadm_id

mask_ratio = 0.25
max_epochs = 1 #400
save_path = 'demo'

batch_size=256 
embed_dim=64
depth=8
decoder_depth=4
num_heads=8
mlp_ratio=4.0
weigths = '100_Labs_Train_0.25Mask_L_V3/epoch390_checkpoint'

imputer = ReMaskerStep(dim=columns, mask_ratio=mask_ratio, max_epochs=max_epochs, save_path=save_path, batch_size=batch_size,
                      embed_dim=embed_dim, depth=depth, decoder_depth=decoder_depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                      weigths=weigths)


loading model weigths...


In [8]:
# Load the normalizations parameters
with open('100_Labs_Train_0.25Mask_L_V3/norm_parameters.pkl', 'rb') as file:
    loaded_norm_parameters = pickle.load(file)
    
imputer.norm_parameters = loaded_norm_parameters

### Train the model

In [7]:
################ Train the model ################
imputer.fit(df_test.drop(columns=['first_race', 'chartyear', 'hadm_id']), val_df.drop(columns=['first_race', 'chartyear', 'hadm_id']))

calculating norm parameters...


100%|██████████| 4/4 [02:45<00:00, 41.25s/it]


Evaluation of epoch 0...
Epoch0 Evaluation for npval_50971: RMSE = 0.2904550859913008, MAE = 0.232860939278787, R2 = 0.6226119597955024

Epoch0 Evaluation for npval_50983: RMSE = 1.0837536956224367, MAE = 0.6108794825037098, R2 = 0.9376735059940294

Epoch0 Evaluation for npval_50902: RMSE = 1.001454948927611, MAE = 0.609597957334039, R2 = 0.9664691356351903

Epoch0 Evaluation for npval_51221: RMSE = 0.26294848397177123, MAE = 0.1400610132420316, R2 = 0.9977859465206198

Epoch0 Evaluation for npval_50912: RMSE = 0.2694394834288218, MAE = 0.14482057585435756, R2 = 0.9147680269188602

Epoch0 Evaluation for npval_51006: RMSE = 3.971334820901262, MAE = 2.7702480284548594, R2 = 0.9326855475558764

Epoch0 Evaluation for npval_50882: RMSE = 0.932510023804176, MAE = 0.5343477592039644, R2 = 0.9239244078147897

Epoch0 Evaluation for npval_50868: RMSE = 0.855768380643851, MAE = 0.462479366345352, R2 = 0.9035696963179912

Epoch0 Evaluation for npval_50931: RMSE = 26.708368025344022, MAE = 20.14971



Epoch0 Evaluation for npval_51104: RMSE = 117.31004630775274, MAE = 91.91630664062495, R2 = 0.8359625982399395

Epoch0 Evaluation for npval_51078: RMSE = 15.607879777779859, MAE = 12.889178466796874, R2 = 0.21478239054399983





Epoch0 Evaluation for npval_50884: RMSE = 0.4154846191406252, MAE = 0.4154846191406252, R2 = nan

Epoch0 Evaluation for npval_51255: RMSE = 1.035256708435501, MAE = 0.7898102521896362, R2 = 0.33015221727455746

Epoch0 Evaluation for npval_last_50971: RMSE = 0.2851983470815638, MAE = 0.2262014270588092, R2 = 0.657869926700084

Epoch0 Evaluation for npval_last_50983: RMSE = 0.9871344442602759, MAE = 0.5764693721648185, R2 = 0.9494814813548322

Epoch0 Evaluation for npval_last_50902: RMSE = 1.018852970130781, MAE = 0.64434936719063, R2 = 0.9660272193418025

Epoch0 Evaluation for npval_last_51221: RMSE = 0.3578235354378583, MAE = 0.2423481303415479, R2 = 0.9959430698752951

Epoch0 Evaluation for npval_last_50912: RMSE = 0.1907504351337875, MAE = 0.12373713923108048, R2 = 0.9600161286029225

Epoch0 Evaluation for npval_last_51006: RMSE = 4.368655636576033, MAE = 2.933789543616466, R2 = 0.9160011971093095

Epoch0 Evaluation for npval_last_50882: RMSE = 1.0980102599037629, MAE = 0.67854770537



Epoch0 Evaluation for npval_last_51255: RMSE = 0.9361315704009681, MAE = 0.7712709307670593, R2 = -2.5053692684055306

1 , 0.12201608502401057


<MAEImputer.ReMaskerStep at 0x384cf82e0>

### Transform (Impute) using the model

In [9]:
eval_batch_size = 256

imputed_test = pd.DataFrame(imputer.transform(df_test.drop(columns=['first_race', 'chartyear', 'hadm_id']), eval_batch_size=eval_batch_size).cpu().numpy())

In [12]:
df_test

Unnamed: 0,npval_50971,nptime_50971,npval_50983,nptime_50983,npval_50902,nptime_50902,npval_51221,nptime_51221,npval_50912,nptime_50912,...,nptime_last_51104,npval_last_51078,nptime_last_51078,npval_last_50884,nptime_last_50884,npval_last_51255,nptime_last_51255,first_race,chartyear,hadm_id
0,4.3,18.0,137.0,18.0,102.0,18.0,38.4,18.0,1.0,18.0,...,,,,,,,,WHITE,2190,20000057
1,3.3,16.0,141.0,16.0,103.0,16.0,38.2,16.0,1.1,16.0,...,,,,,,,,WHITE,2190,20000057
2,3.4,17.0,145.0,17.0,103.0,17.0,39.0,17.0,1.0,17.0,...,,,,,,,,WHITE,2190,20000057
3,4.0,12.0,138.0,12.0,103.0,12.0,37.8,12.0,0.8,12.0,...,,,,,,,,WHITE,2193,20000293
5,3.8,17.0,140.0,17.0,103.0,17.0,27.7,17.0,0.7,17.0,...,,,,,,,,BLACK/AFRICAN,2183,20000298
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1039,5.0,12.0,138.0,12.0,103.0,12.0,24.8,12.0,2.5,12.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1040,5.0,10.0,131.0,10.0,96.0,10.0,23.0,10.0,3.6,10.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1041,5.0,22.0,131.0,22.0,93.0,22.0,24.0,22.0,4.2,22.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1042,3.9,20.0,131.0,20.0,93.0,20.0,23.0,12.0,3.0,20.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924


In [None]:
imputed_test

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,390,391,392,393,394,395,396,397,398,399
0,4.3,18.0,137.0,18.0,102.0,18.0,38.4,18.0,1.0,18.0,...,1.767103,43.245804,596.662231,34.385208,44.843685,36.978500,0.902864,30.409641,1.237766,39.394405
1,3.3,16.0,141.0,16.0,103.0,16.0,38.2,16.0,1.1,16.0,...,1.904227,57.429443,608.033081,138.584167,41.289528,120.660080,0.937603,32.299026,1.230534,52.178501
2,3.4,17.0,145.0,17.0,103.0,17.0,39.0,17.0,1.0,17.0,...,1.855902,59.075882,666.769104,129.137024,44.378906,117.640900,1.064101,32.060921,1.258541,51.411385
3,4.0,12.0,138.0,12.0,103.0,12.0,37.8,12.0,0.8,12.0,...,1.897429,43.416454,639.346802,34.875477,50.585411,35.200684,1.071532,30.379951,1.198120,40.974777
4,3.8,17.0,140.0,17.0,103.0,17.0,27.7,17.0,0.7,17.0,...,1.629823,42.241543,518.730530,31.014757,45.348347,30.422703,0.975406,29.500666,1.256737,33.223385
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,5.0,12.0,138.0,12.0,103.0,12.0,24.8,12.0,2.5,12.0,...,1.521729,53.853996,381.978943,150.453598,49.478943,168.306198,1.016162,35.018742,1.215181,85.001770
996,5.0,10.0,131.0,10.0,96.0,10.0,23.0,10.0,3.6,10.0,...,1.271633,79.960770,429.067200,259.690674,41.761543,260.035583,1.056270,38.480019,1.292287,105.559227
997,5.0,22.0,131.0,22.0,93.0,22.0,24.0,22.0,4.2,22.0,...,1.212541,87.858734,447.379150,294.476288,32.837543,305.014343,1.042971,39.204662,1.413997,120.673744
998,3.9,20.0,131.0,20.0,93.0,20.0,23.0,12.0,3.0,20.0,...,1.159791,98.819763,470.416992,288.705505,33.884087,296.397278,1.099674,43.196106,1.197099,133.847931


### Extract embeddings using the model

In [14]:
eval_batch_size = 256

embeddings = imputer.extract_embeddings(df_test.drop(columns=['first_race', 'chartyear', 'hadm_id']), eval_batch_size=eval_batch_size).cpu().numpy()

In [30]:
df_test

Unnamed: 0,npval_50971,nptime_50971,npval_50983,nptime_50983,npval_50902,nptime_50902,npval_51221,nptime_51221,npval_50912,nptime_50912,...,nptime_last_51104,npval_last_51078,nptime_last_51078,npval_last_50884,nptime_last_50884,npval_last_51255,nptime_last_51255,first_race,chartyear,hadm_id
0,4.3,18.0,137.0,18.0,102.0,18.0,38.4,18.0,1.0,18.0,...,,,,,,,,WHITE,2190,20000057
1,3.3,16.0,141.0,16.0,103.0,16.0,38.2,16.0,1.1,16.0,...,,,,,,,,WHITE,2190,20000057
2,3.4,17.0,145.0,17.0,103.0,17.0,39.0,17.0,1.0,17.0,...,,,,,,,,WHITE,2190,20000057
3,4.0,12.0,138.0,12.0,103.0,12.0,37.8,12.0,0.8,12.0,...,,,,,,,,WHITE,2193,20000293
5,3.8,17.0,140.0,17.0,103.0,17.0,27.7,17.0,0.7,17.0,...,,,,,,,,BLACK/AFRICAN,2183,20000298
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1039,5.0,12.0,138.0,12.0,103.0,12.0,24.8,12.0,2.5,12.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1040,5.0,10.0,131.0,10.0,96.0,10.0,23.0,10.0,3.6,10.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1041,5.0,22.0,131.0,22.0,93.0,22.0,24.0,22.0,4.2,22.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1042,3.9,20.0,131.0,20.0,93.0,20.0,23.0,12.0,3.0,20.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924


In [31]:
embeddings.shape

(1000, 400, 64)

##### Convert embeddings to df:

In [63]:
embeddings_df = df_test.copy()

# Convert the first 400 columns to dtype `object` to store lists
embeddings_df.iloc[:, :400] = embeddings_df.iloc[:, :400].astype(object)

# Replace the values in the first 400 columns with the corresponding embeddings
for i, index in enumerate(embeddings_df.index):
    
    if i % 100 == 0:
        print(f'{i} embeddings processed')
        
    print(f'Row: {i}')
    for j in range(400):
        column = embeddings_df.columns[j]
        if not pd.isna(embeddings_df.loc[index, column]):  # Check if the value is not NaN
            embeddings_df.at[index, column] = embeddings[i, j].tolist()


Row: 0, Column: npval_50971
Row: 0, Column: nptime_50971
Row: 0, Column: npval_50983
Row: 0, Column: nptime_50983
Row: 0, Column: npval_50902
Row: 0, Column: nptime_50902
Row: 0, Column: npval_51221
Row: 0, Column: nptime_51221
Row: 0, Column: npval_50912
Row: 0, Column: nptime_50912
Row: 0, Column: npval_51006
Row: 0, Column: nptime_51006
Row: 0, Column: npval_50882
Row: 0, Column: nptime_50882
Row: 0, Column: npval_50868
Row: 0, Column: nptime_50868
Row: 0, Column: npval_50931
Row: 0, Column: nptime_50931
Row: 0, Column: npval_51265
Row: 0, Column: nptime_51265
Row: 0, Column: npval_51222
Row: 0, Column: nptime_51222
Row: 0, Column: npval_51301
Row: 0, Column: nptime_51301
Row: 0, Column: npval_51249
Row: 0, Column: nptime_51249
Row: 0, Column: npval_51279
Row: 0, Column: nptime_51279
Row: 0, Column: npval_51250
Row: 0, Column: nptime_51250
Row: 0, Column: npval_51248
Row: 0, Column: nptime_51248
Row: 0, Column: npval_51277
Row: 0, Column: nptime_51277
Row: 1, Column: npval_50971
Row

In [66]:
embeddings_df

Unnamed: 0,npval_50971,nptime_50971,npval_50983,nptime_50983,npval_50902,nptime_50902,npval_51221,nptime_51221,npval_50912,nptime_50912,...,nptime_last_51104,npval_last_51078,nptime_last_51078,npval_last_50884,nptime_last_50884,npval_last_51255,nptime_last_51255,first_race,chartyear,hadm_id
0,"[-0.21556080877780914, -0.3242534399032593, 0....","[-0.6569632887840271, -0.463593453168869, -0.0...","[-0.09285791963338852, 0.8375188708305359, 0.2...","[-0.6618356108665466, -0.3700826168060303, 0.0...","[0.013593505136668682, -0.2996731996536255, 0....","[-0.7051290273666382, -0.45784589648246765, -0...","[-0.2403661161661148, 0.29884040355682373, 0.6...","[-0.06651720404624939, -0.5991768836975098, 0....","[0.03270414471626282, -0.7626336216926575, -0....","[-0.425289124250412, -0.49396321177482605, 0.5...",...,,,,,,,,WHITE,2190,20000057
1,"[-0.21447211503982544, -0.9519215822219849, 0....","[-0.5148253440856934, -0.5279613733291626, 0.4...","[0.07857577502727509, 0.48031479120254517, 0.2...","[-0.4946562349796295, -0.720716655254364, 0.44...","[0.10744830220937729, -0.6078659892082214, 0.7...","[-0.6184892654418945, -0.6325526833534241, 0.3...","[-0.08367390930652618, 0.28744521737098694, 0....","[-0.053649552166461945, -0.7167236804962158, 0...","[-0.1718740463256836, -0.4471355676651001, -0....","[-0.3533383905887604, -0.5063391327857971, 0.6...",...,,,,,,,,WHITE,2190,20000057
2,"[-0.43300604820251465, -0.21052400767803192, 0...","[-0.6069528460502625, -0.5679804682731628, 0.2...","[-0.051011696457862854, 0.6867652535438538, 0....","[-0.6439066529273987, -0.6460733413696289, 0.3...","[0.029766876250505447, -0.5743981599807739, 0....","[-0.7327744960784912, -0.5740265250205994, 0.2...","[-0.11488408595323563, 0.3521929979324341, 0.6...","[-0.08411288261413574, -0.5903710126876831, 0....","[-0.356184184551239, -0.3310367166996002, -0.2...","[-0.4784030616283417, -0.5105301737785339, 0.5...",...,,,,,,,,WHITE,2190,20000057
3,"[0.020396579056978226, -0.2867753505706787, 0....","[-0.18402642011642456, -0.41139349341392517, 0...","[-0.14544197916984558, 0.5753598213195801, -0....","[-0.31529271602630615, -0.6057738661766052, 0....","[-0.04707568883895874, -0.1467037945985794, 0....","[-0.31918832659721375, -0.48963063955307007, 0...","[-0.35930466651916504, 0.21418845653533936, 0....","[0.6954495906829834, -0.633435845375061, 0.091...","[-0.11310791224241257, -0.5142297148704529, -0...","[0.5332318544387817, -0.2631949186325073, 0.23...",...,,,,,,,,WHITE,2193,20000293
5,"[-0.20482966303825378, -0.3735973536968231, 0....","[-0.6477938294410706, -0.5340641140937805, 0.1...","[-0.23959676921367645, 0.49599915742874146, -0...","[-0.6233546733856201, -0.6148471236228943, 0.2...","[0.021557040512561798, -0.3770830035209656, 0....","[-0.7222349643707275, -0.5528132915496826, 0.1...","[-0.42759373784065247, 0.02741774171590805, 0....","[-0.09627114236354828, -0.7949388027191162, 0....","[-0.27730420231819153, -0.625002920627594, -0....","[-0.5091067552566528, -0.33270934224128723, 0....",...,,,,,,,,BLACK/AFRICAN,2183,20000298
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1039,"[-0.7891439199447632, -0.7837693691253662, 0.4...","[-1.1731988191604614, -0.1696205884218216, -0....","[-0.35602229833602905, 0.48437315225601196, 0....","[-1.0430480241775513, -0.3261987566947937, -0....","[-0.3210268020629883, -0.4004751443862915, 0.2...","[-1.198649525642395, -0.5245359539985657, -0.3...","[0.03978034853935242, -0.11834651976823807, 0....","[-0.3812251389026642, -0.5347013473510742, 0.0...","[-0.34562307596206665, -0.3381168246269226, -0...","[-1.198569416999817, -0.601617157459259, 0.010...",...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1040,"[-0.5188016891479492, -0.7312965393066406, 0.2...","[-1.0501741170883179, 0.034674592316150665, -0...","[-0.10639360547065735, 1.3086516857147217, -0....","[-0.948326051235199, 0.03777060657739639, -0.1...","[-0.19571082293987274, 0.00019254861399531364,...","[-1.0075918436050415, -0.3644290864467621, -0....","[0.006651674397289753, 0.6553080677986145, 0.4...","[-0.19910217821598053, -0.37376293540000916, 0...","[-0.15418754518032074, -0.20213651657104492, -...","[-0.9103807806968689, -0.3613031208515167, -0....",...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1041,"[-0.28403815627098083, -0.551350474357605, 0.3...","[-0.7295982837677002, -0.02876741625368595, -0...","[-0.28302621841430664, 1.180008888244629, -0.2...","[-0.6340966820716858, -0.11073702573776245, -0...","[-0.18467365205287933, 0.44464969635009766, 0....","[-0.7803656458854675, -0.1762542426586151, -0....","[0.02880602329969406, 0.4070505201816559, 0.47...","[-0.46318817138671875, -0.5604344010353088, 0....","[-0.042061448097229004, -0.04986490681767464, ...","[-0.9409776329994202, -0.33275049924850464, 0....",...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1042,"[-0.04057621210813522, -0.4664139747619629, 0....","[-0.46559005975723267, 0.20921838283538818, -0...","[-0.015267172828316689, 1.0410341024398804, -0...","[-0.35629943013191223, 0.09066826105117798, 0....","[-0.03323885798454285, 0.4105779528617859, 0.5...","[-0.500695526599884, 0.14516475796699524, 0.09...","[0.1840108036994934, 0.6892815232276917, 0.189...","[0.001081228256225586, -0.7352004647254944, 0....","[0.3694098889827728, -0.24272549152374268, -0....","[-0.6941320896148682, -0.4309729039669037, 0.2...",...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
