In [2]:
import sys
sys.path.append("/home/jarlehti/projects/gradu")

In [11]:
import os
import numpy as np
import pandas as pd
import scipy.stats as stats
import jax
import d3p
import matplotlib.pyplot as plt
from src.napsu_mq.napsu_mq import NapsuMQModel, NapsuMQResult

In [5]:
CURRENT_FOLDER = os.path.dirname(os.path.dirname(os.path.abspath(__name__)))
DATASETS_FOLDER = os.path.join(CURRENT_FOLDER, "data", "datasets")
MODELS = os.path.join(CURRENT_FOLDER, "models")

In [24]:
rng = jax.random.PRNGKey(3452345346)

In [None]:
R = 10
training_rngs = jax.random.split(rng, R)

for i in range(R):
    
    model_rng, sampling_rng = jax.random.split(training_rngs[i], 2)

    mus = list(range(1, 6))
    data = {}
    for mu in mus:
        poisson_samples = stats.poisson.rvs(mu, size=(10000))
        plt.hist(poisson_samples)
        data[f'mu_{mu}'] = poisson_samples

    dataset = pd.DataFrame(data, dtype="category")
    
    n, d = dataset.shape
    
    model = NapsuMQModel()
    result: NapsuMQResult = model.fit(
        data=dataset,
        dataset_name="poisson_data",
        rng=model_rng,
        epsilon=8,
        delta=(n ** (-2)),
        column_feature_set=[],
        use_laplace_approximation=True,
        laplace_approximation_algorithm="torch_LBFGS",
    )
    
    synthetic_datasets = result.generate_extended(
        rng=sampling_rng, 
        num_data_per_parameter_sample=10000,
        num_parameter_samples=10,
        single_dataframe=True
    )
    
    MLE_means = np.mean(synthetic_datasets, axis=1)
    
    print(MLE_means)

Domain size: 240240
Recording:  Query selection
start MST selection
end MST selection
Recording:  Calculating full marginal query
MST query set: [('mu_1', 'mu_4'), ('mu_2', 'mu_5'), ('mu_3', 'mu_5'), ('mu_4', 'mu_5')]
Full set of marginal queries: 729
Recording:  Calculating canonical query set
Calculating canonical queries, clique_set length: 10


  0%|                                                                                                                                                                 | 0/10 [00:00<?, ?it/s]
208it [00:00, 25509.22it/s]

13it [00:00, 14509.30it/s]

240it [00:00, 36628.81it/s]

105it [00:00, 29710.71it/s]

16it [00:00, 8251.43it/s]

176it [00:00, 31053.24it/s]

11it [00:00, 27316.37it/s]

7it [00:00, 30615.36it/s]

15it [00:00, 28886.39it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 187.56it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 131072.00it/s]


Calculating new queries, not_original_clique_queries length: 57


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 19280.27it/s]

Canonical queries: 681
Junction tree width: 2
(681,)
Number of marginal queries: 4
2.116778525183835





