In [1]:
from MAE import MaskedAutoencoder
from torch import nn
from functools import partial
import torch
import pandas as pd
import numpy as np

### Load pre-trained model:

The model is a Masked autoencoder model trained on MIMIC IV dataset using the 100 most common lab tests and its timestamps. The model was trained using the following parameters:

- batch_size = 256
- epochs = 500 with early stopping
- dim=400 # 2 columns for each lab test: 200 + 200 for the corresponding timestamps = 400 in total
- embed_dim=64
- depth=8
- decoder_depth=4
- num_heads=8
- mlp_ratio=4.0
- norm_field_loss=False
- encode_func='linear'
- eps = 1e-7


In [2]:
#columns = df_test.shape[1] - 3 # + 3 because of: first_race, chartyear, hadm_id
weigths = '100_Labs_Train_0.25Mask_L_V3/epoch390_checkpoint'
device = 'cpu'

batch_size=256 

dim=400 # Columns
embed_dim=64
depth=8
decoder_depth=4
num_heads=8
mlp_ratio=4.0
norm_field_loss=False
encode_func='linear'
eps = 1e-7

model = MaskedAutoencoder(
    rec_len=dim,
    embed_dim=embed_dim,
    depth=depth,
    num_heads=num_heads,
    decoder_embed_dim=embed_dim,
    decoder_depth=decoder_depth,
    decoder_num_heads=num_heads,
    mlp_ratio=mlp_ratio,
    norm_layer=partial(nn.LayerNorm, eps=eps),
    norm_field_loss=norm_field_loss,
    encode_func=encode_func
)

model.load_state_dict(torch.load(weigths, map_location=torch.device(device)))

<All keys matched successfully>

### Let's create some demo data to test the model:

