In [1]:
import sys
sys.path.append("/home/jarlehti/projects/gradu")

In [2]:
import os
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import d3p
from src.napsu_mq.napsu_mq import NapsuMQModel, NapsuMQResult
from src.utils.preprocess_dataset import get_adult_train_small, get_adult_train_large

In [3]:
CURRENT_FOLDER = os.path.dirname(os.path.dirname(os.path.abspath(__name__)))
DATASETS_FOLDER = os.path.join(CURRENT_FOLDER, "data", "datasets")
MODELS = os.path.join(CURRENT_FOLDER, "models")

In [4]:
rng = jax.random.PRNGKey(32346234634265345)

In [5]:
adult_small = get_adult_train_small(DATASETS_FOLDER)

In [6]:
adult_small

Unnamed: 0,age,education-num,marital-status,sex,hours-per-week,compensation
0,35 - 40,13,Never-married,Male,40 - 45,0
1,50 - 55,13,Married-civ-spouse,Male,10 - 15,0
2,35 - 40,9,Divorced,Male,40 - 45,0
3,50 - 55,7,Married-civ-spouse,Male,40 - 45,0
4,25 - 30,13,Married-civ-spouse,Female,40 - 45,0
...,...,...,...,...,...,...
30157,25 - 30,12,Married-civ-spouse,Female,35 - 40,0
30158,40 - 45,9,Married-civ-spouse,Male,40 - 45,1
30159,55 - 60,9,Widowed,Female,40 - 45,0
30160,20 - 25,9,Never-married,Male,20 - 25,0


In [7]:
adult_small.dtypes

age               category
education-num     category
marital-status    category
sex               category
hours-per-week    category
compensation      category
dtype: object

In [None]:
n, d = adult_small.shape

for i in range(1, len(adult_small.columns)):
    columns = list(adult_small.columns)[0:i]
    columns.extend(['compensation'])
    print(f"Testing Laplace approximation and MCMC with columns: {columns}")
    
    adult_subset = adult_small[columns]
    
    columns_str = "_".join(columns)
    
    model = NapsuMQModel()
    result, timer = model.fit(
        data=adult_subset,
        dataset_name="adult_small",
        rng=rng,
        epsilon=50,
        delta=(n ** (-2)),
        column_feature_set=[],
        use_laplace_approximation=True,
        laplace_approximation_algorithm="jaxopt_LBFGS",
        return_timer = True,
    )
    
    result_path = os.path.join(MODELS, f"napsu_testing_adult_small_all_queries_{columns_str}_50e_LA.dill")
    timer.to_csv(f"napsu_testing_timer_adult_small_all_queries_{columns_str}_50e_LA")
    result.store(result_path)

Testing Laplace approximation and MCMC with columns: ['age', 'compensation']
No experiment_id found: <ContextVar name='experiment_id' at 0x7fbb70f2a720>
Setting experiment_id to L2PZ6VBQ
(30162, 2)
Dataframe data n: 30162
Dataframe data d: 2
Domain size: 40
30162
2
Recording:  Query selection
start MST selection
end MST selection
Recording:  Calculating full marginal query
Recording:  Calculating canonical query set
Calculating canonical queries


  0%|                                                                                                                 | 0/4 [00:00<?, ?it/s]
20it [00:00, 42027.09it/s]

2it [00:00, 18808.54it/s]

40it [00:00, 23851.60it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 514.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 15477.14it/s]


Calculating new queries


100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 47907.53it/s]

Canonical queries: 39
Domain d: 2
Nodes after elimination: {('age',), ('compensation', 'age')}
Edges after elimination: {(('compensation', 'age'), ('age',)): {'age'}}
Nodes after removing: {('compensation', 'age')}
Edges after removing: {}
Suff stat d: 39
Lambda d: 39
Recording:  Calculating lambda0





{'experiment_id': 'L2PZ6VBQ', 'start': 427948.027620747, 'stop': 427948.292537682, 'timedelta': 0.2649169350042939, 'task': 'Calculating lambda0', 'dataset_name': 'adult_small', 'query_str': 'empty', 'query_list': [], 'epsilon': 50, 'delta': 1.0992076159646117e-09, 'MCMC_algo': 'NUTS', 'laplace_approximation': True, 'missing_query': None, 'discretization': None, 'n_canonical_queries': 39}
Recording:  Calculating lambda0
{'experiment_id': 'L2PZ6VBQ', 'start': 427948.292616976, 'stop': 427950.831624764, 'timedelta': 2.5390077880001627, 'task': 'Calculating lambda0', 'dataset_name': 'adult_small', 'query_str': 'empty', 'query_list': [], 'epsilon': 50, 'delta': 1.0992076159646117e-09, 'MCMC_algo': 'NUTS', 'laplace_approximation': True, 'missing_query': None, 'discretization': None, 'n_canonical_queries': 39}
[]
0
Junction tree width: 0
(39,)
0.24685422448289715
Recording:  Laplace approximation
Started Jaxopt Laplace approximation
Attempting Laplace approximation, 0th try
Initialising mode

2023-03-13 13:29:36.565918: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_prim_fun.104] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 13:30:00.051622: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 2m23.485779777s

********************************
[Compiling module jit_prim_fun.104] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 13:42:13.644105: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 7m36.482896603s

********************************
[Compiling module jit_prim_fun.137] Very slow compile?  If you want to file a bug, run with envvar XLA_FLA

