In [1]:
import torch
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from tsl import logger
from tsl.data import ImputationDataset, SpatioTemporalDataModule
from tsl.data.preprocessing import StandardScaler
from tsl.datasets import MetrLA, PemsBay, AirQuality
from electric_data import ElectricData
from tsl.engines import Imputer
from tsl.experiment import Experiment
from tsl.metrics import torch as torch_metrics, numpy as numpy_metrics
from tsl.nn.models import RNNImputerModel, BiRNNImputerModel, GRINModel
from tsl.ops.imputation import add_missing_values
from tsl.transforms import MaskInput
from tsl.utils.casting import torch_to_numpy



In [2]:
p_fault, p_noise = 0.0015, 0.05
dataset = add_missing_values(ElectricData(normalized=True), p_fault=p_fault, p_noise=p_noise, min_seq=12, max_seq=12 * 4, seed=56789)

print(f'\n\n{dataset}')
print(f"Sampling period: {dataset.freq}\n"
      f"Has missing values: {dataset.has_mask}\n"
      f"Percentage of missing values: {(1 - dataset.mask.mean()) * 100:.2f}%\n"
      f"Percentage of missing values val: {(dataset.eval_mask.mean()) * 100:.2f}%\n"
      f"Has dataset exogenous variables: {dataset.has_covariates}\n"
      f"Relevant attributes: {', '.join(dataset.attributes.keys())}\n\n")



MissingValueselectric(length=2208, n_nodes=6, n_channels=1)
Sampling period: None
Has missing values: True
Percentage of missing values: 0.00%
Percentage of missing values val: 9.10%
Has dataset exogenous variables: True
Relevant attributes: 




In [3]:
# get adjacency matrix
adj = dataset.get_connectivity('pearson')

In [4]:
torch_dataset = ImputationDataset(target=dataset.dataframe(),
                                      eval_mask=dataset.eval_mask,
                                      input_mask=dataset.training_mask,
                                      transform=MaskInput(),
                                      connectivity=adj,
                                      window=24,
                                      stride=1)


In [5]:
scalers = {
    'target': StandardScaler(axis=(0, 1))
}
splitter = dataset.get_splitter(val_len=0.1, test_len=0.2)

dm = SpatioTemporalDataModule(
    dataset=torch_dataset,
    #scalers=scalers,
    splitter=splitter,
    batch_size=32,
    workers=0
)
dm

SpatioTemporalDataModule(train_len=None, val_len=None, test_len=None, scalers=[], batch_size=32)

In [6]:
dm.setup()
dm

SpatioTemporalDataModule(train_len=1550, val_len=150, test_len=437, scalers=[], batch_size=32)

In [7]:
train = dm.train_dataloader()

In [8]:
for i, batch in enumerate(train):
    print(i, batch)
    break

0 StaticBatch(
  input=(x=[b=32, t=24, n=6, f=1], input_mask=[b=32, t=24, n=6, f=1], edge_index=[2, e=36], edge_weight=[e=36]),
  target=(y=[b=32, t=24, n=6, f=1]),
  has_mask=True
)


In [9]:
it = iter(train)

In [10]:
batch = next(it)
batch

StaticBatch(
  input=(x=[b=32, t=24, n=6, f=1], input_mask=[b=32, t=24, n=6, f=1], edge_index=[2, e=36], edge_weight=[e=36]),
  target=(y=[b=32, t=24, n=6, f=1]),
  has_mask=True
)

In [11]:
batch.input.x[0,:3,:]

tensor([[[0.6713],
         [0.7015],
         [0.1856],
         [0.0351],
         [0.5445],
         [0.5789]],

        [[0.7011],
         [0.7252],
         [0.2199],
         [0.0311],
         [0.4819],
         [0.0000]],

        [[0.7165],
         [0.7649],
         [0.2206],
         [0.0328],
         [0.4439],
         [0.0000]]])

In [14]:
batch.target.y[0,0:3,:] == batch.input.x[0,0:3,:]

tensor([[[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True]],

        [[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False]],

        [[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False]]])

In [47]:
batch.input.input_mask[0,0:3,:]

tensor([[[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True]],

        [[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False]],

        [[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True]]])

In [50]:
batch.input.edge_weight.shape

torch.Size([36])

In [62]:
model_kwargs = {
    'input_size':dm.n_channels,
    'n_nodes':dm.n_nodes,
    'hidden_size': 64,
    'ff_size': 64,
    'embedding_size': 8,
    'n_layers': 1,
    'kernel_size': 2,
    'decoder_order': 1,
    'layer_norm': False,
    'dropout': 0,
    'ff_dropout': 0,
    'merge_mode': 'mlp'}

from tsl.metrics.torch import MaskedMAE, MaskedMAPE
from tsl.engines import Imputer
from tsl.nn.models import GRINModel

optim_kwargs = {'lr': 0.001, 'weight_decay': 0}
scheduler_kwargs = {'eta_min': 0.0001, 'T_max': 300}
loss_fn = MaskedMAE()
log_metrics = {'mae': MaskedMAE(),
           'mape': MaskedMAPE(),}

imputer = Imputer(
    model_class=GRINModel,
    model_kwargs=model_kwargs,
    optim_class=torch.optim.Adam,
    optim_kwargs=optim_kwargs,
    loss_fn=loss_fn,
    metrics=log_metrics,
    scheduler_class=torch.optim.lr_scheduler.CosineAnnealingLR,
    scheduler_kwargs=scheduler_kwargs,
    whiten_prob=0.05,
    prediction_loss_weight=1.0,
    impute_only_missing=False,
    warm_up_steps=0,
    #
    )

In [68]:
y = imputer(
    x=batch.input.x,
    edge_index=batch.input.edge_index,
    edge_weight=batch.input.edge_weight,
    input_mask=batch.input.input_mask,
)

In [75]:
prediction = y[0]
prediction.shape

torch.Size([32, 24, 6, 1])

In [76]:
fwd_out, bwd_out, fwd_pred, bwd_pred = y[1]

In [78]:
print(fwd_out.shape,fwd_pred.shape)

torch.Size([32, 24, 6, 1]) torch.Size([32, 24, 6, 1])
