In [1]:
import sys
sys.path.append("/home/jarlehti/projects/gradu")

In [2]:
import os
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import d3p
from src.napsu_mq.napsu_mq import NapsuMQModel, NapsuMQResult
from src.utils.preprocess_dataset import get_adult_train_small, get_adult_train_large

In [3]:
CURRENT_FOLDER = os.path.dirname(os.path.dirname(os.path.abspath(__name__)))
DATASETS_FOLDER = os.path.join(CURRENT_FOLDER, "data", "datasets")
MODELS = os.path.join(CURRENT_FOLDER, "models")

In [4]:
rng = jax.random.PRNGKey(32346234634265345)

In [5]:
adult_small = get_adult_train_small(DATASETS_FOLDER)

In [6]:
adult_small

Unnamed: 0,age,education-num,marital-status,sex,hours-per-week,compensation
0,35 - 40,13,Never-married,Male,40 - 45,0
1,50 - 55,13,Married-civ-spouse,Male,10 - 15,0
2,35 - 40,9,Divorced,Male,40 - 45,0
3,50 - 55,7,Married-civ-spouse,Male,40 - 45,0
4,25 - 30,13,Married-civ-spouse,Female,40 - 45,0
...,...,...,...,...,...,...
30157,25 - 30,12,Married-civ-spouse,Female,35 - 40,0
30158,40 - 45,9,Married-civ-spouse,Male,40 - 45,1
30159,55 - 60,9,Widowed,Female,40 - 45,0
30160,20 - 25,9,Never-married,Male,20 - 25,0


In [None]:
columns = ['age', 'education-num', 'marital-status', 'sex', 'compensation']
print(f"Testing Laplace approximation with forward mode and MCMC with columns: {columns}")
adult_subset = adult_small[columns]

n, d = adult_subset.shape

print(adult_subset)

columns_str = "_".join(columns)

model = NapsuMQModel()
result, timer = model.fit(
    data=adult_subset,
    dataset_name="adult_small",
    rng=rng,
    epsilon=50,
    delta=(n ** (-2)),
    column_feature_set=[],
    use_laplace_approximation=True,
    laplace_approximation_algorithm="jaxopt_LBFGS",
    return_timer=True,
    laplace_approximation_forward_mode=False
)

result_path = os.path.join(MODELS, f"napsu_testing_adult_small_forward_mode_{columns_str}_10e_LA.dill")
timer.to_csv(f"napsu_testing_timer_adult_small_forward_mode_{columns_str}_10e_LA.csv")
result.store(result_path)

Testing Laplace approximation with forward mode and MCMC with columns: ['age', 'education-num', 'marital-status', 'sex', 'compensation']
           age education-num      marital-status     sex compensation
0      35 - 40            13       Never-married    Male            0
1      50 - 55            13  Married-civ-spouse    Male            0
2      35 - 40             9            Divorced    Male            0
3      50 - 55             7  Married-civ-spouse    Male            0
4      25 - 30            13  Married-civ-spouse  Female            0
...        ...           ...                 ...     ...          ...
30157  25 - 30            12  Married-civ-spouse  Female            0
30158  40 - 45             9  Married-civ-spouse    Male            1
30159  55 - 60             9             Widowed  Female            0
30160  20 - 25             9       Never-married    Male            0
30161  50 - 55             9  Married-civ-spouse  Female            1

[30162 rows x 5 column

  0%|                                                                           | 0/10 [00:00<?, ?it/s]
14it [00:00, 28068.96it/s]

7it [00:00, 17270.66it/s]

2it [00:00, 10551.71it/s]

32it [00:00, 30538.73it/s]

14it [00:00, 29418.97it/s]

16it [00:00, 44709.44it/s]

140it [00:00, 28318.02it/s]

20it [00:00, 29056.49it/s]

2it [00:00, 21959.71it/s]
100%|█████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 353.33it/s]
100%|████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 159783.01it/s]


Calculating new queries


100%|███████████████████████████████████████████████████████████████| 42/42 [00:00<00:00, 30451.30it/s]

Canonical queries: 183
Domain d: 5
Nodes after elimination: {('compensation', 'education-num'), ('compensation', 'marital-status'), ('marital-status',), ('sex', 'marital-status'), ('age', 'marital-status')}
Edges after elimination: {(('compensation', 'education-num'), ('compensation', 'marital-status')): {'compensation'}, (('sex', 'marital-status'), ('marital-status',)): {'marital-status'}, (('age', 'marital-status'), ('marital-status',)): {'marital-status'}, (('compensation', 'marital-status'), ('marital-status',)): {'marital-status'}}
Nodes after removing: {('compensation', 'education-num'), ('compensation', 'marital-status'), ('sex', 'marital-status'), ('age', 'marital-status')}
Edges after removing: {(('compensation', 'education-num'), ('compensation', 'marital-status')): {'compensation'}, (('compensation', 'marital-status'), ('sex', 'marital-status')): {'marital-status'}, (('compensation', 'marital-status'), ('age', 'marital-status')): {'marital-status'}}
Suff stat d: 183
Lambda d




{'experiment_id': 'J62G8L4W', 'start': 456965.124414994, 'stop': 456966.467993109, 'timedelta': 1.3435781150474213, 'task': 'Calculating lambda0', 'dataset_name': 'adult_small', 'query_str': 'empty', 'query_list': [], 'epsilon': 50, 'delta': 1.0992076159646117e-09, 'MCMC_algo': 'NUTS', 'laplace_approximation': True, 'missing_query': None, 'discretization': None, 'n_canonical_queries': 183}
Recording:  Calculating suff stat mean and cov
{'experiment_id': 'J62G8L4W', 'start': 456966.468072571, 'stop': 456981.787616657, 'timedelta': 15.31954408599995, 'task': 'Calculating suff stat mean and cov', 'dataset_name': 'adult_small', 'query_str': 'empty', 'query_list': [], 'epsilon': 50, 'delta': 1.0992076159646117e-09, 'MCMC_algo': 'NUTS', 'laplace_approximation': True, 'missing_query': None, 'discretization': None, 'n_canonical_queries': 183}
[TreeNode(variables=('compensation', 'marital-status'), parent=('compensation', 'education-num'), children=[('sex', 'marital-status'), ('age', 'marital-s