# Missing Data Imputation Experiments with BLRHS

In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
import jax
jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp
import numpy as np
from lrhs.blrhs import cp_vb_missing
from lrhs.blrhs import get_trend, get_X_imputation_experiment, mape, rmse

Please download the data set shared by Chen et al. 2019 before starting the experiment. You can comment the cell below out if you have already downloaded the data.

Chen, X., He, Z., & Sun, L. (2019). A Bayesian tensor decomposition approach for spatiotemporal traffic data imputation. Transportation Research Part C: Emerging Technologies, 98, 73–84.


In [None]:
!wget https://github.com/mcgill-smart-transport/bgcp_imputation/raw/master/random_matrix.mat -P data/guangzhou/
!wget https://github.com/mcgill-smart-transport/bgcp_imputation/raw/master/random_tensor.mat -P data/guangzhou/
!wget https://github.com/mcgill-smart-transport/bgcp_imputation/raw/master/tensor.mat -P data/guangzhou/

The variables below correspond to the 1- experimental settings defined in the main paper, 2- model hyperparameters, or 3- data representation. Please refer to the original paper for instructions to replicate experiment results exactly. Results can very slightly change due to numerical precision differences between different hardware/software, but never to the extent that changes ordering between models.

In [3]:
all_results = {}
data_folder = "data/guangzhou/"
num_epochs = 20000
# Keep seed = 0 to replicate paper results
seed = 0 
# Ratio of data to be used as validation set for early stopping and model selection
es_ratio = 0.01
# Frequency with which validation performance will be checked
es_freq = 100
# VB epochs after which if there are no improvements the procedure will stop
es_epochs = 5000
# Prints out results with the said frequency, set an integer multiple of es_freq
print_freq = 1000

# Missing data scenario
missing_scenario = "fiber"
# Missing data ratio
missing_ratio = 0.5
# Data representation, select 3 for replicating paper results
num_total_dimensions = 3
# Detrending
detrending_type = "min" # available options ["min", "none"]

# Select rank
R = 40
# Select hyperparameter a
a = 12.5e4

X_censored, X, es_idx, test_idx, einsum = get_X_imputation_experiment(data_folder, missing_scenario, missing_ratio, num_total_dimensions, es_ratio)
trend = get_trend(X_censored, detrending_type)

elbo, val_evals, log_Ms = cp_vb_missing(X=jnp.array(X_censored - trend), R=R, a=a,
                              num_epochs=num_epochs, print_freq=print_freq, X_gt=X, es_epochs=es_epochs,
                              es_freq=es_freq, es_idx=es_idx, trend=trend, einsum=einsum, seed=seed)
res = [np.exp(r) for r in log_Ms[-(len(X.shape) + 1):]]

X_hat = np.einsum(einsum, *res)  * jnp.exp(log_Ms[0]) + trend
cur_rmse = rmse(X_hat[test_idx], X[test_idx])
cur_mape = mape(X_hat[test_idx], X[test_idx])
print(f"Experiment completed. Test results are:")
print(f"MAPE: {cur_mape:.6f}, RMSE: {cur_rmse:.6f}")

Epoch: 1000, ELBO: -2899648, Val MAPE:  0.106084, Val RMSE:  4.362101
Epoch: 2000, ELBO: -2893888, Val MAPE:  0.105677, Val RMSE:  4.352413
Epoch: 3000, ELBO: -2892192, Val MAPE:  0.105886, Val RMSE:  4.366357
Epoch: 4000, ELBO: -2891328, Val MAPE:  0.105834, Val RMSE:  4.362789
Epoch: 5000, ELBO: -2890784, Val MAPE:  0.105611, Val RMSE:  4.352549
Epoch: 6000, ELBO: -2890432, Val MAPE:  0.105218, Val RMSE:  4.335884
Epoch: 7000, ELBO: -2890048, Val MAPE:  0.104736, Val RMSE:  4.313793
Epoch: 8000, ELBO: -2889664, Val MAPE:  0.104387, Val RMSE:  4.296974
Epoch: 9000, ELBO: -2889600, Val MAPE:  0.104193, Val RMSE:  4.286645
Epoch: 10000, ELBO: -2889504, Val MAPE:  0.104081, Val RMSE:  4.279346
Epoch: 11000, ELBO: -2889280, Val MAPE:  0.104004, Val RMSE:  4.273397
Epoch: 12000, ELBO: -2889280, Val MAPE:  0.103964, Val RMSE:  4.269267
Epoch: 13000, ELBO: -2889376, Val MAPE:  0.103953, Val RMSE:  4.267170
Epoch: 14000, ELBO: -2889280, Val MAPE:  0.103951, Val RMSE:  4.267012
Epoch: 15000, E