# SAITS: Ejemplo de Uso

SAITS (Self-Attention-based Imputations for Time Series) es un modelo de 2023 que pretende servir para imputar datos en POTS (Partially Observed Time Series). 

In [12]:
# Data preprocessing. Tedious, but PyPOTS can help.
import numpy as np
import benchpots
from sklearn.preprocessing import StandardScaler
from pygrinder import mcar
from pypots.data import load_specific_dataset

data = benchpots.datasets.preprocess_physionet2012('all', 0.1)

2025-02-13 15:20:54 [INFO]: You're using dataset physionet_2012, please cite it properly in your work. You can find its reference information at the below link: 
https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/physionet_2012
2025-02-13 15:20:54 [INFO]: Dataset physionet_2012 has already been downloaded. Processing directly...
2025-02-13 15:20:54 [INFO]: Dataset physionet_2012 has already been cached. Loading from cache directly...
2025-02-13 15:20:54 [INFO]: Loaded successfully!
2025-02-13 15:21:07 [INFO]: 69553 values masked out in the val set as ground truth, take 10.01% of the original observed values
2025-02-13 15:21:07 [INFO]: 86007 values masked out in the test set as ground truth, take 9.97% of the original observed values
2025-02-13 15:21:07 [INFO]: Total sample number: 11988
2025-02-13 15:21:07 [INFO]: Training set size: 7671 (63.99%)
2025-02-13 15:21:07 [INFO]: Validation set size: 1918 (16.00%)
2025-02-13 15:21:07 [INFO]: Test set size: 2399 (20.01%)
2025-02-13 1

In [20]:
data

{'n_classes': 2,
 'n_steps': 48,
 'n_features': 37,
 'scaler': StandardScaler(),
 'train_X': array([[[            nan,             nan,             nan, ...,
                      nan, -3.29921035e+00,             nan],
         [            nan,             nan,             nan, ...,
                      nan,             nan,             nan],
         [            nan,             nan,             nan, ...,
                      nan,             nan,             nan],
         ...,
         [            nan,             nan,             nan, ...,
                      nan,             nan,             nan],
         [            nan,             nan,             nan, ...,
                      nan,             nan,             nan],
         [            nan,             nan,             nan, ...,
                      nan,             nan,             nan]],
 
        [[            nan,             nan,             nan, ...,
                      nan, -2.79253393e-01,  2.29159306e-

In [21]:
print(data.keys())  
print(type(data["X"]))  

dict_keys(['n_classes', 'n_steps', 'n_features', 'scaler', 'train_X', 'train_y', 'train_ICUType', 'val_X', 'val_y', 'val_ICUType', 'test_X', 'test_y', 'test_ICUType', 'val_X_ori', 'test_X_ori'])


KeyError: 'X'

In [18]:
num_samples = len(X['RecordID'].unique())
X = X.drop(['RecordID', 'Time'], axis = 1)
X = StandardScaler().fit_transform(X.to_numpy())
X = X.reshape(num_samples, 48, -1)
X_ori = X  # keep X_ori for validation


IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

In [None]:

X = mcar(X, 0.1)  # randomly hold out 10% observed values as ground truth
dataset = {"X": X}  # X for model input
print(X.shape)  # (11988, 48, 37), 11988 samples and each sample has 48 time steps, 37 features

# Model training. This is PyPOTS showtime.
from pypots.imputation import SAITS
from pypots.utils.metrics import calc_mae
saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, n_heads=4, d_k=64, d_v=64, d_ffn=128, dropout=0.1, epochs=10)
# Here I use the whole dataset as the training set because ground truth is not visible to the model, you can also split it into train/val/test sets
saits.fit(dataset)  # train the model on the dataset
imputation = saits.impute(dataset)  # impute the originally-missing values and artificially-missing values
indicating_mask = np.isnan(X) ^ np.isnan(X_ori)  # indicating mask for imputation error calculation
mae = calc_mae(imputation, np.nan_to_num(X_ori), indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)
saits.save("save_it_here/saits_physionet2012.pypots")  # save the model for future use
saits.load("save_it_here/saits_physionet2012.pypots")  # reload the serialized model file for following imputation or training