<a href="https://colab.research.google.com/github/kochbj/Deep-Learning-for-Causal-Inference/blob/main/Tutorial_3_Semi_parametric_extensions_to_TARNet_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

These tutorials are licensed by [Bernard Koch](http://www.github.com/kochbj) under a [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/).


# Tutorial 3: Semi-parametric extensions to TARNet 

In practice, TARNet is a pretty low bias $CATE$ estimator, but it doesn't have a strong theoretical motivation. In this short tutorial we elaborate on the base TarNet model with some modifications using inverse propensity-score weighting that provide consistency guarantees from semi-parametric theory. These models are featured in [Shi et al., 2019](https://arxiv.org/pdf/1906.02120.pdf), and the code is adapted from Shi's [excellent GitHub repository](https://github.com/claudiashi57/dragonnet).  

## Notation
**Causal identification**

- Observed covariates/features: $X$

- Potential outcomes: $Y(0)$ and $Y(1)$

- Treatment: $T$

- Average Treatment Effect: $ATE =\mathbb{E}[Y(1)-Y(0)]$

- Conditional Average Treatment Effect: $CATE =\mathbb{E}[Y(1)-Y(0)|X=x]$


**Deep learning estimation**

- Predicted outcomes: $\hat{Y}(0)$ and $\hat{Y}(1)$

- Outcome modeling functions: $\hat{Y}(T)=h(X,T)$ 

- Representation functions: $\Phi(X)$ (producing representations $\phi$)

- Propensity score function:
$\pi(X,T)=P(T|X)$ </br>*where $\pi(X,1)=P(T=1|X)$ and $\pi(X,0)=1-\pi(X,1)$* 

- Loss functions: $\mathcal{L}(true,predicted)$, with the mean squared error abbreviated $MSE$ and binary cross-entropy as $BCE$

- Estimated CATE: $\hat{CATE}=(1-2t)(\hat{{y}}(t)-\hat{y}(1-t))$

- Estimated ATE: $\hat{ATE}=\frac{1}{n}\sum_{i=1}^n\hat{CATE_i}$

- Nearest-neighbor PEHE (approximate variance in $\hat{CATE}$ error):
$$PEHE_{nn}=\frac{1}{N}\sum_{i=1}^N{(\underbrace{(1−2t_i)(y_i(t_i)−y_i^{nn}(1-t_i)}_{CATE_{nn}}−\underbrace{(h(\Phi(x),1)−h(\Phi(x),0)))}_{\hat{CATE}}}^2$$ for nearest neighbor $j$ of each unit $i$ in representation space such that $t_j\neq t_i$:
  $$y_i^{nn}(1-t_i) = \min_{j\in (1-T)}||\Phi(x_i|t_i)-\Phi(x_j|1-t_i)||_2$$



## Treatment Modeling in Neural Networks
 
Beyond outcome modeling, another approach to reducing confounding is adjusting for selection into treatment. This is typically done using the *propensity score*. If the $ATE$ is identifiable by adjusting for $X$, then the propensity score $\pi(X,T)=P(T|X)$ is sufficient to identify the $ATE$ as well (Rosenbaum and Rubin, 1983). We can estimate the ATE using inverse propensity score weighting:
 
$\hat{ATE}=[\frac{T}{\pi(X,T)}-\frac{1-T}{\pi(X,1-T)}]\cdot Y$
 
To use the IPW estimator with a neural network, we can trivially add a third "head" to predict the treatment from the representation $\Phi$ (actually if we *just* wanted to do IPW we don't need the other two heads at all),
 
$\hat{ATE}=[\frac{T}{\pi(\Phi(X),T)}-\frac{1-T}{\pi(\Phi(X),1-T)}]\cdot Y$
 
 This is the "Dragonnet" architecture from [Shi et al., 2019](https://arxiv.org/pdf/1906.02120.pdf).
 
<figure><img src=http://drive.google.com/uc?export=view&id=1E20cDRbwvJNdDChqSs0Qp-SNyklxt_rQ width="900"><figcaption>Dragonnet architecture introduced in Shi et al., 2019. This is just TARNet with a third head (single neuron) predicting the propensity score $P(T)=\pi(\Phi(X),T)$.</a></figcaption></figure>
 
The third head could be implemented as a single neuron (as in DragonNet) or using additional layers as in ([Johansson et al. 2018](https://arxiv.org/abs/1903.03448), and [Johansson et al., 2020](https://arxiv.org/abs/2001.07426)) to produce a scalar propensity score $P(T|\Phi(X))=\pi(\Phi(X),T)$.
 
The loss function for this network looks like this:
$$\underset{\phi,\pi,h}{\arg \min}\ MSE(Y,h(\Phi(X),T)) + \alpha \cdot \text{BCE}(T,\pi(\Phi(X),T))$$
with $\alpha$ being a hyperparameter to balance the two objectives.
 
Below we break down more sophisticated ways that the propensity score is used in [Shi et al., 2019](https://arxiv.org/pdf/1906.02120.pdf) from semi-parametric estimation theory.

## Semi-parametric theory in three paragraphs

The application of semi-parametric theory to causal inference (as far as I understand it), is focused on estimating a target parameter of a distribution $P$ of treatment effects $T(P):=ATE$. While we do not know the true distribution of treatment effects because we lack counterfactuals, we do know some parameters of this distribution (e.g., the treatment assignment mechanism). We can encode these  constraints in the form of a likelihood that parametrically defines a set of possible approximate distributions of $P$ from our existing data that we'll call $\mathcal{P}$. Within this set there is a sample-inferred distribution $\tilde{P}\in\mathcal{P}$, that we can use to estimate $T(P)$ using $T(\tilde{P})$.

### Picking $\tilde{P}$

Regardless of $\tilde{P}$ chosen, $\tilde{P}\neq P \therefore T(\tilde{P})\neq T(P)$. We don't really know how to pick $\tilde{P}$ with finite data to get the best estimate $T(\tilde{P})$. We can maximize our likelihood function to pick $\tilde{P}$, but there are a lot of "nuisance" parameters in the likelihood that are not our target that we don't really care about estimating accurately, so this won't necessarily give us the best estimate of $T(P)$. This is where **influence curves** come in. 
 
 We're going to define a "nudge" parameter $\epsilon$ that moves $\tilde{P}$ closer to $P$ (thus moving $T(\tilde{P})$ closer to $T(P)$). An influence curve of $T(P)$ tells us how changes in $\epsilon$ will induce changes in $T(P+\epsilon(\tilde{P}-P))$. We'll use this influence curve to fit $\epsilon$ to get the best approximation of $T(P)$ that we can. In particular, there is a specific **efficient influence curve (EIC)** that provides us with the lowest variance estimates of $T(P)$.




# Part1: AIPW

The augmented inverse propensity weighting estimator (AIPW or sometimes AIPTW) is an estimator
that solves the efficient influence curve estimating equation for the ATE directly (i.e.,  without a nudge parameter). 

In AIPW (and TMLE), we set the mean of the EIC estimating equation equal to zero which allows us to use it to estimate the $ATE$ linearly. The estimating equation models both the outcome and the treatment. We can specify it as:

$EIC = \frac{1}{N}\sum_{i=1}^N{[(\frac{T}{\pi(\Phi(X),1)}-\frac{1-T}{\pi(\Phi(X),0)})[Y-h(\Phi(X),T)] +[h(\Phi(X),1)-h(\Phi(X),0)]}]-ATE$

$(\text{Set mean of EIC to 0})$


$ATE = \frac{1}{N}\sum_{i=1}^N{[\underbrace{\underbrace{(\frac{T}{\pi(\Phi(X),1)}-\frac{1-T}{\pi(\Phi(X),0)})}_{\text{Treatment Modeling}}\times\underbrace{[Y-h(\Phi(X),T)]}_{\text{Residual Confounding}}}_{\text{Adjustment}} +\underbrace{[h(\Phi(X),1)-h(\Phi(X),0)]}_{\text{Outcome Modeling}}}]$

There is another interpretation of the AIPW as a "doubly robust" estimator. As a doubly robust estimator, we are effectively using Dragonnet to do outcome modeling of $T(\tilde{P})$ in the second term, but account for any residual confounding (second part of the first term) using a function of the propensity score. Doubly robust estimators are appealing because they will produce a consistent estimate of the $ATE$ if either $\pi$ or $h$ is estimated consistently, and are efficient if both are estimated correctly to solve the estimating equation.

## Implementing AIPW

Let's begin by implementing AIPW with TARNet. We only need to add the propensity score head to the network and write a new loss function. 

Because our loss function has a hyperparameter, we'll now need to either add a closure within our [custom loss function](https://towardsdatascience.com/creating-custom-loss-functions-using-tensorflow-2-96c123d5ce6c), or write a custom loss object. I don't like closures, so let's do an object. 

We'll call this `StandardLoss` because it's the "standard" machine learning loss for each of the three heads of our network: $MSE$ for the $h(\Phi(X),T)$ heads, and binary crossentropy ($BCE$) for the propensity score head $\pi(\Phi(X),T)$. If you're not familiar with binary crossentropy/log loss, it's the standard binary classification loss used in deep learning and used for likelihood maximization in logistic regression.


In [None]:
import tensorflow as tf
import numpy as np

from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Concatenate
from tensorflow.keras import regularizers
from tensorflow.keras import Model
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.metrics import binary_accuracy
from tensorflow.keras.losses import Loss
def make_aipw(input_dim, reg_l2):

    x = Input(shape=(input_dim,), name='input')
    # representation
    phi = Dense(units=200, activation='elu', kernel_initializer='RandomNormal',name='phi_1')(x)
    phi = Dense(units=200, activation='elu', kernel_initializer='RandomNormal',name='phi_2')(phi)
    phi = Dense(units=200, activation='elu', kernel_initializer='RandomNormal',name='phi_3')(phi)

    # HYPOTHESIS
    y0_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_1')(phi)
    y1_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_1')(phi)

    # second layer
    y0_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_2')(y0_hidden)
    y1_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_2')(y1_hidden)

    # third
    y0_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y0_predictions')(y0_hidden)
    y1_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y1_predictions')(y1_hidden)

    #propensity prediction
    #Note that the activation is actually sigmoid, but we will squish it in the loss function for numerical stability reasons
    t_prediction = Dense(units=1,activation=None,name='t_prediction')(phi)

    concat_pred = Concatenate(1)([y0_predictions, y1_predictions,t_prediction,phi])
    model = Model(inputs=x, outputs=concat_pred)
    return model

class Base_Loss(Loss):
    #initialize instance attributes
    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = alpha
        self.name='standard_loss'

    def split_pred(self,concat_pred):
        #generic helper to make sure we dont make mistakes
        preds={}
        preds['y0_pred'] = concat_pred[:, 0]
        preds['y1_pred'] = concat_pred[:, 1]
        preds['t_pred'] = concat_pred[:, 2]
        preds['phi'] = concat_pred[:, 3:]
        return preds

    #for logging purposes only
    def treatment_acc(self,concat_true,concat_pred):
        t_true = concat_true[:, 1]
        p = self.split_pred(concat_pred)
        #Since this isn't used as a loss, I've used tf.reduce_mean for interpretability
        return tf.reduce_mean(binary_accuracy(t_true, tf.math.sigmoid(p['t_pred']), threshold=0.5))

    def treatment_bce(self,concat_true,concat_pred):
        t_true = concat_true[:, 1]
        p = self.split_pred(concat_pred)
        lossP = tf.reduce_sum(binary_crossentropy(t_true,p['t_pred'],from_logits=True))
        return lossP
    
    def regression_loss(self,concat_true,concat_pred):
        y_true = concat_true[:, 0]
        t_true = concat_true[:, 1]
        p = self.split_pred(concat_pred)
        loss0 = tf.reduce_sum((1. - t_true) * tf.square(y_true - p['y0_pred']))
        loss1 = tf.reduce_sum(t_true * tf.square(y_true - p['y1_pred']))
        return loss0+loss1

    def standard_loss(self,concat_true,concat_pred):
        lossR = self.regression_loss(concat_true,concat_pred)
        lossP = self.treatment_bce(concat_true,concat_pred)
        return lossR + self.alpha * lossP

    #compute loss
    def call(self, concat_true, concat_pred):        
        return self.standard_loss(concat_true,concat_pred)
        

Now let's add AIPW to our callback.


In [None]:
from tensorflow.keras.callbacks import Callback

def pdist2sq(A, B):
    #helper for PEHEnn
    #calculates squared euclidean distance between rows of two matrices  
    #https://gist.github.com/mbsariyildiz/34cdc26afb630e8cae079048eef91865
    # squared norms of each row in A and B
    na = tf.reduce_sum(tf.square(A), 1)
    nb = tf.reduce_sum(tf.square(B), 1)    
    # na as a row and nb as a column vectors
    na = tf.reshape(na, [-1, 1])
    nb = tf.reshape(nb, [1, -1])
    # return pairwise euclidean difference matrix
    D=tf.reduce_sum((tf.expand_dims(A, 1)-tf.expand_dims(B, 0))**2,2) 
    return D

#https://towardsdatascience.com/implementing-macro-f1-score-in-keras-what-not-to-do-e9f1aa04029d
class AIPW_Metrics(Callback):
    def __init__(self,data, verbose=0):   
        super(AIPW_Metrics, self).__init__()
        self.data=data #feed the callback the full dataset
        self.verbose=verbose

        #needed for PEHEnn; Called in self.find_ynn
        self.data['o_idx']=tf.range(self.data['t'].shape[0])
        self.data['c_idx']=self.data['o_idx'][self.data['t'].squeeze()==0] #These are the indices of the control units
        self.data['t_idx']=self.data['o_idx'][self.data['t'].squeeze()==1] #These are the indices of the treated units
    
    def split_pred(self,concat_pred):
        preds={}
        preds['y0_pred'] = self.data['y_scaler'].inverse_transform(concat_pred[:, 0])
        preds['y1_pred'] = self.data['y_scaler'].inverse_transform(concat_pred[:, 1])
        preds['t_pred'] = concat_pred[:, 2]
        preds['phi'] = concat_pred[:, 3:]
        return preds

    def find_ynn(self, Phi):
        #helper for PEHEnn
        PhiC, PhiT =tf.dynamic_partition(Phi,tf.cast(tf.squeeze(self.data['t']),tf.int32),2) #separate control and treated reps
        dists=tf.sqrt(pdist2sq(PhiC,PhiT)) #calculate squared distance then sqrt to get euclidean
        yT_nn_idx=tf.gather(self.data['c_idx'],tf.argmin(dists,axis=0),1) #get c_idxs of smallest distances for treated units
        yC_nn_idx=tf.gather(self.data['t_idx'],tf.argmin(dists,axis=1),1) #get t_idxs of smallest distances for control units
        yT_nn=tf.gather(self.data['y'],yT_nn_idx,1) #now use these to retrieve y values
        yC_nn=tf.gather(self.data['y'],yC_nn_idx,1)
        y_nn=tf.dynamic_stitch([self.data['t_idx'],self.data['c_idx']],[yT_nn,yC_nn]) #stitch em back up!
        return y_nn

    def PEHEnn(self,concat_pred):
        p = self.split_pred(concat_pred)
        y_nn = self.find_ynn(p['phi']) #now its 3 plus because 
        cate_nn_err=tf.reduce_mean( tf.square( (1-2*self.data['t']) * (y_nn-self.data['y']) - (p['y1_pred']-p['y0_pred']) ) )
        return cate_nn_err

    def ATE(self,concat_pred):
        p = self.split_pred(concat_pred)
        return p['y1_pred']-p['y0_pred']

    def PEHE(self,concat_pred):
        #simulation only
        p = self.split_pred(concat_pred)
        cate_err=tf.reduce_mean( tf.square( ( (self.data['mu_1']-self.data['mu_0']) - (p['y1_pred']-p['y0_pred']) ) ) )
        return cate_err 
   
    #THIS IS THE NEW PART
    def AIPW(self,concat_pred):
        p = self.split_pred(concat_pred)
        t_pred=tf.math.sigmoid(p['t_pred'])
        t_pred = (t_pred + 0.001) / 1.002 # a little numerical stability trick implemented by Shi
        y_pred = p['y0_pred'] * (1 - self.data['t']) + p['y1_pred'] * self.data['t']
        #cc stands for clever covariate which is I think what it's called in TMLE lit
        cc = self.data['t'] * (1.0 / p['t_pred']) - (1.0 - self.data['t']) / (1.0 - p['t_pred'])
        cate = cc * (self.data['y'] - y_pred) + p['y1_pred'] - p['y0_pred']
        return cate

    def on_epoch_end(self, epoch, logs={}):
        concat_pred=self.model.predict(self.data['x'])
        #Calculate Empirical Metrics        
        ate_pred=tf.reduce_mean(self.ATE(concat_pred)); tf.summary.scalar('ate', data=ate_pred, step=epoch)
        pehe_nn=self.PEHEnn(concat_pred); tf.summary.scalar('cate_nn_err', data=tf.sqrt(pehe_nn), step=epoch)
        aipw=tf.reduce_mean(self.AIPW(concat_pred)); tf.summary.scalar('aipw', data=aipw, step=epoch)
        
        #Simulation Metrics
        ate_true=tf.reduce_mean(self.data['mu_1']-self.data['mu_0'])
        ate_err=tf.abs(ate_true-ate_pred); tf.summary.scalar('ate_err', data=ate_err, step=epoch)
        pehe =self.PEHE(concat_pred); tf.summary.scalar('cate_err', data=tf.sqrt(pehe), step=epoch)
        aipw_err =self.PEHE(concat_pred); tf.summary.scalar('aipw_err', data=aipw_err, step=epoch)
        out_str=f' — ate_err: {ate_err:.4f}  — aipw_err: {aipw_err:.4f} — cate_err: {tf.sqrt(pehe):.4f} — cate_nn_err: {tf.sqrt(pehe_nn):.4f} '
        
        if self.verbose > 0: print(out_str)

Now reload our data...

In [None]:
from sklearn.preprocessing import StandardScaler
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.train.npz
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.test.npz 

def load_IHDP_data(training_data,testing_data,i=7):
    with open(training_data,'rb') as trf, open(testing_data,'rb') as tef:
        train_data=np.load(trf); test_data=np.load(tef)
        y=np.concatenate(   (train_data['yf'][:,i],   test_data['yf'][:,i])).astype('float32') #most GPUs only compute 32-bit floats
        t=np.concatenate(   (train_data['t'][:,i],    test_data['t'][:,i])).astype('float32')
        x=np.concatenate(   (train_data['x'][:,:,i],  test_data['x'][:,:,i]),axis=0).astype('float32')
        mu_0=np.concatenate((train_data['mu0'][:,i],  test_data['mu0'][:,i])).astype('float32')
        mu_1=np.concatenate((train_data['mu1'][:,i],  test_data['mu1'][:,i])).astype('float32')

        data={'x':x,'t':t,'y':y,'t':t,'mu_0':mu_0,'mu_1':mu_1}
        data['t']=data['t'].reshape(-1,1) #we're just padding one dimensional vectors with an additional dimension 
        data['y']=data['y'].reshape(-1,1)
        
        #rescaling y between 0 and 1 often makes training of DL regressors easier
        data['y_scaler'] = StandardScaler().fit(data['y'])
        data['ys'] = data['y_scaler'].transform(data['y'])

    return data

data=load_IHDP_data(training_data='./ihdp_npci_1-100.train.npz',testing_data='./ihdp_npci_1-100.test.npz')

Now we're ready to train. Note that I did something I think is nifty this time: I wrote my sublosses as having `y_true` and `y_pred` arguments so I could add them as metrics in `.fit`.

In [None]:
import tensorflow as tf
import numpy as np
import datetime
%load_ext tensorboard

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras.optimizers import SGD, Adam

val_split=0.2
batch_size=64
verbose=1
i = 0
tf.random.set_seed(i)
np.random.seed(i)
yt = np.concatenate([data['ys'], data['t']], 1)

# Clear any logs from previous runs
!rm -rf ./logs/ 
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

sgd_callbacks = [
        TerminateOnNaN(),
        EarlyStopping(monitor='val_loss', patience=40, min_delta=0.),
        ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=verbose, mode='auto',
                          min_delta=0., cooldown=0, min_lr=0),
        tensorboard_callback,
        AIPW_Metrics(data,verbose=verbose)
        ]
      

sgd_lr = 1e-5
momentum = 0.9

aipw_model=make_aipw(data['x'].shape[1],.01)
aipw_loss=Base_Loss(alpha=1.0)

aipw_model.compile(optimizer=SGD(lr=sgd_lr, momentum=momentum, nesterov=True),
                    loss=aipw_loss,
                    metrics=[aipw_loss,aipw_loss.regression_loss,aipw_loss.treatment_acc]
                   )

aipw_model.fit(x=data['x'],y=yt,
                  callbacks=sgd_callbacks,
                  validation_split=val_split,
                  epochs=300,
                  batch_size=batch_size,
                  verbose=verbose)



## Reviewing results in Tensorboard

Let's do a quick comparison to last tutorial where we ran the exact same network without the propensity score loss.

If we look at `treatment_acc`, it's clear that Dragonnet is learning the treatment information. We can also see that there is only a very slight penalty in the network's ability to predict the outcomes (`ate_err`).

It's hard to say whether any other performance differences are significant without doing hyperparameter tuning under both scenarios and looking across multiple simulations. For what it's worth, the `aipw_err` is slightly worse than the raw `ate_err` but it's pretty close, and the statistical guarantees are definitely worth something.  

In [None]:
%tensorboard --logdir logs/fit

The issue with inverse propensity score weighting estimators is that the weights are unstable for units that have a propensity score close to 0 or 1. In the case of the AIPW equation above, it can potentially produce estimates of treatment that are less than zero or greater than one. This is one of the practical motivations for using nudge parameters as in TMLE or Targeted Regularization.

----   

# Part 2: Targeted Regularization




## Aside: Targeted Maximum Likelihood Estimation Algorithm

We won't implement Targeted Maximum Likelihood Estimation (TMLE) *per se*, but "Targeted Regularization" is closely related so it's useful to see it first.

TMLE is an iterative procedure where we actually use the nudge parameter $\epsilon$ to reduce bias when we fit the EIC estimating equation. We then calculate $\hat{ATE}$ by using nudged variants of our outcome predictions (see below). 

We begin again by setting the EIC equation to 0. 

$$EIC = 0 = \frac{1}{N}\sum_{i=1}^N{[\underbrace{\underbrace{(\frac{T}{\pi(x,1)}-\frac{1-T}{\pi(x,0)})}_{\text{Treatment Modeling}}\times\underbrace{(Y-h(x,T))}_{\text{Residual Confounding}}}_{\text{Adjustment}}] +\underbrace{[h(x,1)-h(x,0)]}_{\text{Outcome Modeling}}}]-ATE$$
$$ATE = \frac{1}{N}\sum_{i=1}^N{[\underbrace{\underbrace{(\frac{T}{\pi(x,1)}-\frac{1-T}{\pi(x,0)})}_{\text{Treatment Modeling}}\times\underbrace{(Y-h(x,T))}_{\text{Residual Confounding}}}_{\text{Adjustment}}] +\underbrace{[h(x,1)-h(x,0)]}_{\text{Outcome Modeling}}}]$$
The TMLE procedure consists of four **consecutive** steps:
1. Fit $h$ by predicting outcomes (e.g., using TARNet) and minimizing $MSE(Y,h(X,T))$
2. Fit $\pi$ by predicting treatment (e.g., using logistic regression) and $BCE(T,\pi(X,T))$
3. Plug-in $h$ and $\pi$ functions to fit $\epsilon$ and estimate $h^{*}(X,T)$ where,
$$h^{*}(X,T)=h(X,T)+(\frac{T}{\pi(X,T)}-\frac{1-T}{\pi(X,1-T)})\times \epsilon$$
by minimizing $MSE(Y,h^{*}(X,T))$.

(If you look at the EIC, this is equivalent to minimizing the "Adjustment" part.)

4. Plug-in $h^*(X,T)$ to estimate $\hat{ATE}$:
$$\hat{ATE}=\frac{1}{N}\sum_{i=1}^N{ h^*(x,1)}-{ h^*(x,0)}$$

## Targeted Regularization

Targeted Regularization, introduced in [Shi et al., 2019](https://arxiv.org/pdf/1906.02120.pdf), takes TMLE and adapts it for a neural network loss function. The main difference is that steps 1 and 2 above are done concurrently by Dragonnet, and that the loss functions for the first three steps are combined into a single loss applied to the whole network at the end of each batch. It requires adding a single free parameter to the Dragonnet network for $\epsilon$.

 At a very intuitive level, Targeted Regularization is appealing because it introduces a loss function to TARNet that explicitly encourages the network to learn the treatment effect distribution and not just the outcome distribution. The Targeted Regularization procedure proceeds as follows:

In each epoch:
<ol>
  <li>
    <ol type="a">
    <li>
    Use Dragonnet to predict $h(\Phi(X),T)$ and $\pi(\Phi(X),T)$.
    </li>
    <li>
    Calculate the standard ML loss for the network using a hyperparameter $\alpha$:
    $$MSE(Y,h(\Phi(X),T)) + \alpha \cdot BCE(\pi(\Phi(X),T))$$
    </li>
    </ol>
  </li>
  <li>
  <ol type="a">
  <li> Compute $h^{*}(\Phi(X),T)$ as above,
  $$h^{*}(\Phi(x),T)=h(\Phi(x),T)+(\frac{T}{\pi(\Phi(X),T)}-\frac{1-T}{\pi(\phi(X),1-T)})\times \epsilon$$ </li>
  <li> Calculate the targeted regularization loss: $MSE(Y,h^*(\Phi(X),T))$ </li>
  </ol>
  </li>
  <li> Combine and minimize the losses from 1 and 2 using a hyperparameter $\beta$,
    $$\underset{\Phi,h,\epsilon}{\arg \min}= [MSE(Y,h(\Phi(X),T)) + \alpha \cdot \text{BCE}(T,\pi(\Phi(X,T))]+\beta\cdot MSE(Y,h^*(\Phi(X),T))$$ </li>
</ol>

**Step 3 of Targeted Regularization is exactly equivalent to minimizing $\beta \cdot EIC$.** 

At the end of training, we can thus estimate the targeted regularization estimate of  the ATE $\hat{ATE_{TR}}$ as in TMLE:
$$\hat{ATE_{TR}}=\frac{1}{N}\sum_{i=1}^N{ h^*(\Phi(x),1)}-{ h^*(\Phi(x),0)}$$

Note that because they solve the EIC estimating equation for the ATE, both TMLE and Targeted Regularization are doubly robust estimators.

Again, the key difference between TMLE and Targeted Regularization is that in Targeted Regularization we are jointly tuning $h$ and $\epsilon$ at the same time every batch, rather than fitting them completely in succession and plugging-in as in TMLE.


## Implementing Targeted Regularization
The rest of this tutorial is basically an annotated cut and paste from Shi's elegant repository.

 We'll need to start by adding $\epsilon$ as a parameter in our neural network. Shi does this by creating a [custom layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) she calls `EpsilonLayer`. 

In [None]:
from tensorflow.keras.layers import Layer
class EpsilonLayer(Layer):

    def __init__(self):
        super(EpsilonLayer, self).__init__()

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.epsilon = self.add_weight(name='epsilon',
                                       shape=[1, 1],
                                       initializer='RandomNormal',
                                       #  initializer='ones',
                                       trainable=True)
        super(EpsilonLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, inputs, **kwargs):
        #note there is only one epsilon were just duplicating it for conformability
        return self.epsilon * tf.ones_like(inputs)[:, 0:1]

def make_dragonnet(input_dim, reg_l2):

    x = Input(shape=(input_dim,), name='input')
    # representation
    phi = Dense(units=200, activation='elu', kernel_initializer='RandomNormal',name='phi_1')(x)
    phi = Dense(units=200, activation='elu', kernel_initializer='RandomNormal',name='phi_2')(phi)
    phi = Dense(units=200, activation='elu', kernel_initializer='RandomNormal',name='phi_3')(phi)

    # HYPOTHESIS
    y0_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_1')(phi)
    y1_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_1')(phi)

    # second layer
    y0_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_2')(y0_hidden)
    y1_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_2')(y1_hidden)

    # third
    y0_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y0_predictions')(y0_hidden)
    y1_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y1_predictions')(y1_hidden)

    #propensity prediction
    #Note that the activation is actually sigmoid, but we will squish it in the loss function for numerical stability reasons
    t_predictions = Dense(units=1,activation=None,name='t_prediction')(phi)
    #Although the epsilon layer takes an input, it really just houses a free parameter. 
    epsilons = EpsilonLayer()(t_predictions)
    concat_pred = Concatenate(1)([y0_predictions, y1_predictions,t_predictions,epsilons,phi])
    model = Model(inputs=x, outputs=concat_pred)
    return model


Now let's write the loss. You should be able to follow along from the Latex above. To save a bit of code we'll subclass the [`Base_Loss`](https://colab.research.google.com/drive/1gtOd0O0kJDMJe4emJCUaR7ZQRvrM53xa?authuser=2#scrollTo=0wgc7wTySg24&line=1&uniqifier=1) we implemented above.

In [None]:
class TarReg_Loss(Base_Loss):
    #initialize instance attributes
    def __init__(self, alpha=1,beta=1):
        super().__init__()
        self.alpha = alpha
        self.beta=beta
        self.name='tarreg_loss'

    def split_pred(self,concat_pred):
        #generic helper to make sure we dont make mistakes
        preds={}
        preds['y0_pred'] = concat_pred[:, 0]
        preds['y1_pred'] = concat_pred[:, 1]
        preds['t_pred'] = concat_pred[:, 2]
        preds['epsilon'] = concat_pred[:, 3] #we're moving epsilon into slot three
        preds['phi'] = concat_pred[:, 4:]
        return preds

    def calc_hstar(self,concat_true,concat_pred):
        #step 2 above
        p=self.split_pred(concat_pred)
        y_true = concat_true[:, 0]
        t_true = concat_true[:, 1]

        t_pred = tf.math.sigmoid(concat_pred[:, 2])
        t_pred = (t_pred + 0.001) / 1.002 # a little numerical stability trick implemented by Shi
        y_pred = t_true * p['y1_pred'] + (1 - t_true) * p['y0_pred']

        #calling it cc for "clever covariate" as in SuperLearner TMLE literature
        cc = t_true / t_pred - (1 - t_true) / (1 - t_pred)
        h_star = y_pred + p['epsilon'] * cc
        return h_star

    def call(self,concat_true,concat_pred):
        y_true = concat_true[:, 0]

        standard_loss=self.standard_loss(concat_true,concat_pred)
        h_star=self.calc_hstar(concat_true,concat_pred)
        #step 3 above
        targeted_regularization = tf.reduce_sum(tf.square(y_true - h_star))

        # final
        loss = standard_loss + self.beta * targeted_regularization
        return loss

Now we update our callback so that it computes $h*$ and the final, plug-in $\hat{ATE}_{\text{TR}}$. Looking at the Latex may again be helpful. We can save some code lines by subclassing.


In [None]:
class TarReg_Metrics(AIPW_Metrics):
    def __init__(self,data, verbose=0):   
        super().__init__(data,verbose)

    def split_pred(self,concat_pred):
        preds={}
        preds['y0_pred'] = self.data['y_scaler'].inverse_transform(concat_pred[:, 0])
        preds['y1_pred'] = self.data['y_scaler'].inverse_transform(concat_pred[:, 1])
        preds['t_pred'] = concat_pred[:, 2]
        preds['epsilon'] = concat_pred[:, 3]
        preds['phi'] = concat_pred[:, 4:]
        return preds
    
    def compute_hstar(self,y0_pred,y1_pred,t_pred,t_true,epsilons):
        #helper for calculating the targeted regularization cate
        y_pred = t_true * y1_pred + (1 - t_true) * y0_pred
        cc = t_true / t_pred - (1 - t_true) / (1 - t_pred)
        h_star = y_pred + epsilons * cc
        return h_star
    
    def TARREG_CATE(self,concat_pred):
        #Final calculation of Targeted Regularization loss
        p = self.split_pred(concat_pred)
        t_pred = tf.math.sigmoid(p['t_pred'])
        t_pred = (t_pred + 0.001) / 1.002 # a little numerical stability trick implemented by Shi       
        hstar_0=self.compute_hstar(p['y0_pred'],p['y1_pred'],t_pred,tf.zeros_like(p['epsilon']),p['epsilon'])
        hstar_1=self.compute_hstar(p['y0_pred'],p['y1_pred'],t_pred,tf.ones_like(p['epsilon']),p['epsilon'])
        return hstar_1-hstar_0

    def on_epoch_end(self, epoch, logs={}):
        concat_pred=self.model.predict(self.data['x'])
        #Calculate Empirical Metrics        
        aipw_pred=tf.reduce_mean(self.AIPW(concat_pred)); tf.summary.scalar('aipw', data=aipw_pred, step=epoch)
        ate_pred=tf.reduce_mean(self.ATE(concat_pred)); tf.summary.scalar('ate', data=ate_pred, step=epoch)
        tarreg_pred=tf.reduce_mean(self.TARREG_CATE(concat_pred)); tf.summary.scalar('tarreg_pred', data=tarreg_pred, step=epoch)
        pehe_nn=self.PEHEnn(concat_pred); tf.summary.scalar('cate_nn_err', data=tf.sqrt(pehe_nn), step=epoch)
        
        #Simulation Metrics
        ate_true=tf.reduce_mean(self.data['mu_1']-self.data['mu_0'])
        ate_err=tf.abs(ate_true-ate_pred); tf.summary.scalar('ate_err', data=ate_err, step=epoch)
        aipw_err=tf.abs(ate_true-aipw_pred); tf.summary.scalar('aipw_err', data=aipw_err, step=epoch)
        tarreg_err=tf.abs(ate_true-tarreg_pred); tf.summary.scalar('tarreg_err', data=tarreg_err, step=epoch)
        pehe =self.PEHE(concat_pred); tf.summary.scalar('cate_err', data=tf.sqrt(pehe), step=epoch)
        out_str=f' — ate_err: {ate_err:.4f}  — aipw_err: {aipw_err:.4f} — tarreg_err: {tarreg_err:.4f} — cate_err: {tf.sqrt(pehe):.4f} — cate_nn_err: {tf.sqrt(pehe_nn):.4f} '
        
        if self.verbose > 0: print(out_str)

Cool. Now we can run it!

In [None]:
import tensorflow as tf
import numpy as np
import datetime
%load_ext tensorboard

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras.optimizers import SGD, Adam

val_split=0.2
batch_size=64
verbose=1
i = 0
tf.random.set_seed(i)
np.random.seed(i)
yt = np.concatenate([data['ys'], data['t']], 1)

# Clear any logs from previous runs
!rm -rf ./logs/ 
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0)

sgd_callbacks = [
        TerminateOnNaN(),
        EarlyStopping(monitor='val_loss', patience=40, min_delta=0.),
        ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=verbose, mode='auto',
                          min_delta=0., cooldown=0, min_lr=0),
        tensorboard_callback,
        TarReg_Metrics(data, verbose=verbose)    ]

sgd_lr = 1e-5
momentum = 0.9

dragonnet_model=make_dragonnet(data['x'].shape[1],.01)
tarreg_loss=TarReg_Loss(alpha=1)

dragonnet_model.compile(optimizer=SGD(lr=sgd_lr, momentum=momentum, nesterov=True),
                      loss=tarreg_loss,
                 metrics=[tarreg_loss,tarreg_loss.regression_loss,tarreg_loss.treatment_acc])

dragonnet_model.fit(x=data['x'],y=yt,
                 callbacks=sgd_callbacks,
                  validation_split=val_split,
                  epochs=300,
                  batch_size=batch_size,
                  verbose=verbose)

As usual we'll check things in TensorBoard. Again, we don't see that much of a difference in predicting the $ATE$ (`tarreg_err` vs. `ate_err`) but there is a big gain in `cate_err` on this dataset.

In [None]:
%tensorboard --logdir logs/fit

# Confidence Intervals

Because Targeted regularization is essentially TMLE, standard deviation of $\sigma_{\hat{ATE}}$ is the sample-corrected standard deviation of the EIC estimating equation, where
$$\sigma_{\hat{ATE_{TR}}}=\sqrt{\frac{Var(IC_\hat{ATE_{TR}})}{n}}$$ and,

$$Var(IC_\hat{ATE_{TR}}) = Var[(\frac{T}{\pi(x,1)}-\frac{1-T}{\pi(x,0)})(Y-h^*(x,T))+(h^*(x,1)-h^*(x,0))-\hat{ATE_{TR}}]$$

Alternatively, you can estimate confidence intervals through bootstrapping.


# Small caveat

Note that semi-parametric theory is focused on estimating the $ATE$ parameter of the treatment effect distribution, and not heterogeneous treatment effects. In fact, the Dragonnet paper does not report $CATE$ estimates because this is not the focus of the paper. But in this toy example, it does seems like Targeted Regularization improved prediction on the $CATE$ as well. It makes sense that creating an explicit treatment effect loss would improve counterfactual prediction.


# That's it!

In this tutorial we:

- Introduced treatment modeling and the Dragonnet architecture.

- Built a custom object oriented loss for Dragonnet and adapted our callback to estimate the AIPW

- Explained the TMLE and Targeted Regularization algorithms and implemented targeted regularization.

# Up next...

In the final tutorial, we introduce Counterfactual Regression Network (CFRNet) described in [Shalit et al., 2017](https://arxiv.org/abs/1606.03976), [Johansson et al. 2018](https://arxiv.org/abs/1903.03448), and [Johansson et al., 2020](https://arxiv.org/abs/2001.07426). Instead of propensity modeling, CFRNet uses integral probability metrics (IPMs) to bound the counterfactual generalization error. A variant called weighted CFRNet combines IPMs with propensity weighting.
