# Hyper-parameter tuning for GANITE

This notebook presents the solution for hyper-parameter searching for the __GANITE__(for both Tensorflow and PyTorch version) algorithm over the [Twins](https://bitbucket.org/mvdschaar/mlforhealthlabpub/src/master/data/twins/) dataset.

For details about each algorithm, please refer to their dedicated notebooks:
 - [GANITE(Tensorflow) notebook](https://github.com/bcebere/ite-api/blob/main/notebooks/ganite_train_evaluation.ipynb).
 - [GANITE(PyTorch) notebook](https://github.com/bcebere/ite-api/blob/main/notebooks/ganite_pytorch_train_evaluation.ipynb).

## Hyper-parameter tuning

Hyperparameter tuning refers to performing a search to discover the model parameters that result in the model's best performance on a specific dataset.

One algorithm for performing hyperparameter optimization is the [__Bayesian Optimization__](https://en.wikipedia.org/wiki/Bayesian_optimization).


__Bayesian Optimization__ provides a principled technique based on [Bayes Theorem](https://en.wikipedia.org/wiki/Bayes%27_theorem) to direct a search of a global optimization problem that is efficient and effective. It works by building a probabilistic model of the objective function, called the surrogate function, that is then searched efficiently with an acquisition function before candidate samples are chosen to evaluate the real objective function.

For the tuning, we use the [__Scikit-Optimize__](https://scikit-optimize.github.io/stable/) library, which provides a general toolkit for [Bayesian Optimization](https://en.wikipedia.org/wiki/Bayesian_optimization) that can be used for hyperparameter tuning.

For __GANITE__, we try to optimize the following hyperparameters using the ranges suggested in [[3] Table 6](https://openreview.net/forum?id=ByKWUeWA-):

| Hyperparameter | Search area | Description |
| --- | --- | --- |
| dim_hidden | {dim, int(dim/2), int(dim/3), int(dim/4), int(dim/5)} | the size of the hidden layers. |
| depth |{1, 3, 5, 7, 9} | the number of hidden layers in the generator and inference blocks. |
| alpha | {0, 0.1, 0.5, 1, 2, 5, 10} | weight for the Generator block loss. |
| beta | {0, 0.1, 0.5, 1, 2, 5, 10} | weight the ITE block loss. |
| num_discr_iterations | [3, 10] | number of iterations executed by the Counterfactual discriminator. |
| minibatch_size | {32, 64, 128, 256} | the size of the dataset batches. |

 

You can find the __GANITE__ hyperparameter tuning implementation [here](https://github.com/bcebere/ite-api/blob/main/src/ite/algs/hyperparam_tuning.py).

## Setup

First, make sure that all the depends are installed in the current environment.
```
pip install -r requirements.txt
pip install .
```

Next, we import all the dependencies necessary for the task.

In [6]:
import ite.algs.hyperparam_tuning as tuning
from IPython.display import HTML, display
import tabulate


param_search_names = ["num_discr_iterations", "minibatch_size", "dim_hidden", "alpha", "beta", "depth"]

### GANITE(Tensorflow)

In [8]:
tf_best_params = tuning.search("GANITE", iterations=5000)

100%|██████████| 5000/5000 [00:46<00:00, 107.75it/s]
100%|██████████| 5000/5000 [00:05<00:00, 957.84it/s] 
100%|██████████| 5000/5000 [00:37<00:00, 132.50it/s]
100%|██████████| 5000/5000 [00:04<00:00, 1084.63it/s]
100%|██████████| 5000/5000 [00:28<00:00, 175.32it/s]
100%|██████████| 5000/5000 [00:04<00:00, 1041.94it/s]
100%|██████████| 5000/5000 [01:04<00:00, 77.37it/s] 
100%|██████████| 5000/5000 [00:07<00:00, 636.72it/s]
100%|██████████| 5000/5000 [02:01<00:00, 41.12it/s]
100%|██████████| 5000/5000 [00:11<00:00, 435.71it/s]
100%|██████████| 5000/5000 [00:39<00:00, 127.79it/s]
100%|██████████| 5000/5000 [00:07<00:00, 636.15it/s]
100%|██████████| 5000/5000 [01:00<00:00, 82.15it/s]
100%|██████████| 5000/5000 [00:08<00:00, 588.47it/s]
100%|██████████| 5000/5000 [00:50<00:00, 98.11it/s] 
100%|██████████| 5000/5000 [00:07<00:00, 685.55it/s]
100%|██████████| 5000/5000 [00:38<00:00, 129.34it/s]
100%|██████████| 5000/5000 [00:06<00:00, 727.99it/s]
100%|██████████| 5000/5000 [00:39<00:00, 127.

100%|██████████| 5000/5000 [01:04<00:00, 77.59it/s] 
100%|██████████| 5000/5000 [00:07<00:00, 637.95it/s]
100%|██████████| 5000/5000 [00:41<00:00, 119.61it/s]
100%|██████████| 5000/5000 [00:08<00:00, 621.74it/s]
100%|██████████| 5000/5000 [01:16<00:00, 65.14it/s] 
100%|██████████| 5000/5000 [00:06<00:00, 779.75it/s] 
100%|██████████| 5000/5000 [00:34<00:00, 145.62it/s]
100%|██████████| 5000/5000 [00:08<00:00, 614.05it/s]
100%|██████████| 5000/5000 [00:20<00:00, 244.16it/s]
100%|██████████| 5000/5000 [00:03<00:00, 1391.65it/s]
100%|██████████| 5000/5000 [00:43<00:00, 114.34it/s]
100%|██████████| 5000/5000 [00:07<00:00, 670.87it/s]
100%|██████████| 5000/5000 [00:43<00:00, 114.29it/s]
100%|██████████| 5000/5000 [00:06<00:00, 716.15it/s]
100%|██████████| 5000/5000 [00:57<00:00, 86.78it/s] 
100%|██████████| 5000/5000 [00:04<00:00, 1216.97it/s]
100%|██████████| 5000/5000 [00:25<00:00, 196.44it/s]
100%|██████████| 5000/5000 [00:07<00:00, 671.93it/s]
100%|██████████| 5000/5000 [01:04<00:00, 76

### Hyper-parameter tuning for GANITE(Tensorflow)

In [9]:
display(HTML(tabulate.tabulate([tf_best_params], headers=param_search_names, tablefmt='html')))

num_discr_iterations,minibatch_size,dim_hidden,alpha,beta,depth
9,32,7,5,0.1,2


### GANITE (PyTorch)

In [4]:
torch_best_params = tuning.search("GANITE_TORCH")

100%|██████████| 2000/2000 [01:13<00:00, 27.04it/s]
100%|██████████| 2000/2000 [00:11<00:00, 168.64it/s]
100%|██████████| 2000/2000 [01:20<00:00, 24.85it/s]
100%|██████████| 2000/2000 [00:10<00:00, 198.42it/s]
100%|██████████| 2000/2000 [00:41<00:00, 48.76it/s]
100%|██████████| 2000/2000 [00:06<00:00, 286.62it/s]
100%|██████████| 2000/2000 [00:58<00:00, 34.48it/s]
100%|██████████| 2000/2000 [00:08<00:00, 249.81it/s]
100%|██████████| 2000/2000 [00:53<00:00, 37.51it/s]
100%|██████████| 2000/2000 [00:09<00:00, 213.55it/s]
100%|██████████| 2000/2000 [00:47<00:00, 41.89it/s]
100%|██████████| 2000/2000 [00:08<00:00, 244.43it/s]
100%|██████████| 2000/2000 [00:31<00:00, 62.96it/s]
100%|██████████| 2000/2000 [00:10<00:00, 184.20it/s]
100%|██████████| 2000/2000 [00:24<00:00, 80.05it/s]
100%|██████████| 2000/2000 [00:07<00:00, 280.38it/s]
100%|██████████| 2000/2000 [00:58<00:00, 34.24it/s]
100%|██████████| 2000/2000 [00:07<00:00, 279.55it/s]
100%|██████████| 2000/2000 [00:56<00:00, 35.50it/s]
100

### Hyper-parameter tuning results for GANITE(PyTorch)

In [5]:
display(HTML(tabulate.tabulate([torch_best_params], headers=param_search_names, tablefmt='html')))

num_discr_iterations,minibatch_size,dim_hidden,alpha,beta,depth
6,256,30,1,10,5


## References
1. [Scikit-Optimize for Hyperparameter Tuning in Machine Learning](https://machinelearningmastery.com/scikit-optimize-for-hyperparameter-tuning-in-machine-learning).
2. [scikit-optimize](https://scikit-optimize.github.io/).
3. Jinsung Yoon, James Jordon, Mihaela van der Schaar, "GANITE: Estimation of Individualized Treatment Effects using Generative Adversarial Nets", International Conference on Learning Representations (ICLR), 2018 ([Paper](https://openreview.net/forum?id=ByKWUeWA-)).