DP noise mean: -0.018754324753820772
[ 4.60035265e+01  1.25884845e+02  2.14492236e+02  2.68122228e+02
  2.29982821e+02  2.16906882e+02  1.45183178e+02  1.05788529e+02
  4.73081256e+01  3.08683535e+01  2.94090442e+00 -1.86178696e-01
 -1.41870972e+00  3.03636809e+00  2.14792138e+00  7.05978405e+01
  1.90248246e+02  3.09207641e+02  3.67234050e+02  3.95412005e+02
  3.37345889e+02  2.60307109e+02  1.33036213e+02  9.25940041e+01
  4.13759331e+01  1.73568576e+01  7.82079048e+00  3.34526307e+00
  2.51992512e+00 -4.20206040e-01  7.55792351e+01  1.61598645e+02
  3.13761640e+02  3.88723386e+02  3.65552463e+02  3.27625208e+02
  2.41752998e+02  1.56053390e+02  7.16505738e+01  3.22809814e+01
  1.80620829e+01  1.10991691e+01  2.58373840e+00 -1.53137095e+00
  1.79677386e+00  5.44458698e+01  1.62491768e+02  2.34703850e+02
  2.97054383e+02  3.07067572e+02  2.42526321e+02  2.15670620e+02
  1.17641387e+02  6.40929991e+01  2.67972814e+01  1.48351674e+01
  7.38304781e+00  4.39852068e-02  1.99749025e-01 -1.7

Running Laplace optimization
Laplace approximation time: {'experiment_id': 'TZ6VARUJ', 'start': 166049.2623535, 'stop': 167309.257749071, 'timedelta': 1259.9953955709934, 'task': 'Laplace approximation', 'dataset_name': 'poisson_data', 'query_str': 'empty', 'query_list': [], 'epsilon': 8, 'delta': 1e-08, 'MCMC_algo': 'NUTS', 'laplace_approximation': True, 'missing_query': None, 'discretization': None, 'no_privacy': False, 'disable_MST': False, 'n_canonical_queries': 681, 'junction_tree_width': 2, 'suff_stat_dim': (681,), 'suff_stat': array([ 46, 127, 214, 267, 232, 218, 142, 106,  45,  33,   6,   2,   1,
         0,   0,  74, 192, 306, 370, 396, 339, 261, 138,  91,  43,  17,
         7,   5,   1,   0,  79, 164, 317, 388, 365, 323, 238, 157,  73,
        33,  19,   8,   1,   1,   0,  55, 163, 230, 294, 309, 245, 213,
       118,  62,  25,  12,   7,   2,   0,   1,  24,  82, 153, 210, 170,
       162,  91,  57,  35,  13,   9,   3,   2,   1,   0,  13,  45,  80,
        92,  87,  83,  55,  

  mcmc = numpyro.infer.MCMC(


[ 4.60035265e+01  1.25884845e+02  2.14492236e+02  2.68122228e+02
  2.29982821e+02  2.16906882e+02  1.45183178e+02  1.05788529e+02
  4.73081256e+01  3.08683535e+01  2.94090442e+00 -1.86178696e-01
 -1.41870972e+00  3.03636809e+00  2.14792138e+00  7.05978405e+01
  1.90248246e+02  3.09207641e+02  3.67234050e+02  3.95412005e+02
  3.37345889e+02  2.60307109e+02  1.33036213e+02  9.25940041e+01
  4.13759331e+01  1.73568576e+01  7.82079048e+00  3.34526307e+00
  2.51992512e+00 -4.20206040e-01  7.55792351e+01  1.61598645e+02
  3.13761640e+02  3.88723386e+02  3.65552463e+02  3.27625208e+02
  2.41752998e+02  1.56053390e+02  7.16505738e+01  3.22809814e+01
  1.80620829e+01  1.10991691e+01  2.58373840e+00 -1.53137095e+00
  1.79677386e+00  5.44458698e+01  1.62491768e+02  2.34703850e+02
  2.97054383e+02  3.07067572e+02  2.42526321e+02  2.15670620e+02
  1.17641387e+02  6.40929991e+01  2.67972814e+01  1.48351674e+01
  7.38304781e+00  4.39852068e-02  1.99749025e-01 -1.77978185e+00
  2.29170143e+01  8.03219

2023-05-21 20:47:58.717862: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_backward_pass.94] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-05-21 20:53:02.219805: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 7m3.503081641s

********************************
[Compiling module jit_backward_pass.94] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-05-21 21:00:19.107212: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 6m47.937538624s

********************************
[Compiling module jit_backward_pass.97] Very slow compile?  If you want to file a bug, run with env

In [22]:
dataset

Unnamed: 0,mu_1,mu_2,mu_3,mu_4,mu_5
0,2,1,1,4,8
1,1,3,4,2,4
2,0,2,2,3,4
3,2,3,3,2,4
4,1,3,2,2,5
...,...,...,...,...,...
9995,1,2,1,4,4
9996,0,1,4,5,6
9997,1,0,3,3,3
9998,1,1,3,5,7


In [None]:
model = NapsuMQModel()
result = model.fit(
    data=dataset,
    dataset_name="binary4d",
    rng=rng,
    epsilon=8,
    delta=(n ** (-2)),
    column_feature_set=[],
    use_laplace_approximation=True,
    laplace_approximation_algorithm="torch_LBFGS",
)