In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from MAEImputer import ReMaskerStep
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import pickle
from math import sqrt
import os

### Get the data

In [2]:
################ Read Datasets ################
df = pd.read_csv('../data/X_train.csv')
print(f'Train values shape: {df.shape}')

df_test = pd.read_csv('../data/X_test.csv')
print(f'Test values shape: {df_test.shape}')
df_test.head()

Train values shape: (1417738, 403)
Test values shape: (326025, 403)


Unnamed: 0,npval_50971,nptime_50971,npval_50983,nptime_50983,npval_50902,nptime_50902,npval_51221,nptime_51221,npval_50912,nptime_50912,...,nptime_last_51104,npval_last_51078,nptime_last_51078,npval_last_50884,nptime_last_50884,npval_last_51255,nptime_last_51255,first_race,chartyear,hadm_id
0,4.3,18.0,137.0,18.0,102.0,18.0,38.4,18.0,1.0,18.0,...,,,,,,,,WHITE,2190,20000057
1,3.3,16.0,141.0,16.0,103.0,16.0,38.2,16.0,1.1,16.0,...,,,,,,,,WHITE,2190,20000057
2,3.4,17.0,145.0,17.0,103.0,17.0,39.0,17.0,1.0,17.0,...,,,,,,,,WHITE,2190,20000057
3,4.0,12.0,138.0,12.0,103.0,12.0,37.8,12.0,0.8,12.0,...,,,,,,,,WHITE,2193,20000293
4,,,,,,,,,,,...,,,,,,,,BLACK/AFRICAN,2183,20000298


In [3]:
################ Clean Missing Data ################
def clean_missing(df, threshold=20 + 3, missing_per_col=100, cols_to_remove=None):
    # Remove rows with less than 20 values
    df = df.dropna(thresh=threshold)
    print(f"DataFrame after removing rows with at least 20 missing values: {df.shape}")
    
    if type(cols_to_remove) != list:
        if missing_per_col and not cols_to_remove:
            # Get columns where at least 100 values are not missing
            columns_all_nan = df.columns[df.notna().sum() < missing_per_col].tolist()
            # Identify columns that end with a number after the last underscore
            ids = ['_' + col.split('_')[-1] for col in columns_all_nan]

            def ids_in_string(value_list, target_string):
                for value in value_list:
                    if value in target_string:
                        return True
                return False

            cols_to_remove = []
            for column in df.columns:
                if ids_in_string(ids, column):
                    cols_to_remove.append(column)

    print(f'Removing columns: {cols_to_remove}')

    df.drop(columns=cols_to_remove, inplace=True)
    
    return df, cols_to_remove

missing_per_row = 20 + 3 # + 3 because of: first_race, chartyear, hadm_id
missing_per_col = 500

df, cols_removed = clean_missing(df, missing_per_row)
df_test, _ = clean_missing(df_test, missing_per_row, cols_to_remove=cols_removed)


DataFrame after removing rows with at least 20 missing values: (1364232, 403)
Removing columns: []


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.drop(columns=cols_to_remove, inplace=True)


DataFrame after removing rows with at least 20 missing values: (313763, 403)
Removing columns: []


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.drop(columns=cols_to_remove, inplace=True)


In [4]:
# Since is a demo let's use only 1000 rows
df = df.head(1000)
df_test = df_test.head(1000)

In [5]:
################ Create Train-Validation Data ################

# Split the dataframe into train and test sets
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)  # 20% of the data as test set

# Print the shapes of the train and test dataframes
print("Train shape:", train_df.shape)
print("Val shape:", val_df.shape)


Train shape: (800, 403)
Val shape: (200, 403)


### Create an instance of the imputer model

In [6]:
################ Create Imputer Instance ################
columns = df.shape[1] - 3 # + 3 because of: first_race, chartyear, hadm_id

mask_ratio = 0.25
max_epochs = 1 #400
save_path = 'demo'

batch_size=256 
embed_dim=64
depth=8
decoder_depth=4
num_heads=8
mlp_ratio=4.0
weigths = '100_Labs_Train_0.25Mask_L_V3/epoch390_checkpoint'

