In [None]:
import sys
import os
import pandas as pd
import numpy as np
from scipy.interpolate import interp1d
import copy
import datetime
import pathlib
from argparse import Namespace
import yaml
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.optim.lr_scheduler import CosineAnnealingLR
# sys.path.insert(1, os.path.abspath("./"))
from lib import fillers, datasets
from lib.data.datamodule import SpatioTemporalDataModule
from lib.data.imputation_dataset import GraphImputationDataset
from lib.nn import models
from lib.nn.utils.metric_base import MaskedMetric
from lib.nn.utils.metrics import MaskedMAE, MaskedMAPE, MaskedMSE, MaskedMRE
from lib.utils import parser_utils, numpy_metrics, ensure_list, prediction_dataframe

In [None]:
# Z-Score normalization
def normalization(data):
    temp = np.array(data)
    mean = np.nanmean(temp, axis=0)  # Calculate mean ignoring NaNs
    std = np.nanstd(temp, axis=0)    # Calculate std ignoring NaNs
    temp_masked = np.ma.masked_invalid(temp)  # Mask NaNs
    temp = (temp_masked - mean) / (std + 1e-8)  # Prevent division by zero
    temp = temp.filled(np.nan)  # Replace mask with NaNs
    return temp, mean, std

# Z-Score normalization using given mean and std
def normalization_with_min_max(data, mean, std):
    temp = np.array(data)
    temp_masked = np.ma.masked_invalid(temp)  # Mask NaNs
    temp = (temp_masked - mean) / (std + 1e-8)  # Prevent division by zero
    temp = temp.filled(np.nan)  # Replace mask with NaNs
    return temp, mean, std

# Reverse Z-Score normalization
def denormalization_with_min_max(data, mean, std):
    return data * (std + 1e-8) + mean

# Calculate metrics and save results
def calculate_metrics(true_data, filled_data, miss_mask, x_min, x_max):
    mask_indices = np.where(miss_mask == 1)  # Indices of missing values
    
    # Metrics for original data
    rmse = np.sqrt(np.mean((true_data[mask_indices] - filled_data[mask_indices]) ** 2))
    mae = np.mean(np.abs(true_data[mask_indices] - filled_data[mask_indices]))
    mape = np.mean(np.abs((true_data[mask_indices] - filled_data[mask_indices]) / true_data[mask_indices])) * 100

    # Metrics for normalized data
    normalized_true_data, _, _ = normalization_with_min_max(true_data, x_min, x_max)
    normalized_filled_data, _, _ = normalization_with_min_max(filled_data, x_min, x_max)

    normalized_rmse = np.sqrt(np.mean((normalized_true_data[mask_indices] - normalized_filled_data[mask_indices]) ** 2))
    normalized_mae = np.mean(np.abs(normalized_true_data[mask_indices] - normalized_filled_data[mask_indices]))
    normalized_mape = np.mean(np.abs((normalized_true_data[mask_indices] - normalized_filled_data[mask_indices]) / normalized_true_data[mask_indices])) * 100
    return mae, rmse, mape, normalized_mae, normalized_rmse, normalized_mape


