In [10]:
import numpy as np
import pandas as pd

# pipiline core 
from sklearn.model_selection import GridSearchCV

from pypots.data import load_specific_dataset
from pypots.imputation import SAITS, Transformer

import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots

In [11]:
# Data preprocessing. Tedious, but PyPOTS can help.
df = pd.read_csv('../csv_dump/P_.49_823.csv')  # PyPOTS will automatically download and extract it.
df = df.rename(columns={"Unnamed: 0": "date"})

df.shape


(34, 9)

In [29]:
X = df.drop(['date'], axis = 1)
X = (X.to_numpy()).reshape(-1,X.shape[0],X.shape[1])
#X_intact, X, missing_mask, indicating_mask = mcar(X, 0.1) # hold out 10% observed values as ground truth
#X = masked_fill(X, 1 - missing_mask, np.nan)
dataset = {"X": X}
X.shape

(1, 34, 8)

In [84]:
# Model training. This is PyPOTS showtime.
saits = Transformer(n_steps=X.shape[1], n_features=X.shape[2], n_layers=4, d_model=128, d_inner=256, n_heads=4, d_k=64, d_v=64, dropout=0.1, epochs=10)
# Here I use the whole dataset as the training set because ground truth is not visible to the model, you can also split it into train/val/test sets
saits.fit(dataset)
imputation = saits.predict(dataset)  # impute the originally-missing values and artificially-missing values


2023-11-11 09:41:00 [INFO]: No given device, using default device: cpu
2023-11-11 09:41:00 [INFO]: Model initialized successfully with the number of trainable parameters: 793,224
2023-11-11 09:41:00 [INFO]: epoch 0: training loss 14.7010
2023-11-11 09:41:00 [INFO]: epoch 1: training loss 13.7498
2023-11-11 09:41:00 [INFO]: epoch 2: training loss 15.1572
2023-11-11 09:41:00 [INFO]: epoch 3: training loss 13.3003
2023-11-11 09:41:00 [INFO]: epoch 4: training loss 13.1854
2023-11-11 09:41:00 [INFO]: epoch 5: training loss 13.7621
2023-11-11 09:41:00 [INFO]: epoch 6: training loss 11.8193
2023-11-11 09:41:00 [INFO]: epoch 7: training loss 11.6262
2023-11-11 09:41:01 [INFO]: epoch 8: training loss 13.4413
2023-11-11 09:41:01 [INFO]: epoch 9: training loss 12.7822
2023-11-11 09:41:01 [INFO]: epoch 10: training loss 12.5912
2023-11-11 09:41:01 [INFO]: epoch 11: training loss 12.8752
2023-11-11 09:41:01 [INFO]: epoch 12: training loss 13.0443
2023-11-11 09:41:01 [INFO]: epoch 13: training loss

In [85]:
imputation['imputation'].reshape(-1,X.shape[2]).shape

(34, 8)

In [86]:
inputed_df = pd.DataFrame(imputation['imputation'].reshape(-1,X.shape[2]), columns = df.keys().tolist()[0:-1])
inputed_df['date']=df['date']

In [87]:
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(go.Scatter(
        name='NDVI',
        x=inputed_df['date'],
        y=inputed_df['ndvi'],
        mode='lines',
        line=dict(color='rgb(31, 119, 180)'),
    ))
fig.add_trace(go.Scatter(
        name='NDWI',
        x=inputed_df['date'],
        y=inputed_df['ndwi'],
        mode='lines',
        line=dict(color='rgb(90, 200, 70)'),
    ))

fig.add_trace(go.Scatter(
        name='NDWI_original',
        x=df['date'],
        y=df['ndwi'],
        mode='lines',
        line=dict(color='rgb(90, 200, 270)'),
    ))

fig.add_trace(go.Scatter(
        name='NDVI_original',
        x=df['date'],
        y=df['ndvi'],
        mode='lines',
        line=dict(color='rgb(31, 119, 010)'),
    ))