In [1]:
import itertools
import sys
import os
sys.path.append("../") # go to parent dir

import jax
import jax.random as jr
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import numpy as np
from scipy.stats import rankdata
import scipy.stats as ss
import seaborn as sns
from sklearn.model_selection import KFold

# from data.create_sim_data import *
import data.template_causl_simulations as causl_py
from data.run_all_simulations import plot_simulation_results
import data.hyperparam_and_bootstrapping as hb
from frugal_flows.causal_flows import independent_continuous_marginal_flow, get_independent_quantiles, train_frugal_flow
from frugal_flows.bijections import UnivariateNormalCDF

import rpy2.robjects as ro
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import SignatureTranslatedAnonymousPackage

# Activate automatic conversion of rpy2 objects to pandas objects
pandas2ri.activate()

# Import the R library causl
try:
    causl = importr('causl')
except Exception as e:
    package_names = ('causl')
    utils.install_packages(StrVector(package_names))

jax.config.update("jax_enable_x64", True)

hyperparams_dict = {
    'learning_rate': 5e-3,
    'RQS_knots': 8,
    'flow_layers': 5,
    'nn_width': 50,
    'nn_depth': 4,    
    'max_patience': 50,
    'max_epochs': 10000
}

NUM_ITER = 10
TRUE_PARAMS = {'ate': 1, 'const': 1, 'scale': 1}
CAUSAL_PARAMS = [1, 1]

## Generate Data

In [7]:
data = causl_py.generate_gaussian_samples(N=1000, causal_params=[1,1], seed=0)
Z_cont = data.get('Z_cont')
X = data.get('X')
Y = data.get('Y')

In [1]:
import jax
jax.devices()

[CpuDevice(id=0)]

In [18]:
# Define the hyperparameter ranges
param_grid = {
    'RQS_knots': [4, 6, 8],
    'flow_layers': [4, 6, 8],
    'nn_width': [20, 40, 60],
    'nn_depth': [4, 6, 8],
    'learning_rate': [3e-3, 5e-3],
    'batch_size': [1000],
    'max_patience': [50],
    'max_epochs': [10000]
}

param_combinations = hb.generate_param_combinations(param_grid)

In [19]:
hyperparam_fits = hb.gaussian_outcome_hyperparameter_search(
    X, Y, Z_disc=None, Z_cont=Z_cont, param_combinations=param_combinations, seed=0
)

 11%|██████████▌                                                                                     | 1096/10000 [00:49<06:40, 22.21it/s, train=-0.1539257789469866, val=0.2261695015260868 (Max patience reached)]
  7%|██████▊                                                                                         | 711/10000 [00:25<05:37, 27.54it/s, train=-0.13549111643113435, val=0.5551765969044515 (Max patience reached)]
 14%|████████████▋                                                                                 | 1351/10000 [01:09<07:25, 19.42it/s, train=-0.34658835963764434, val=0.23451815353378067 (Max patience reached)]
 22%|████████████████████▋                                                                          | 2182/10000 [02:36<09:21, 13.93it/s, train=-1.7833379680170227, val=-0.8207068324662511 (Max patience reached)]
 22%|█████████████████████                                                                          | 2213/10000 [02:46<09:46, 13.28it/s, train=-0.9

KeyboardInterrupt: 

In [17]:
hyperparam_fits.sort_values('min_loss').head(20)

Unnamed: 0,RQS_knots,flow_layers,nn_width,nn_depth,learning_rate,batch_size,max_patience,max_epochs,ate,const,scale,min_loss
0,4,2,60,8,0.003,1000,50,10000,0.7150669160365015,1.374007716934389,1.4217528496579643,-1.326096336170809
0,4,2,35,2,0.003,1000,50,10000,0.8266724675551651,1.5860350125175438,2.29718594724751,-1.3028112910257834
0,4,2,25,4,0.003,1000,50,10000,1.0179928860883678,1.665140150428186,2.572571572273175,-0.6892003449425886
0,4,2,45,6,0.003,1000,50,10000,0.8571017907300323,1.7412041588663232,2.490030621054472,-0.6708703916855526
0,6,2,50,8,0.003,1000,50,10000,0.9446084566625048,1.5971757854820765,2.3495064545679925,-0.5243316967847134
0,6,2,25,4,0.003,1000,50,10000,1.4900972867872,1.937130902510924,2.862495278420609,-0.5242080519884562
0,4,2,50,4,0.003,1000,50,10000,1.035237754176492,1.684316556302863,2.4932209816214552,-0.5143605829525498
0,8,2,25,4,0.003,1000,50,10000,1.736149392203217,1.9078992491636035,2.872326192508355,-0.4606091843220257
0,4,2,25,2,0.003,1000,50,10000,2.1145454335751657,2.2542805387658795,3.0904339231313105,-0.3695486533062358
0,4,2,45,8,0.003,1000,50,10000,1.7178410146193837,1.9942493167948057,2.925035631203214,-0.3574400345667037