# Main model function
def DGBRIN(X, TEST, save=False):
    ########################################
    # Transform data and add data column labels
    ########################################
    split_idx = [len(X), len(X) + len(TEST)]
    X = np.concatenate((X, TEST), axis=0)
    df = pd.DataFrame(X)
    n = df.shape[1]
    col_names = list(range(1, n + 1))
    df.columns = col_names
    # if the data is none
    df.index = pd.date_range(start='2024-01-01', periods=len(df), freq='5T')
    df.to_hdf('./datasets/Fly/Fly.h5', key='data', format='table')
    
    args = parse_args()
    args.split_idx = split_idx
    print("Arguments:", args)
    args = copy.deepcopy(args)
    torch.set_num_threads(1)
    
    model_cls, filler_cls = models.DGBRIN, fillers.GraphFiller
    dataset = datasets.MissingValuesData()
    dataset.numpy()[:] = fill_nan_linear_interpolation_2D(dataset.numpy(), 3)

    ########################################
    # Create log directory and save configuration
    ########################################
    exp_name = f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{args.seed}"
    logdir = os.path.join('logs', args.dataset_name, args.model_name, exp_name)
    
    pathlib.Path(logdir).mkdir(parents=True)
    with open(os.path.join(logdir, 'config.yaml'), 'w') as fp:
        yaml.dump(parser_utils.config_dict_from_args(args), fp, indent=4, sort_keys=True)

    ########################################
    # Data module construction
    ########################################
    dataset_cls = GraphImputationDataset
    torch_dataset = dataset_cls(*dataset.numpy(return_idx=True),
                                mask=dataset.training_mask,
                                eval_mask=dataset.eval_mask,
                                window=args.window,
                                stride=args.stride)
    
    train_idxs = range(args.split_idx[0] - args.window + 1)
    val_idxs = train_idxs
    test_idxs = range(args.split_idx[0], args.split_idx[1] - args.window + 1)

    data_conf = parser_utils.filter_args(args, SpatioTemporalDataModule, return_dict=True)
    dm = SpatioTemporalDataModule(torch_dataset, train_idxs=train_idxs, val_idxs=val_idxs, test_idxs=test_idxs, **data_conf)
    dm.setup()

    ########################################
    # Predictor
    ########################################
    additional_model_hparams = dict(d_in=dm.d_in, n_nodes=dm.n_nodes)
    model_kwargs = parser_utils.filter_args(args={**vars(args), **additional_model_hparams},
                                            target_cls=model_cls,
                                            return_dict=True)
    model_kwargs['number_of_samples'] = torch_dataset.data.shape[0] * 2
    model_kwargs['time_step'] = TIMESTEPS
    model_kwargs['fuse_dim'] = FUSE_DIM

    loss_fn = MaskedMetric(metric_fn=getattr(F, args.loss_fn),
                           compute_on_step=True,
                           metric_kwargs={'reduction': 'none'})
   
    metrics = {
        'mae': MaskedMAE(compute_on_step=False),
        'mape': MaskedMAPE(compute_on_step=False),
        'mse': MaskedMSE(compute_on_step=False),
        'mre': MaskedMRE(compute_on_step=False)
    }

    scheduler_class = CosineAnnealingLR if args.use_lr_schedule else None
    additional_filler_hparams = dict(model_class=model_cls,
                                     model_kwargs=model_kwargs,
                                     optim_class=torch.optim.Adam,
                                     optim_kwargs={'lr': args.lr, 'weight_decay': args.l2_reg},
                                     loss_fn=loss_fn,
                                     metrics=metrics,
                                     scheduler_class=scheduler_class,
                                     scheduler_kwargs={'eta_min': 0.0001, 'T_max': args.epochs},
                                     alpha=args.alpha,
                                     hint_rate=args.hint_rate,
                                     g_train_freq=args.g_train_freq,
                                     d_train_freq=args.d_train_freq)
    filler_kwargs = parser_utils.filter_args(args={**vars(args), **additional_filler_hparams},
                                             target_cls=filler_cls,
                                             return_dict=True)
    filler = filler_cls(**filler_kwargs)

    ########################################
    # Training
    ########################################
    early_stop_callback = EarlyStopping(monitor='val_mae', patience=args.patience, mode='min')
    checkpoint_callback = ModelCheckpoint(dirpath=logdir, save_top_k=1, monitor='val_mae', mode='min')

    logger = TensorBoardLogger(logdir, name="model")

    trainer = pl.Trainer(max_epochs=args.epochs,
                         logger=logger,
                         default_root_dir=logdir,
                         gradient_clip_val=args.grad_clip_val,
                         gradient_clip_algorithm=args.grad_clip_algorithm,
                         callbacks=[early_stop_callback, checkpoint_callback])
    trainer.fit(filler, datamodule=dm)

    ########################################
    # Testing
    ########################################
    filler.load_state_dict(torch.load(checkpoint_callback.best_model_path, lambda storage, loc: storage)['state_dict'])
    filler.freeze()
    trainer.test(model=filler, datamodule=dm)
    filler.eval()

    if torch.cuda.is_available():
        filler.cuda()
    with torch.no_grad():
        y_true, y_hat, mask = filler.predict_loader(dm.test_dataloader(), return_mask=True)
    y_hat = y_hat.detach().cpu().numpy().reshape(y_hat.shape[:3])
    y_true = y_true.detach().cpu().numpy().reshape(y_hat.shape[:3])
    mask = mask.detach().cpu().numpy().reshape(y_hat.shape[:3])

    train_mask = dataset.mask[dm.test_slice]
    eval_mask_ = dataset.eval_mask[dm.test_slice]
    eval_mask = 1 - (train_mask | eval_mask_)
    df_true = dataset.df.iloc[dm.test_slice]
    metrics = {
        'mae': numpy_metrics.masked_mae,
        'mse': numpy_metrics.masked_mse,
        'mre': numpy_metrics.masked_mre,
        'mape': numpy_metrics.masked_mape
    }

    index = dm.torch_dataset.data_timestamps(dm.testset.indices, flatten=False)['horizon']
    aggr_methods = ensure_list(args.aggregate_by)
    df_hats = prediction_dataframe(y_hat, index, dataset.df.columns, aggregate_by=aggr_methods)
    
    df_hats = dict(zip(aggr_methods, df_hats))
    for aggr_by, df_hat in df_hats.items():
        print(f'- AGGREGATE BY {aggr_by.upper()}')
        for metric_name, metric_fn in metrics.items():
            error = metric_fn(df_hat.values, df_true.values, eval_mask).item()
            print(f' {metric_name}: {error:.4f}')
    return df_hat.values, eval_mask