In [3]:
# Create a tensor with the data shape (batch_size=2, columns=400)
X = pd.DataFrame([[np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], 
                  [0.5, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   0.5, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
                   np.nan, 0.5, np.nan, 0.8, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
                  ])

# Convert the dataframe to a tensor
X = torch.tensor(X.values, dtype=torch.float32)

# Create a mask tensor indicating which values are missing
M = 1 - (1 * (np.isnan(X)))

# Convert the nan values to 0 to avoid errors
X = torch.nan_to_num(X)
# Add a dimension to the tensor to match the model input
X = X.unsqueeze(dim=1)


In [4]:
X.shape, M.shape

(torch.Size([2, 1, 400]), torch.Size([2, 400]))

### Example of forward pass:

Forward pass is done by calling the model with the input data. The model returns the a tuple with: loss, predictions, mask, nask.

Inputs:
- `X`: input data
- `M`: mask indicating the valid positions in the input data (1 for valid positions, 0 for invalid positions). Invalid positions are missing data that should be predicted by the model.
- `mask_ratio`: ratio of the values to be masked by the MAE model

Outputs:
- `loss`: loss value
- `predictions`: predicted values
- `mask`: mask indicating the valid positions in the input data
- `nask`: mask indicating the invalid positions in the input

Note: 
During inference set `mask_ratio` to 0.0 and use only the second element of the tuple (`predictions`) as the output.
During training set `mask_ratio` to a value between 0.0 and 1.0 and use the first element of the tuple (`loss`) for optimization.

In [5]:
model(X, M, mask_ratio=0.2) # Output is a tuple with: loss, predictions, mask, nask
# Note: during inference set mask_ratio=0.0 and only use the predictions tensor

(tensor(0.0906, grad_fn=<AddBackward0>),
 tensor([[[0.6369],
          [0.4887],
          [0.1396],
          [0.8042],
          [0.2963],
          [0.3945],
          [0.5100],
          [0.5971],
          [0.6908],
          [0.8001],
          [0.9105],
          [0.0924],
          [0.1996],
          [0.2921],
          [0.4126],
          [0.4954],
          [0.6029],
          [0.6395],
          [0.8009],
          [0.9011],
          [0.5689],
          [0.4984],
          [0.4393],
          [0.8064],
          [0.8360],
          [0.3935],
          [0.5050],
          [0.6009],
          [0.7031],
          [0.8045],
          [0.8964],
          [0.1040],
          [0.2049],
          [0.7541],
          [0.4046],
          [0.4938],
          [0.5980],
          [0.6933],
          [0.7968],
          [0.8991],
          [0.3172],
          [0.4958],
          [0.3246],
          [0.7988],
          [0.3504],
          [0.3882],
          [0.4973],
          [0.5988],

### Extract embeddings:

The model has a method to extract the embeddings of the input data. The method receives the input data and the mask and returns the embeddings of the valid positions and the embeddings of the cls token.

Inputs:
- `X`: input data
- `M`: mask indicating the valid positions in the input data (1 for valid positions, 0 for invalid positions). Invalid positions are missing data that should be predicted by the model.
- `mask_ratio` (optional): ratio of the values to be masked by the MAE model. Default is 0.0

Outputs:
- `embeddings`: embeddings of the valid positions in the input data. The shape is (batch_size, seq_len, emb_dim). the embeddings of the invalid positions are filled with np.nan values.
- `cls_embeddings`: embeddings of the cls token

In [6]:
embeddings, cls_embedding = model.extract_embeddings(X, M)

In [7]:
embeddings.shape

torch.Size([2, 400, 64])

In [8]:
embeddings

tensor([[[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [-0.2416, -0.9443,  0.1086,  ..., -0.2403,  0.2612,  0.0051],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         ...,
         [ 0.6908, -0.7115,  0.3388,  ...,  0.7217, -0.4952, -1.0635],
         [-0.1543,  0.4914, -0.5550,  ..., -0.2721, -0.2909, -0.6287],
         [-0.3792,  0.3128,  0.3379,  ...,  0.0617, -0.1295, -0.7110]],

        [[-0.0360, -0.4197,  0.4760,  ...,  0.0408, -0.1766,  0.7799],
         [-0.2475, -0.9181,  0.1032,  ..., -0.2477,  0.2731, -0.0031],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         ...,
         [ 0.6888, -0.7126,  0.3376,  ...,  0.7271, -0.4950, -1.0622],
         [-0.1525,  0.4921, -0.5548,  ..., -0.2719, -0.2894, -0.6280],
         [-0.3814,  0.3154,  0.3387,  ...,  0.0608, -0.1303, -0.7126]]],
       grad_fn=<AsStridedBackward0>)

In [9]:
cls_embedding.shape

torch.Size([2, 64])

In [11]:
cls_embedding

tensor([[ 2.5699e-02,  6.3364e-01, -6.0745e-01, -1.5556e-01, -1.0331e+00,
         -1.2777e-01, -1.1509e+00, -9.4867e-01,  9.1756e-02,  1.4694e-01,
         -1.7349e-01, -7.7522e-01, -5.2982e-01, -2.7562e-01,  6.0526e-01,
         -7.7315e-01, -1.2642e-01,  9.6521e-01, -2.1740e-01, -3.2482e-01,
         -1.2552e-01,  3.0561e-01,  6.3130e-02,  3.0150e-01,  2.7939e-01,
          9.0490e-01,  9.1128e-01,  1.1288e-04,  5.6387e-01,  4.4801e-02,
          1.0105e+00, -9.3561e-01, -1.2520e-01,  7.5231e-01,  3.1883e-02,
         -1.4329e-01, -1.5316e-03, -5.2973e-01, -8.6476e-01,  3.1874e-01,
         -1.0763e+00,  2.1091e-01, -2.4990e-02, -4.2509e-02,  1.2121e+00,
          6.0505e-01, -5.7568e-01, -5.6585e-01, -2.0443e-01,  5.9395e-01,
         -8.5318e-02,  5.3784e-01, -3.3979e-01,  7.5654e-02,  6.5231e-01,
         -3.4830e-01, -6.5401e-02,  4.9329e-01, -9.7070e-01,  8.2891e-02,
         -4.9702e-01,  6.9165e-01, -3.1891e-01, -1.7545e-01],
        [ 2.1343e-02,  6.2828e-01, -6.0903e-01, -1