These tutorials are licensed by [Bernard Koch](http://www.github.com/kochbj) under a [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/).

# Tutorial 2: Causal Inference Metrics and Hyperparameter Optimization
 
In the last tutorial we introduced the intuition behind representation learning for causal inference and built a simple TARNet model. Given that we are doing statistical inference, **proper model optimization is critical to achieving unbiased estimates.** In this tutorial I'll show you how to assess convergence of your model using Tensorboard, and do hyperparameter tuning using [Keras Tuner](https://keras-team.github.io/keras-tuner/). If you don't know what hyperparameters are, read the spoiler below.
 
<details><summary>What are hyperparameters and why should we tune them?</summary>
 
Hyperparameters are any parameters in your model that are not optimized by your loss function. In other words, they're the knobs you have to tune. To tune hyperparameters we need an objective score to compare between different hyperparameterizations. Often times this is the validation loss, but we can/should try something more sophisticated for causal inference (see below). A simple strategy called "grid search" is to exhaustively explore all possible hyperparameterizations and pick the best scoring one. Often times this is computationally intractible. Hyperparameter tuning packages like Keras Tuner instead implement more sophisticated algorithms for exploring the hyperparameter space (e.g., Random Search, Hyperband, Bayesian Optimization).
</details>

## Notation
**Causal identification**

- Observed covariates/features: $X$

- Potential outcomes: $Y(0)$ and $Y(1)$

- Treatment: $T$

- Unobservable Individual Treatment Effect: $\tau_i = Y_i(1) - Y_i(0)$

- Average Treatment Effect: $ATE =\mathbb{E}[Y_i(1)-Y_i(0)]= \mathbb{E}[{\tau_i}]$

- Conditional Average Treatment Effect: $CATE(x) =\mathbb{E}[Y_i(1)-Y_i(0)|X=x]$


**Deep learning estimation**

- Predicted outcomes: $\hat{Y}(0)$ and $\hat{Y}(1)$

- Outcome modeling functions: $\hat{Y}(T)=h(X,T)$ 

- Representation functions: $\Phi(X)$

- Propensity score function:
$\pi(X,T)=P(T|X)$ </br>*where $\pi(X,1)=P(T=1|X)$ and $\pi(X,0)=1-\pi(X,1)$* 

- Loss functions: $\mathcal{L}(true,predicted)$, with the mean squared error abbreviated $MSE$ and binary cross-entropy as $BCE$

- Estimated CATE<sup>*</sup>: $\hat{CATE_i} = \hat{\tau}_i = \hat{Y_i}(1)-\hat{Y_i}(0) = h(X,1)-h(X,0)$

- Estimated ATE: $\hat{ATE}=\frac{1}{n}\sum_{i=1}^n\hat{CATE_i}$

- Nearest-neighbor PEHE:
$$PEHE_{nn}=\frac{1}{N}\sum_{i=1}^N{(\underbrace{(1−2t_i)(y_i(t_i)−y_i^{nn}(1-t_i)}_{CATE_{nn}}−\underbrace{(h(\Phi(x),1)−h(\Phi(x),0)))}_{\hat{CATE}}}^2$$ for nearest neighbor $j$ of each unit $i$ in representation space such that $t_j\neq t_i$:
  $$y_i^{nn}(1-t_i) = \min_{j\in (1-T)}||\Phi(x_i|t_i)-\Phi(x_j|1-t_i)||_2$$

\* We define $\hat{\tau}_i = \hat{CATE_i}$ because the we lack the covariates to estimate the ITE.  

# Part 1: Using Metrics and Tensorboard to Evaluate Models

## Training metrics for causal inference

Although our ultimate goal is to estimate the $\hat{CATE}$, the loss function in TARNet only minimizes the factual error to estimate $\hat{Y}$. This is a reflection of the fundamental problem of causal inference: we only observe one potential outcome for each unit.

Within this literature, it is common practice to evaluate model performance on simulations using the Precision Estimation of Heterogeneous  Effects (PEHE) from [Hill, 2011](https://www.tandfonline.com/doi/abs/10.1198/jcgs.2010.08162?casa_token=b8-rfzagECIAAAAA:QeP7C4lKN6nZ7MkDjJHFrEberXopD9M5qPBMeBqbk84mI_8qGxj01ctgt4jdZtORpu9aZvpVRe07PA). PEHE measures the error in estimates of the $CATE$:

$$PEHE=\frac{1}{N}\sum_{i=1}^N(CATE_i-\hat{CATE_i})^2$$

In order to select hyperparameters in real data, [Johansson et al., 2020](https://arxiv.org/pdf/2001.07426.pdf) propose to use a matching variant of $PEHE$ with the nearest Euclidean neighbor of each unit $i$ from the other treatment assignment group $y_i^{nn}$ as a counterfactual. If we identify the nearest neighbor $j$ of each unit $i$ in representation space such that $t_j\neq t_i$ as

  $$y_i^{nn}(1-t_i) = \min_{j\in (1-T)}||\Phi(x_i|t_i)-\Phi(x_j|1-t_i)||_2$$
 then,
$$PEHE_{nn}=\frac{1}{N}\sum_{i=1}^N{(\underbrace{(1−2t_i)(y_i(t_i)−y_i^{nn}(1-t_i)}_{CATE_{nn}}−\underbrace{(h(\Phi(x),1)−h(\Phi(x),0)))}_{\hat{CATE}}}^2$$
If we take the square root of the $PEHE_{nn}$ then we get an approximation of the unit-level error.

I think the intuition behind $\sqrt{PEHE_{nn}}$ is solid. If our representation function $\Phi$ is truly learning to balance the treated and control distributions, $CATE_{nn}$ should coarsely measure it.

### Additional metrics for simulations (known counterfactuals)

Since we know both potential outcomes we might also like to calculate bias in $\hat{ATE}$ and $\hat{CATE}$, as well as the actual $PEHE$,

- $ATE_{bias} = |ATE-\hat{ATE}|$
- $CATE_{bias} = \frac{1}{N}\sum_{i=1}^N |CATE_i-\hat{CATE_i}|$
-  $\sqrt{PEHE}=\sqrt{\frac{1}{N}\sum_{i=1}^N(CATE_i-\hat{CATE_i})^2}$


## Reloading the Data

Reload the IHDP data. for more information on the dataset see [this cell](https://colab.research.google.com/drive/1Zx0AkriygB_ws6qXjA7VfqebG-YMwbWl?authuser=2#scrollTo=jVhJelhqCMD7) from the previous tutorial.

In [None]:
#@title First load the data! (Click Play)
import numpy as np
!pip install scikit-learn==0.24.2
from sklearn.preprocessing import StandardScaler
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.train.npz
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.test.npz 

def load_IHDP_data(training_data,testing_data,i=7):
    with open(training_data,'rb') as trf, open(testing_data,'rb') as tef:
        train_data=np.load(trf); test_data=np.load(tef)
        y=np.concatenate(   (train_data['yf'][:,i],   test_data['yf'][:,i])).astype('float32') #most GPUs only compute 32-bit floats
        t=np.concatenate(   (train_data['t'][:,i],    test_data['t'][:,i])).astype('float32')
        x=np.concatenate(   (train_data['x'][:,:,i],  test_data['x'][:,:,i]),axis=0).astype('float32')
        mu_0=np.concatenate((train_data['mu0'][:,i],  test_data['mu0'][:,i])).astype('float32')
        mu_1=np.concatenate((train_data['mu1'][:,i],  test_data['mu1'][:,i])).astype('float32')

        data={'x':x,'t':t,'y':y,'t':t,'mu_0':mu_0,'mu_1':mu_1}
        data['t']=data['t'].reshape(-1,1) #we're just padding one dimensional vectors with an additional dimension 
        data['y']=data['y'].reshape(-1,1)
        
        #rescaling y between 0 and 1 often makes training of DL regressors easier
        data['y_scaler'] = StandardScaler().fit(data['y'])
        data['ys'] = data['y_scaler'].transform(data['y'])

    return data

data=load_IHDP_data(training_data='./ihdp_npci_1-100.train.npz',testing_data='./ihdp_npci_1-100.test.npz')


## Adding our metrics to tensorflow

If we use the `.fit` API's built-in metric system, we will be calculating our sample metrics within each mini-batch (e.g. for 64 units) which isn't very useful. Instead, we want to calculate our metrics on the full validation set or entire dataset all at once, at the end of ever epoch. We'll need to create a [custom callback](https://www.tensorflow.org/guide/keras/custom_callback) for that. A callback is code that runs dynamically in response to an event (e.g., a TF2 model finishing training). The procedure for subclassing Callbacks is very similar to how one would subclass custom layers or losses (both of which we do in subsequent tutorials).

There are some tricky manipulations in `pdist2sq` and `find_ynn` but don't worry about those. The rest should be clear.

Note that we've added the `tf.summary.scalar` lines to log these metrics for viewing in Tensorboard.

In [None]:
!pip install -q tensorflow==2.8.0
import tensorflow as tf
import numpy as np
import datetime
#Colab command to allow us to run Colab in TF2
%load_ext tensorboard 

In [None]:
from tensorflow.keras.callbacks import Callback

def pdist2sq(x,y):
    x2 = tf.reduce_sum(x ** 2, axis=-1, keepdims=True)
    y2 = tf.reduce_sum(y ** 2, axis=-1, keepdims=True)
    dist = x2 + tf.transpose(y2, (1, 0)) - 2. * x @ tf.transpose(y, (1, 0))
    return dist

'''
def pdist2sq(A, B):
    #helper for PEHEnn
    #calculates squared euclidean distance between rows of two matrices  
    #https://gist.github.com/mbsariyildiz/34cdc26afb630e8cae079048eef91865
    # squared norms of each row in A and B
    na = tf.reduce_sum(tf.square(A), 1)
    nb = tf.reduce_sum(tf.square(B), 1)    
    # na as a row and nb as a column vectors
    na = tf.reshape(na, [-1, 1])
    nb = tf.reshape(nb, [1, -1])
    # return pairwise euclidean difference matrix
    D = tf.sqrt(tf.maximum(na - 2*tf.matmul(A, B, False, True) + nb, 0.0))
    return D
'''

#https://towardsdatascience.com/implementing-macro-f1-score-in-keras-what-not-to-do-e9f1aa04029d
class Full_Metrics(Callback):
    def __init__(self,data, verbose=0):   
        super(Full_Metrics, self).__init__()
        self.data=data #feed the callback the full dataset
        self.verbose=verbose

        #needed for PEHEnn; Called in self.find_ynn
        self.data['o_idx']=tf.range(self.data['t'].shape[0])
        self.data['c_idx']=self.data['o_idx'][self.data['t'].squeeze()==0] #These are the indices of the control units
        self.data['t_idx']=self.data['o_idx'][self.data['t'].squeeze()==1] #These are the indices of the treated units
    
    def split_pred(self,concat_pred):
        #this helps us keep ptrack of things so we don't make mistakes
        preds={}
        preds['y0_pred'] = self.data['y_scaler'].inverse_transform(concat_pred[:, 0].reshape(-1, 1))
        preds['y1_pred'] = self.data['y_scaler'].inverse_transform(concat_pred[:, 1].reshape(-1, 1))
        preds['phi'] = concat_pred[:, 2:]
        return preds

    def find_ynn(self, Phi):
        #helper for PEHEnn
        PhiC, PhiT =tf.dynamic_partition(Phi,tf.cast(tf.squeeze(self.data['t']),tf.int32),2) #separate control and treated reps
        dists=tf.sqrt(pdist2sq(PhiC,PhiT)) #calculate squared distance then sqrt to get euclidean
        yT_nn_idx=tf.gather(self.data['c_idx'],tf.argmin(dists,axis=0),1) #get c_idxs of smallest distances for treated units
        yC_nn_idx=tf.gather(self.data['t_idx'],tf.argmin(dists,axis=1),1) #get t_idxs of smallest distances for control units
        yT_nn=tf.gather(self.data['y'],yT_nn_idx,1) #now use these to retrieve y values
        yC_nn=tf.gather(self.data['y'],yC_nn_idx,1)
        y_nn=tf.dynamic_stitch([self.data['t_idx'],self.data['c_idx']],[yT_nn,yC_nn]) #stitch em back up!
        return y_nn

    def PEHEnn(self,concat_pred):
        p = self.split_pred(concat_pred)
        y_nn = self.find_ynn(p['phi']) #now its 3 plus because 
        cate_nn_err=tf.reduce_mean( tf.square( (1-2*self.data['t']) * (y_nn-self.data['y']) - (p['y1_pred']-p['y0_pred']) ) )
        return cate_nn_err

    def ATE(self,concat_pred):
        p = self.split_pred(concat_pred)
        return p['y1_pred']-p['y0_pred']

    def PEHE(self,concat_pred):
        #simulation only
        p = self.split_pred(concat_pred)
        cate_err=tf.reduce_mean( tf.square( ( (self.data['mu_1']-self.data['mu_0']) - (p['y1_pred']-p['y0_pred']) ) ) )
        return cate_err 

    def on_epoch_end(self, epoch, logs={}):
        concat_pred=self.model.predict(self.data['x'])
        #Calculate Empirical Metrics        
        ate_pred=tf.reduce_mean(self.ATE(concat_pred)); tf.summary.scalar('ate', data=ate_pred, step=epoch)
        pehe_nn=self.PEHEnn(concat_pred); tf.summary.scalar('cate_nn_err', data=tf.sqrt(pehe_nn), step=epoch)
        
        #Simulation Metrics
        ate_true=tf.reduce_mean(self.data['mu_1']-self.data['mu_0'])
        ate_err=tf.abs(ate_true-ate_pred); tf.summary.scalar('ate_err', data=ate_err, step=epoch)
        pehe =self.PEHE(concat_pred); tf.summary.scalar('cate_err', data=tf.sqrt(pehe), step=epoch)
        out_str=f' — ate_err: {ate_err:.4f}  — cate_err: {tf.sqrt(pehe):.4f} — cate_nn_err: {tf.sqrt(pehe_nn):.4f} '
        
        if self.verbose > 0: print(out_str)

## Running the Model
Now reload the model, loss, and fitting boiler plate from last tutorial.
We've made three minor changes. First, we return `phi` in `concat_pred`. Second, we add `FullMetrics` as a callback. Third, we add some code and a callback to save metrics for later viewing in Tensorboard:
```
!rm -rf ./logs/ 
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
```

In [None]:
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Concatenate
from tensorflow.keras import regularizers
from tensorflow.keras import Model
 
def make_tarnet(input_dim, reg_l2):
    '''
    The first argument is the column dimension of our data.
    It needs to be specified because the functional API creates a static computational graph
    The second argument is the strength of regularization we'll apply to the output layers
    '''
    x = Input(shape=(input_dim,), name='input')
 
    # REPRESENTATION
    #in TF2/Keras it is idiomatic to instantiate a layer and pass its inputs on the same line unless the layer will be reused
    #Note that we apply no regularization to the representation layers 
    phi = Dense(units=200, activation='elu', kernel_initializer='RandomNormal',name='phi_1')(x)
    phi = Dense(units=200, activation='elu', kernel_initializer='RandomNormal',name='phi_2')(phi)
    phi = Dense(units=200, activation='elu', kernel_initializer='RandomNormal',name='phi_3')(phi)
 
    # HYPOTHESIS
    y0_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_1')(phi)
    y1_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_1')(phi)
 
    # second layer
    y0_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_2')(y0_hidden)
    y1_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_2')(y1_hidden)
 
    # third
    y0_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y0_predictions')(y0_hidden)
    y1_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y1_predictions')(y1_hidden)
 
    #a convenience "layer" that concatenates arrays as columns in a matrix
    #this time we'll return Phi as well to calculate cate_nn_err
    concat_pred = Concatenate(1)([y0_predictions, y1_predictions, phi])
    #the declarations above have specified the computational graph of our network, now we instantiate it
    model = Model(inputs=x, outputs=concat_pred)
 
    return model
 
# every loss function in TF2 takes 2 arguments, a vector of true values and a vector predictions
def regression_loss(concat_true, concat_pred):
    #computes a standard MSE loss for TARNet
    y_true = concat_true[:, 0] #get individual vectors
    t_true = concat_true[:, 1]
 
    y0_pred = concat_pred[:, 0]
    y1_pred = concat_pred[:, 1]
 
    #Each head outputs a prediction for both potential outcomes
    #We use t_true as a switch to only calculate the factual loss
    loss0 = tf.reduce_sum((1. - t_true) * tf.square(y_true - y0_pred))
    loss1 = tf.reduce_sum(t_true * tf.square(y_true - y1_pred))
    #note Shi uses tf.reduce_sum for her losses even though mathematically we should be using the mean
    #tf.reduce_mean and tf.reduce_sum should be equivalent, but maybe having larger error gradients makes training easier?
    return loss0 + loss1
 
### MAIN CODE ####
 
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras.optimizers import SGD
 
#make model
tarnet_model=make_tarnet(data['x'].shape[1],.01)
 
val_split=0.2
batch_size=64
verbose=True
i = 0
tf.random.set_seed(i)
np.random.seed(i)
yt = np.concatenate([data['ys'], data['t']], 1) #we'll use both y and t to compute the loss
 
# Clear any logs from previous runs
!rm -rf ./logs/ 
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
 
sgd_callbacks = [
        TerminateOnNaN(),
        EarlyStopping(monitor='val_loss', patience=40, min_delta=0), 
        #40 is Shi's recommendation patience for this dataset, but you should tune for your data 
        ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=verbose, mode='auto',
                          min_delta=0, cooldown=0, min_lr=0),
        #This learning rate scheduling is quite agressive which seems good for this dataset
        Full_Metrics(data,verbose),
        tensorboard_callback
    ]
#optimizer hyperparameters
sgd_lr = 1e-5
momentum = 0.9
tarnet_model.compile(optimizer=SGD(learning_rate=sgd_lr, momentum=momentum, nesterov=True),
                    loss=regression_loss,
                    metrics=regression_loss)
 
tarnet_model.fit(x=data['x'],y=yt,
                callbacks=sgd_callbacks,
                validation_split=val_split,
                epochs=300,
                batch_size=batch_size,
                verbose=verbose)
print("Done!")

## Evaluating the Model Using Tensorboard

A key difference between neural networks and other machine learning models is that we avoid overfitting by stopping training once the validation error stops improving. We discussed in the last tutorial how this code is built into Tensorflow2 through callbacks. In general, it's good to check that your model has both converged and isn't being overfit in Tensorboard.

In [None]:
%tensorboard --logdir logs/fit

We'll focus on the logs of our losses and metrics in the "Scalars" tab. 

### Losses
Let's begin with `epoch_regression_loss` (the difference between `epoch_regression_loss` and `epoch_loss` is that the latter includes $L_2$ penalties). Everything looks pretty textbook with training once validation loss has plateaued. Note that the smoothness of the plateau is because we chose very aggressive settings for `ReduceLROnPlateau`. This may be something you want to tune.

### Metrics
Let's start with our simulation metrics. The $\sqrt{PEHE}$ (`cate_err`) score looks a lot like our validation loss which is good. This tells us that the model is learning to predict counterfactuals even if it only receives a factual loss. Our $ATE_{bias}$ (`ate_bias`) continues to decrease over time.

The other key observation is that nearest neighbor estimates $\sqrt{PEHE_{nn}}$ (`cate_nn_err`) are substantially biased. Because we are only using `cate_nn_err` to tune hyperparameters this isn't necessarily an issue, but we'll really need to check that `cate_nn_err` correlates with `cate_err` across hyperparamaterizations.

---

# Part 2: Hyperparameter Tuning for Statistical Estimators

Now that we have metrics for model evaluation that are appropriate to causal inference, we can talk about hyperparameter optimization.

**It doesn't matter what your model's theoretical guarantees are, if you do not appropriately tune your hyperparameters to your data you could get significantly biased estimates.** Beyond conventional hyperparameters (e.g. number of neurons), you should consider the choice of optimizer and it's settings as hyperparameters for a statistical inference. Here is a list of potentially tunable hyperparameters for TARNet:

Regularization hyperparameters:
 - $\lambda$ ($L_2$ regularization strength) for outcome layers
 - Dropout for outcome modeling layers
 - Batch normalization

Architectural hyperparameters:
  - Number of representation layers
  - Number of neurons in a representation layer
  - Number of output layers
  - Number of neurons in an output layer
  - Neuronal activation function (e.g. ELU, RELu, Sigmoid)

Optimization Hyperparameters:
 - Choice of Optimizer (e.g. SGD, ADAM)
 - Optimizer Parameters (e.g. Momentum for SGD)
 - Learning Rate Scheduling Parameters
 - Early Stopping Paremeters
 - Batch Size

We are now going to do a hyperparameter search using KerasTuner! We'll select hyperparameter settings that minimize the $\sqrt{PEHE_{nn}}$.






## Building a HyperModel using Keras Tuner

Keras Tuner provides an elegant framework for compiling TF2 models with hyperparameters. We simply specify `hp.Int`, `hp.Choice` or `hp.Bool` for hyperparameters we wish to tune. Below, we are allowing the number of representation and hypothesis layers, the number of neurons in each layer, as well as the regularization strength to be hyperparameters. These ranges are loosely informed by suggestions from [Johansson et al., 2020](https://arxiv.org/pdf/2001.07426.pdf) for IHDP.

In [None]:
# Install Keras Tuner
!pip install keras-tuner==1.0.4

In [None]:
import keras_tuner as kt
from keras_tuner.tuners import RandomSearch
def make_hypertarnet(hp):
    """
    Neural net predictive model. The dragon has three heads.
    :param input_dim:
    :param reg:
    :return:
    """
    # hp.Choice takes hyperparam name, list of options, and default
    reg_l2=hp.Choice('l2',[.1,.01,.001],default=.01)
    input_dim=25
    inputs = Input(shape=(input_dim,), name='input')

    # representation
    rep_units = hp.Choice('rep_units', [50,100,200],default=200)
    phi = Dense(units=rep_units, activation='elu', kernel_initializer='RandomNormal',name='phi_1')(inputs)
    for i in range(hp.Int('rep_layers', 1, 2, default=1)):
      #pretty nifty way to dynamically add more layers!
      phi = Dense(units=rep_units, activation='elu', kernel_initializer='RandomNormal',name='phi_'+str(i+2))(phi)

    # HYPOTHESIS
    hyp_units = hp.Choice('hyp_units', [20,50,100,200],default=100)
    y0_hidden = Dense(units=hyp_units, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_1')(phi)
    y1_hidden = Dense(units=hyp_units, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_1')(phi)
    for i in range(hp.Int('hyp_layers', 1, 3, default=2)):
      y0_hidden = Dense(units=hyp_units, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_'+str(i+2))(y0_hidden)
      y1_hidden = Dense(units=hyp_units, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_'+str(i+2))(y1_hidden)
    
    # OUTPUT
    y0_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y0_predictions')(y0_hidden)
    y1_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y1_predictions')(y1_hidden)

    concat_pred = Concatenate(1)([y0_predictions, y1_predictions,phi])
    model = Model(inputs=inputs, outputs=concat_pred)
    
    sgd_lr = 1e-5
    momentum = 0.9
    
    optimizer=SGD(lr=sgd_lr, momentum=momentum, nesterov=True)
    
    model.compile(optimizer=SGD(lr=sgd_lr, momentum=momentum, nesterov=True),
                      loss=regression_loss,
                 metrics=regression_loss)

    return model

## Tailoring Keras Tuner for our needs

Because we wish to tune models using $\sqrt{PEHE_{nn}}$ instead of the network's own loss function, we can't use the standard Keras Tuner framework. Instead, we subclass Keras Tuner and reimplement the `run_trial` method. Incidentally, this allows us to add other non-model (i.e., optimizer-related) parameters like the batch size and early-stopping patience.

 This code should look very familiar by now. The only differences is that we now have a `trial_id` for each parameter configuration which we need to use to save the model and Tensorboard logs. We also add an additional callback for saving these hyperparameter configurations in TensorBoard. Lastly the line,

`self.oracle.update_trial(trial.trial_id, {'cate_nn_err': cate_nn_err})`

reports the $\sqrt{PEHE_{nn}}$ back to Keras Tuner so that it can compare models.

In [None]:
from keras_tuner.engine import tuner_utils
from tensorboard.plugins.hparams import api as hparams_api
!rm -rf my_dir

class TarNetTuner(kt.Tuner):

  def run_trial(self, trial,dataset,*fit_args, **fit_kwargs):
      # *args and **kwargs in Python are positional (list) and keyword (dict) arguments
      verbose = fit_kwargs['verbose']

      log_dir=self.project_dir+'/trial_'+trial.trial_id
      hp = trial.hyperparameters
      
      batch_size = hp.Int('batch_size', 128, 256, step=64, default=128)
      stopping_patience=hp.Int('batch_size', 5, 15, step=5, default=5)

      #some of this hacky code will hopefully go away as Keras Tuner get's more polished
      hparams = tuner_utils.convert_hyperparams_to_hparams(trial.hyperparameters)
      file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
      file_writer.set_as_default()
      
      tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
      hparams_callback = hparams_api.KerasCallback(
                        writer=log_dir,
                        hparams=hparams,
                        #I prepend trial_ here to get it to save nicely. Hopefully will be fixed in future version of KT.
                        trial_id='trial_'+trial.trial_id) 
      metrics_callback=Full_Metrics(dataset,verbose=verbose)
      callbacks = [
              TerminateOnNaN(),
              EarlyStopping(monitor='val_loss', patience=stopping_patience, min_delta=0.0001),
              ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=verbose, mode='auto',
                                min_delta=0, cooldown=0, min_lr=0),
              metrics_callback,
              tensorboard_callback,
              hparams_callback
          ]

      
      model = self.hypermodel.build(hp)
      model.fit(x=fit_args[0],y=fit_args[1],
                 callbacks=callbacks,
                  validation_split=fit_kwargs['validation_split'],
                  epochs=fit_kwargs['epochs'],
                  batch_size=batch_size, verbose=verbose)
      
      #give the metric to the hyperparameter optimization algorithm
      concat_pred=model.predict(data['x'])
      pehe_nn=metrics_callback.PEHEnn(concat_pred)
      self.oracle.update_trial(trial.trial_id, {'cate_nn_err': tf.sqrt(pehe_nn)})
      self.save_model(trial.trial_id, model)


## Running Keras Tuner

Keras Tuner has three hyperparameter optimization strategies: Random Search, HyperBand, and Bayesian Optimization. Explaining (or attempting to explain) Bayesian Optimization, is beyond the scope of this tutorial, but we'll go with that. 

We'll simply explore a maximum of 10 configurations to keep this quick. In practice you probably want to set this to as many trials as your resources can accommodate and let Keras Tuner run overnight. At the end we'll print out the ID of the best trial.

**WARNING:** This takes about 15 minutes. Make sure you are on a GPU or TPU.

In [None]:

tuner = TarNetTuner(
    #the oracle is the hyperoptimization algorithm
    oracle=kt.oracles.BayesianOptimization(
        objective=kt.Objective('cate_nn_err', 'min'),
        max_trials=10, #were trying to keep this quick for you.
        #You probably want to do as many trials as your resources allow if you see variance between runs
        seed=0    
),
        directory='my_dir',
        project_name='helloworld',
    hypermodel=make_hypertarnet
    )
tuner.search(data, data['x'],yt, epochs=300,validation_split=.2,verbose=2)

best_trial=tuner.oracle.get_best_trials(num_trials=1)[0]
print("BEST TRIAL ID:",best_trial.trial_id)
best_model=tuner.load_model(best_trial)

## Examing Hyperparameters in Tensorboard

Once KerasTuner is done, we can boot TensorBoard back up. This time we'll focus on the "HPARAMS" tab. In the "Table View" you can compare the best trial to others on the metrics we looked at before. The "Parallel Cordinates View" and "Scatter Plot Matrix View" have more information though.

Let's check out the "Parallel Cordinates View" Here you can see trends across metrics and hyperparameterizations. To the far right are the metrics we really care about: `ate_pred`,`cate_nn_err`, and our true counterfactual errors `cate_err`.

First, we note that while the correlation between `cate_nn_err` and `cate_err` isn't perfect, there is basically a cohort of good `cate_nn_err` that correspond with good `cate_err`. This seems to correspond with having two representation layers. Given that we won't know `cate_err` with real data, this makes `cate_nn_err` seem like a reasonable choice for hyperparameter optimization, although we should probably run multiple runs of each model before choosing a final model.

In [None]:
%tensorboard --logdir my_dir/helloworld/

# That's it!

Some final thoughts on hyperparameter tuning: I think optimizer settings (early stopping, learning rate scheduling) may be more important hyperparameters than any architectural changes. Second, while hyperparameter optimzation is important, I'll emphasize that the differences between our best and worst models are pretty small, especially compared to non-neural network estimators like linear regression or causal forests.

This concludes the tutorial on evaluation and hyperparameter optimization. In this tutorial we:

- Introduced causal inference-specific metrics for DL models

- Wrote a custom callback in Tensorflow

- Learned how to evaluate our models in TensorBoard

- Learned how to tune hyperparameters in KerasTuner and compare them in TensorBoard

# Up Next...

 In the next tutorial we'll implement some more sophisticated models that do not just rely on representation learning for balancing and have some consistency guarantees.