imputer = ReMaskerStep(dim=columns, mask_ratio=mask_ratio, max_epochs=max_epochs, save_path=save_path, batch_size=batch_size,
                      embed_dim=embed_dim, depth=depth, decoder_depth=decoder_depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                      weigths=weigths)


loading model weigths...


### Train the model

In [7]:
################ Train the model ################
imputer.fit(train_df.drop(columns=['first_race', 'chartyear', 'hadm_id']), val_df.drop(columns=['first_race', 'chartyear', 'hadm_id']))

calculating norm parameters...


100%|██████████| 4/4 [02:45<00:00, 41.25s/it]


Evaluation of epoch 0...
Epoch0 Evaluation for npval_50971: RMSE = 0.2904550859913008, MAE = 0.232860939278787, R2 = 0.6226119597955024

Epoch0 Evaluation for npval_50983: RMSE = 1.0837536956224367, MAE = 0.6108794825037098, R2 = 0.9376735059940294

Epoch0 Evaluation for npval_50902: RMSE = 1.001454948927611, MAE = 0.609597957334039, R2 = 0.9664691356351903

Epoch0 Evaluation for npval_51221: RMSE = 0.26294848397177123, MAE = 0.1400610132420316, R2 = 0.9977859465206198

Epoch0 Evaluation for npval_50912: RMSE = 0.2694394834288218, MAE = 0.14482057585435756, R2 = 0.9147680269188602

Epoch0 Evaluation for npval_51006: RMSE = 3.971334820901262, MAE = 2.7702480284548594, R2 = 0.9326855475558764

Epoch0 Evaluation for npval_50882: RMSE = 0.932510023804176, MAE = 0.5343477592039644, R2 = 0.9239244078147897

Epoch0 Evaluation for npval_50868: RMSE = 0.855768380643851, MAE = 0.462479366345352, R2 = 0.9035696963179912

Epoch0 Evaluation for npval_50931: RMSE = 26.708368025344022, MAE = 20.14971



Epoch0 Evaluation for npval_51104: RMSE = 117.31004630775274, MAE = 91.91630664062495, R2 = 0.8359625982399395

Epoch0 Evaluation for npval_51078: RMSE = 15.607879777779859, MAE = 12.889178466796874, R2 = 0.21478239054399983





Epoch0 Evaluation for npval_50884: RMSE = 0.4154846191406252, MAE = 0.4154846191406252, R2 = nan

Epoch0 Evaluation for npval_51255: RMSE = 1.035256708435501, MAE = 0.7898102521896362, R2 = 0.33015221727455746

Epoch0 Evaluation for npval_last_50971: RMSE = 0.2851983470815638, MAE = 0.2262014270588092, R2 = 0.657869926700084

Epoch0 Evaluation for npval_last_50983: RMSE = 0.9871344442602759, MAE = 0.5764693721648185, R2 = 0.9494814813548322

Epoch0 Evaluation for npval_last_50902: RMSE = 1.018852970130781, MAE = 0.64434936719063, R2 = 0.9660272193418025

Epoch0 Evaluation for npval_last_51221: RMSE = 0.3578235354378583, MAE = 0.2423481303415479, R2 = 0.9959430698752951

Epoch0 Evaluation for npval_last_50912: RMSE = 0.1907504351337875, MAE = 0.12373713923108048, R2 = 0.9600161286029225

Epoch0 Evaluation for npval_last_51006: RMSE = 4.368655636576033, MAE = 2.933789543616466, R2 = 0.9160011971093095

Epoch0 Evaluation for npval_last_50882: RMSE = 1.0980102599037629, MAE = 0.67854770537



Epoch0 Evaluation for npval_last_51255: RMSE = 0.9361315704009681, MAE = 0.7712709307670593, R2 = -2.5053692684055306

1 , 0.12201608502401057


<MAEImputer.ReMaskerStep at 0x384cf82e0>

### Transform using the model

In [10]:
eval_batch_size = 256

