In [1]:
import sys
sys.path.append("/home/jarlehti/projects/gradu")

In [2]:
import os
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import d3p
from src.napsu_mq.napsu_mq import NapsuMQModel, NapsuMQResult

In [3]:
CURRENT_FOLDER = os.path.dirname(os.path.dirname(os.path.abspath(__name__)))
DATASETS_FOLDER = os.path.join(CURRENT_FOLDER, "data", "datasets")
MODELS = os.path.join(CURRENT_FOLDER, "models")

In [4]:
rng = jax.random.PRNGKey(452345235)

In [5]:
dataset = pd.read_csv(os.path.join(DATASETS_FOLDER, "binary4d.csv"))
n, d = dataset.shape
print(n)
print(d)

100000
4


In [6]:
column_feature_set = [
    ('A', 'D'), 
    ('B', 'D'), 
    ('C', 'D'), 
]

model = NapsuMQModel()
result = model.fit(
    data=dataset,
    dataset_name="binary4d",
    rng=rng,
    epsilon=50,
    delta=(n ** (-2)),
    column_feature_set=column_feature_set,
    use_laplace_approximation=True,
    laplace_approximation_algorithm="jaxopt_LBFGS",

)

No experiment_id found: <ContextVar name='experiment_id' at 0x7f01a0395e50>
Setting experiment_id to 13F45B7N
(100000, 4)
Dataframe data n: 100000
Dataframe data d: 4
Domain size: 16
100000
4
Recording:  Query selection
start MST selection


  epsilon = np.sqrt(8 * rho / (r - 1))


end MST selection
Recording:  Calculating full marginal query
Recording:  Calculating canonical query set
Calculating canonical queries


  0%|                                                                                                                                                                  | 0/8 [00:00<?, ?it/s]
4it [00:00, 6729.73it/s]

2it [00:00, 8136.38it/s]

4it [00:00, 14488.10it/s]

4it [00:00, 17772.47it/s]

2it [00:00, 4637.15it/s]

2it [00:00, 5555.37it/s]

2it [00:00, 5687.19it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 539.94it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 61984.79it/s]


Calculating new queries


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 9765.55it/s]


Canonical queries: 7
Domain d: 4
Nodes after elimination: {('A', 'D'), ('B', 'D'), ('D', 'C'), ('D',)}
Edges after elimination: {(('A', 'D'), ('D',)): {'D'}, (('B', 'D'), ('D',)): {'D'}, (('D', 'C'), ('D',)): {'D'}}
Nodes after removing: {('A', 'D'), ('B', 'D'), ('D', 'C')}
Edges after removing: {(('A', 'D'), ('B', 'D')): {'D'}, (('A', 'D'), ('D', 'C')): {'D'}}
Suff stat d: 7
Lambda d: 7
start lambda0
end lambda0
start suff stat mean and cov
end suff stat mean and cov
[TreeNode(variables=('B', 'D'), parent=('A', 'D'), children=[]), TreeNode(variables=('D', 'C'), parent=('A', 'D'), children=[])]
1
Junction tree width: 2
(7,)
0.44162884978050604
Recording:  Laplace approximation
Started Jaxopt Laplace approximation
Attempting Laplace approximation, 0th try
Initialising model done
Minimising potential function
{'lambdas': DeviceArray([-0.01410392, -0.01932848,  0.28121972, -0.68804395,
              0.07540693,  0.39751783, -1.5307403 ], dtype=float32)}
LbfgsState(iter_num=DeviceArray(100

  mcmc = numpyro.infer.MCMC(



                     mean       std    median      5.0%     95.0%     n_eff     r_hat
norm_lambdas[0]     -0.00      0.99     -0.01     -1.58      1.63  12069.24      1.00
norm_lambdas[1]     -0.03      1.02     -0.02     -1.77      1.59  12782.77      1.00
norm_lambdas[2]     -0.00      1.00     -0.01     -1.57      1.70  11899.02      1.00
norm_lambdas[3]     -0.01      0.99     -0.02     -1.70      1.57  11671.73      1.00
norm_lambdas[4]     -0.01      1.00     -0.00     -1.59      1.69  11751.86      1.00
norm_lambdas[5]      0.00      0.98      0.01     -1.63      1.56  12084.04      1.00
norm_lambdas[6]      0.02      0.99      0.01     -1.68      1.59  13881.00      1.00

Number of divergences: 0
[norm_lambdas]	 max r_hat: 0.9983
Potential energy	min: [93.84338379  0.          0.          0.        ]	max: [106.05001068   0.           0.           0.        ]	mean: [97.17044067  0.          0.          0.        ]	std: [1.82203531 0.         0.         0.        ]
Acceptance pr

In [7]:
result_path = os.path.join(MODELS, "napsu_testing_binary4d_all_queries_500e_NUTS_LA.dill")
result.store(result_path)

In [8]:
napsu_result_read_file = open(os.path.join(MODELS, "napsu_testing_binary4d_all_queries_500e_NUTS_LA.dill"), "rb")
loaded_result = NapsuMQResult.load(napsu_result_read_file)
sampling_rng = jax.random.PRNGKey(234513465234)
datasets = loaded_result.generate_extended(rng=sampling_rng, num_data_per_parameter_sample=2000, num_parameter_samples=5)

Generating data with 2000 points and 5 datasets


In [9]:
datasets

[      A  B  C  D
 0     0  0  1  1
 1     1  0  1  1
 2     0  1  0  1
 3     1  1  1  1
 4     1  0  1  1
 ...  .. .. .. ..
 1995  0  0  1  1
 1996  1  1  1  1
 1997  0  0  0  0
 1998  1  1  1  1
 1999  0  1  0  1
 
 [2000 rows x 4 columns],
       A  B  C  D
 0     0  1  1  1
 1     1  0  1  1
 2     0  0  0  0
 3     0  0  1  1
 4     1  0  0  0
 ...  .. .. .. ..
 1995  1  0  0  1
 1996  0  1  1  1
 1997  1  0  0  1
 1998  0  1  0  0
 1999  1  1  1  1
 
 [2000 rows x 4 columns],
       A  B  C  D
 0     1  1  1  1
 1     0  0  1  1
 2     1  0  1  1
 3     0  1  1  1
 4     0  1  1  1
 ...  .. .. .. ..
 1995  0  0  1  1
 1996  1  1  1  1
 1997  1  1  1  1
 1998  1  0  0  1
 1999  0  1  1  1
 
 [2000 rows x 4 columns],
       A  B  C  D
 0     0  1  1  0
 1     0  1  1  1
 2     0  1  0  1
 3     0  1  0  1
 4     1  0  1  1
 ...  .. .. .. ..
 1995  1  0  1  1
 1996  0  0  0  0
 1997  1  1  1  1
 1998  1  0  0  1
 1999  0  0  1  1
 
 [2000 rows x 4 columns],
       A  B  C  D
 0    