In [19]:
import matplotlib.pyplot as plt
import pandas as pd
import networkx as nx
import numpy as np
from CIoTS import *
from tqdm import trange

In [3]:
dimensions = 3
max_p = 4
incoming_edges = 2

In [17]:
results = pd.DataFrame(columns=['var_mse', 'var_train_bic', 'var_test_bic',
                                'chen_mse', 'chen_train_bic', 'chen_test_bic'])

In [48]:
for i in trange(100):
    generator = CausalTSGenerator(dimensions=dimensions, max_p=max_p, data_length=20000, incoming_edges=incoming_edges)
    ts = generator.generate()
    train_data, test_data = ts[:10000], ts[10000:]

    result = {}

    var_model = VAR(max_p)
    var_model.fit(train_data)
    result['var_train_bic'] = var_model.information_criterion('bic')
    result['var_mse'], result['var_test_bic'] = var_model.evaluate_test_set(train_data[:-4], test_data)

    mapping, data_matrix = transform_ts(train_data, max_p)
    chen_graph = pc_chen_modified(partial_corr_test, ts, max_p, alpha=0.05)
    chen_model = VAR(max_p)
    chen_model.fit_from_graph(dimensions, data_matrix, chen_graph, mapping)
    result['chen_train_bic'] = chen_model.information_criterion('bic')
    result['chen_mse'], result['chen_test_bic'] = chen_model.evaluate_test_set(train_data[:-4], test_data)

    results = results.append(result, ignore_index=True)
results

100%|██████████| 100/100 [02:51<00:00,  1.72s/it]


Unnamed: 0,var_mse,var_train_bic,var_test_bic,chen_mse,chen_train_bic,chen_test_bic
0,1.008025,-inf,-12.492394,1.008025,-18.154321,-11.823757
1,0.985310,-71.849700,-17.495620,0.985605,-20.516508,-14.330809
2,0.979584,-inf,-13.748290,0.982629,-12.687546,-9.169717
3,1.015809,-72.028986,-14.419848,1.016438,-19.157028,-13.846939
4,0.995819,-71.375559,-13.399563,0.995456,-16.558996,-12.428573
5,0.991970,-72.808430,-9.285520,0.991903,-15.562128,-8.969021
6,0.994647,-inf,-15.491953,0.994724,-18.127807,-14.223614
7,0.997721,-109.509359,-17.543876,0.998161,-14.044951,-11.782818
8,0.986040,-inf,-15.545268,0.985706,-17.417473,-12.951063
9,0.994354,-inf,-11.921205,0.994513,-17.955066,-11.135656


In [58]:
results.median()

var_mse            1.000781
var_train_bic    -72.230887
var_test_bic     -13.778173
chen_mse           1.004319
chen_train_bic   -16.289542
chen_test_bic    -11.250013
dtype: float64

In [59]:
results.to_csv('results/bic_test_set/results.csv')