imputed_test = pd.DataFrame(imputer.transform(df_test.drop(columns=['first_race', 'chartyear', 'hadm_id']), eval_batch_size=eval_batch_size).cpu().numpy())

In [12]:
df_test

Unnamed: 0,npval_50971,nptime_50971,npval_50983,nptime_50983,npval_50902,nptime_50902,npval_51221,nptime_51221,npval_50912,nptime_50912,...,nptime_last_51104,npval_last_51078,nptime_last_51078,npval_last_50884,nptime_last_50884,npval_last_51255,nptime_last_51255,first_race,chartyear,hadm_id
0,4.3,18.0,137.0,18.0,102.0,18.0,38.4,18.0,1.0,18.0,...,,,,,,,,WHITE,2190,20000057
1,3.3,16.0,141.0,16.0,103.0,16.0,38.2,16.0,1.1,16.0,...,,,,,,,,WHITE,2190,20000057
2,3.4,17.0,145.0,17.0,103.0,17.0,39.0,17.0,1.0,17.0,...,,,,,,,,WHITE,2190,20000057
3,4.0,12.0,138.0,12.0,103.0,12.0,37.8,12.0,0.8,12.0,...,,,,,,,,WHITE,2193,20000293
5,3.8,17.0,140.0,17.0,103.0,17.0,27.7,17.0,0.7,17.0,...,,,,,,,,BLACK/AFRICAN,2183,20000298
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1039,5.0,12.0,138.0,12.0,103.0,12.0,24.8,12.0,2.5,12.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1040,5.0,10.0,131.0,10.0,96.0,10.0,23.0,10.0,3.6,10.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1041,5.0,22.0,131.0,22.0,93.0,22.0,24.0,22.0,4.2,22.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924
1042,3.9,20.0,131.0,20.0,93.0,20.0,23.0,12.0,3.0,20.0,...,,,,,,,,BLACK/AFRICAN AMERICAN,2196,20033924


In [13]:
imputed_test

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,390,391,392,393,394,395,396,397,398,399
0,4.3,18.0,137.0,18.0,102.0,18.0,38.4,18.0,1.0,18.0,...,1.767103,43.245804,596.662231,34.385208,44.843685,36.978500,0.902864,30.409641,1.237766,39.394405
1,3.3,16.0,141.0,16.0,103.0,16.0,38.2,16.0,1.1,16.0,...,1.904227,57.429443,608.033081,138.584167,41.289528,120.660080,0.937603,32.299026,1.230534,52.178501
2,3.4,17.0,145.0,17.0,103.0,17.0,39.0,17.0,1.0,17.0,...,1.855902,59.075882,666.769104,129.137024,44.378906,117.640900,1.064101,32.060921,1.258541,51.411385
3,4.0,12.0,138.0,12.0,103.0,12.0,37.8,12.0,0.8,12.0,...,1.897429,43.416454,639.346802,34.875477,50.585411,35.200684,1.071532,30.379951,1.198120,40.974777
4,3.8,17.0,140.0,17.0,103.0,17.0,27.7,17.0,0.7,17.0,...,1.629823,42.241543,518.730530,31.014757,45.348347,30.422703,0.975406,29.500666,1.256737,33.223385
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,5.0,12.0,138.0,12.0,103.0,12.0,24.8,12.0,2.5,12.0,...,1.521729,53.853996,381.978943,150.453598,49.478943,168.306198,1.016162,35.018742,1.215181,85.001770
996,5.0,10.0,131.0,10.0,96.0,10.0,23.0,10.0,3.6,10.0,...,1.271633,79.960770,429.067200,259.690674,41.761543,260.035583,1.056270,38.480019,1.292287,105.559227
997,5.0,22.0,131.0,22.0,93.0,22.0,24.0,22.0,4.2,22.0,...,1.212541,87.858734,447.379150,294.476288,32.837543,305.014343,1.042971,39.204662,1.413997,120.673744
998,3.9,20.0,131.0,20.0,93.0,20.0,23.0,12.0,3.0,20.0,...,1.159791,98.819763,470.416992,288.705505,33.884087,296.397278,1.099674,43.196106,1.197099,133.847931