# Linear interpolation for 2D data to pre-fill NaNs
def fill_nan_linear_interpolation_2D(data, prev_count):
    for col in range(data.shape[1]):
        arr = data[:, col]
        n = len(arr)
        start_value = None
        end_value = None
        for i in range(n):
            if np.isnan(arr[i]):
                prev_values = []
                prev_indices = []
                for j in range(i-1, -1, -1):
                    if len(prev_values) >= prev_count:
                        break
                    if not np.isnan(arr[j]):
                        prev_values.append(arr[j])
                        prev_indices.append(j)
                next_value = None
                next_index = None
                for j in range(i+1, n):
                    if not np.isnan(arr[j]):
                        next_value = arr[j]
                        next_index = j
                        break
                if i == 0 and len(prev_values) == 0 and next_value is not None:
                    start_value = next_value
                if i == n - 1 and next_value is None and len(prev_values) > 0:
                    end_value = prev_values[-1]
                if len(prev_values) == 0 or next_value is None:
                    continue
                x = prev_indices + [next_index]
                y = prev_values + [next_value]
                f = interp1d(x, y, kind='linear')
                arr[i] = f(i)
        if start_value is not None and np.isnan(arr[0]):
            for index in range(len(arr)):
                if np.isnan(arr[index]):
                    arr[index] = start_value
                else:
                    break
        if end_value is not None and np.isnan(arr[-1]):
            for index in range(len(arr) - 1, -1, -1):
                if np.isnan(arr[index]):
                    arr[index] = end_value
                else:
                    break
        data[:, col] = arr
    return data

In [None]:
# Constants and file paths
MODEL_NAME = 'DGBRIN'
FUSE_DIM = 30
TIMESTEPS = 30
TRAIN_DATA_PATH = f'timeseries_datasets/example_train.csv'
TEST_DATA_PATH = f'timeseries_datasets/example_test.csv'
REAL_DATA_PATH = f'timeseries_datasets/example_real.csv'

# Parse arguments for the model
def parse_args():
    args_dict = {
        "seed": 1,
        "model_name": 'DGBRIN',
        "dataset_name": 'fly',
        "in_sample": False,
        "val_len": 0.1,
        "test_len": 0,
        "aggregate_by": 'mean',
        "lr": 0.001,
        "epochs": 200,
        "patience": 20,
        "l2_reg": 0.,
        "scaled_target": True,
        "grad_clip_val": 5.,
        "grad_clip_algorithm": 'norm',
        "loss_fn": 'l1_loss',
        "use_lr_schedule": True,
        "consistency_loss": False,
        "whiten_prob": 0.05,
        "pred_loss_weight": 1.0,
        "warm_up": 0,
        "adj_threshold": 0.1,
        "alpha": 20.,
        "hint_rate": 0.7,
        "g_train_freq": 1,
        "d_train_freq": 5,
        "batch_size": 32,
        "window": TIMESTEPS,
        "horizon": 24,
        "delay": 0,
        "stride": 1,
        "scaling_axis": "channels",
        "scaling_type": "std",
        "scale": True,
        "workers": 0,
        "samples_per_epoch": None,
    }
    args = Namespace(**args_dict)
    model_cls = models.DGBRIN
    args = model_cls.add_model_specific_args(args)
    return args


# Load data
X = np.array(pd.read_csv(TRAIN_DATA_PATH))
TEST = np.array(pd.read_csv(TEST_DATA_PATH))
R = np.array(pd.read_csv(REAL_DATA_PATH))

# Normalize data
x_zz_norm, x_min, x_max = normalization(X.copy())  # Use normalized data with missing values
test_zz_norm, _, _ = normalization_with_min_max(TEST, x_min, x_max)

# Model inference
time_steps, num_nodes = TEST.shape
y_zz_norm, MASK = DGBRIN(x_zz_norm, test_zz_norm, save=True)  # Model inference with normalized data
Y = denormalization_with_min_max(y_zz_norm, x_min, x_max)  # Denormalize results

# Analyze and save results
mae, rmse, mape, normalized_mae, normalized_rmse, normalized_mape = calculate_metrics(
    R, Y, MASK, x_min, x_max)

# Print results
print("MAE:", mae)
print("RMSE:", rmse)
print("MAPE:", mape)
print("Normalized MAE:", normalized_mae)
print("Normalized RMSE:", normalized_rmse)
print("Normalized MAPE:", normalized_mape)