Failed linesearch, try again
Attempting Laplace approximation, 1th try
Initialising model done
Minimising potential function


2023-03-13 13:45:56.495351: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 2m24.731889187s

********************************
[Compiling module jit_prim_fun.145] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 13:52:39.455021: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_prim_fun.147] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 13:58:19.338523: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 7m39.883661075s

********************************
[Compiling module jit_prim_fun.147] Very slow compile?  If you want to file a bug, run with envvar XLA_FLA

Failed linesearch, try again
Attempting Laplace approximation, 2th try
Initialising model done
Minimising potential function


2023-03-13 14:02:05.431296: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 2m26.56849226s

********************************
[Compiling module jit_prim_fun.150] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 14:08:48.467687: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_prim_fun.152] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 14:14:26.969491: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 7m38.501968301s

********************************
[Compiling module jit_prim_fun.152] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAG

Calculating Hessian
Recording:  MCMC


  mcmc = numpyro.infer.MCMC(



                      mean       std    median      5.0%     95.0%     n_eff     r_hat
 norm_lambdas[0]     -0.91      1.44     -0.64     -3.12      1.24   5479.73      1.00
 norm_lambdas[1]     -1.91      2.50     -1.39     -5.67      1.66   3814.66      1.00
 norm_lambdas[2]      0.73      1.51      0.80     -1.04      2.87    517.97      1.00
 norm_lambdas[3]      1.30      1.46      1.15     -0.97      3.33    935.31      1.00
 norm_lambdas[4]      0.25      1.02      0.25     -1.40      1.94   5754.60      1.00
 norm_lambdas[5]      0.07      1.01      0.06     -1.62      1.76  10358.86      1.00
 norm_lambdas[6]      0.03      1.00      0.02     -1.66      1.62   9693.72      1.00
 norm_lambdas[7]      0.01      1.02      0.02     -1.62      1.65  11568.98      1.00
 norm_lambdas[8]      0.03      1.00      0.04     -1.61      1.64  11009.16      1.00
 norm_lambdas[9]      0.02      0.99      0.03     -1.52      1.70   9855.51      1.00
norm_lambdas[10]     -0.00      0.99      

  0%|                                                                                                                 | 0/6 [00:00<?, ?it/s]
2it [00:00, 13273.11it/s]

32it [00:00, 33123.82it/s]

16it [00:00, 17028.38it/s]

40it [00:00, 18367.87it/s]

20it [00:00, 20651.42it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 437.58it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 46345.90it/s]


Calculating new queries


100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:00<00:00, 36517.57it/s]

Canonical queries: 69
Domain d: 3
Nodes after elimination: {('compensation',), ('education-num', 'compensation'), ('compensation', 'age')}
Edges after elimination: {(('education-num', 'compensation'), ('compensation',)): {'compensation'}, (('compensation', 'age'), ('compensation',)): {'compensation'}}
Nodes after removing: {('education-num', 'compensation'), ('compensation', 'age')}
Edges after removing: {(('education-num', 'compensation'), ('compensation', 'age')): {'compensation'}}
Suff stat d: 69
Lambda d: 69
Recording:  Calculating lambda0





{'experiment_id': 'L2PZ6VBQ', 'start': 431032.993564738, 'stop': 431033.436578496, 'timedelta': 0.44301375799113885, 'task': 'Calculating lambda0', 'dataset_name': 'adult_small', 'query_str': 'empty', 'query_list': [], 'epsilon': 50, 'delta': 1.0992076159646117e-09, 'MCMC_algo': 'NUTS', 'laplace_approximation': True, 'missing_query': None, 'discretization': None, 'n_canonical_queries': 69}
Recording:  Calculating lambda0
{'experiment_id': 'L2PZ6VBQ', 'start': 431033.436666832, 'stop': 431037.646653186, 'timedelta': 4.209986353991553, 'task': 'Calculating lambda0', 'dataset_name': 'adult_small', 'query_str': 'empty', 'query_list': [], 'epsilon': 50, 'delta': 1.0992076159646117e-09, 'MCMC_algo': 'NUTS', 'laplace_approximation': True, 'missing_query': None, 'discretization': None, 'n_canonical_queries': 69}
[TreeNode(variables=('compensation', 'age'), parent=('education-num', 'compensation'), children=[])]
1
Junction tree width: 1
(69,)
0.34910459219280565
Recording:  Laplace approximatio

2023-03-13 14:22:58.259600: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_prim_fun.387] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 14:29:35.204699: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 8m36.945169522s

********************************
[Compiling module jit_prim_fun.387] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 14:51:28.478408: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_prim_fun.394] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach

Failed linesearch, try again
Attempting Laplace approximation, 1th try
Initialising model done
Minimising potential function


2023-03-13 15:36:32.316316: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 9m41.580129829s

********************************
[Compiling module jit_prim_fun.399] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 15:57:29.809419: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_prim_fun.401] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 16:26:36.316781: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 31m6.507023942s

********************************
[Compiling module jit_prim_fun.401] Very slow compile?  If you want to file a bug, run with envvar XLA_FLA

Failed linesearch, try again
Attempting Laplace approximation, 2th try
Initialising model done
Minimising potential function


2023-03-13 16:33:02.246476: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_prim_fun.404] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-03-13 16:40:30.533513: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 9m28.286763745s

********************************
[Compiling module jit_prim_fun.404] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
