These tutorials are licensed by [Bernard Koch](http://www.github.com/kochbj) under a [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/).

# Tutorial 4: Using Integral Probability Metrics to Bound the Counterfactual Loss

Instead of using  semi-parametric corrections based on the propensity score to adjust ATE estimates, [Shalit et al., 2017](https://arxiv.org/abs/1606.03976), [Johansson et al. 2019](https://arxiv.org/abs/1903.03448), and [Johansson et al., 2020](https://arxiv.org/abs/2001.07426) take a different approach to encourage the representation function $\Phi$ to explicitly learn treatment effects: integral probability metrics. 

Integral probability metrics (IPMs) are true metrics (symmetric and obey the triangle inequality unlike KL divergence) that measure the distance between two distributions. **The key idea here is that we can add an IPM as a loss in TARNET to more explicitly encourage the representation layers to balance the covariate distribution of the treated group $T$, and the covariate distribution of the control group $C$.** In recent years, IPMs have become very popular losses for generative adversarial networks.   

Theoretically, this group has developed generalization bounds for the $CATE$ that show that even though the counterfactual loss is unknowable, it can be bounded by the factual loss of $h$ and an IPM between the treated and control distributions. Their claims focus on two IPMs in particular:

## Wasserstein Distance

The Wasserstein or “Earth Mover’s” distance fits an interpretable ”map” (i.e.  a matrix) showing how to efficiently convert from one probability mass distribution to another. The  Wasserstein  distance  is  most  easily  understood  as  an  optimal  transport  problem  (i.e., a scenario where we want to transport one distribution to another at minimum cost).  The nickname "Earth mover’s distance” comes from the metaphor of shoveling dirt to terraform one landscape into another.  In the idealized case in which one distribution can be perfectly transformed into another, the Wasserstein map corresponds exactly to a perfect one-to-one matching on covariates strategy ([Kallus, 2016](https://arxiv.org/abs/1612.08321)). Because calculating the Wasserstein distance requires optimizing the map, we'll focus on the computationally simpler Maximum Mean Discrepancy metric. 

## Maximum Mean Discrepancy (MMD)

The  maximum  mean  discrepancy (MMD)  is  the  normed  distance between the means of two distributions $T$ and $C$,  after a kernel function has transformed them into a high-dimensional reproducing kernel Hibbert Space (RKHS). It relies on the kernel trick to calculate distances in the RKHS. The metric is built on the intuition that there is no function that would have differing expected values for $T$ and $C$ in this high-dimensional space, if $T$ and $C$ are the same distribution. Formally we can define the MMD as

$$MMD(T,C) = ||\mathbb{E}_{X \sim T}\phi(X) - \mathbb{E}_{X \sim C}\phi(X)||^2_{\mathcal{H}}$$

where $\phi$ is associated with a reproducing kernel function $k$ such as the Gaussian radial basis function kernel: $$k(X,X')=\exp({- \frac{||X-X'||^2}{2\sigma^2} }) $$

From [Johansson et al., 2020](https://arxiv.org/abs/2001.07426), an unbiased estimator for the square of the MMD from a sample of size $m$ drawn from treated distribution $T$ and a sample of size $n$ drawn from control distribution $C$ is,

$$\hat{MMD}^2_k(T,C):=\frac{1}{m(m-1)}\sum_{i=1}^m\sum_{i=j}^mk(x_i^T,x_j^T)-\frac{2}{mn}\sum_{i=1}^m\sum_{i=j}^nk(x_i^T,x_j^C)+\frac{1}{n(n-1)}\sum_{i=1}^n\sum_{i=j}^nk(x_i^C,x_j^C)$$



## Notation
**Causal identification**

- Observed covariates/features: $X$

- Potential outcomes: $Y(0)$ and $Y(1)$

- Treatment: $T$

- Unobservable Individual Treatment Effect: $\tau_i = Y_i(1) - Y_i(0)$

- Average Treatment Effect: $ATE =\mathbb{E}[Y_i(1)-Y_i(0)]= \mathbb{E}[{\tau_i}]$

- Conditional Average Treatment Effect: $CATE(x) =\mathbb{E}[Y_i(1)-Y_i(0)|X=x]$


**Deep learning estimation**

- Predicted outcomes: $\hat{Y}(0)$ and $\hat{Y}(1)$

- Outcome modeling functions: $\hat{Y}(T)=h(X,T)$ 

- Representation functions: $\Phi(X)$

- Propensity score function:
$\pi(X,T)=P(T|X)$ </br>*where $\pi(X,1)=P(T=1|X)$ and $\pi(X,0)=1-\pi(X,1)$* 

- Loss functions: $\mathcal{L}(true,predicted)$, with the mean squared error abbreviated $MSE$ and binary cross-entropy as $BCE$

- Estimated CATE<sup>*</sup>: $\hat{CATE_i} = \hat{\tau}_i = \hat{Y_i}(1)-\hat{Y_i}(0) = h(X,1)-h(X,0)$

- Estimated ATE: $\hat{ATE}=\frac{1}{n}\sum_{i=1}^n\hat{CATE_i}$

- Nearest-neighbor PEHE:
$$PEHE_{nn}=\frac{1}{N}\sum_{i=1}^N{(\underbrace{(1−2t_i)(y_i(t_i)−y_i^{nn}(1-t_i)}_{CATE_{nn}}−\underbrace{(h(\Phi(x),1)−h(\Phi(x),0)))}_{\hat{CATE}}}^2$$ for nearest neighbor $j$ of each unit $i$ in representation space such that $t_j\neq t_i$:
  $$y_i^{nn}(1-t_i) = \min_{j\in (1-T)}||\Phi(x_i|t_i)-\Phi(x_j|1-t_i)||_2$$

\* We define $\hat{\tau}_i = \hat{CATE_i}$ because the we lack the covariates to estimate the ITE.  

# Part 1: Coding up CFRNet using the MMD

Johansson et al. call the variant of TARNet with an IPM loss  Counterfactual Regression Network (CFRNet). Architecturally CFRNet is identical to TARNet, but the loss function looks like this:

\begin{equation}
\min_{h,\Phi,IPM}\frac{1}{n}\sum_{i=1}^n \mathcal{L}(h(\Phi(x_i),t_i),y_i) + \lambda \mathcal{R}(h)+\alpha\cdot IPM(\Phi(X,|T=1),\Phi(X|T=0))\end{equation}
where $\mathcal{R}(h)$ is a model complexity term and $\lambda$ and $\alpha$ are hyperparameters. $IPM(\Phi(X,|T=1),\Phi(X|T=0))$ is an IPM distance between the covariate distributions of the treated and control distributions after they are projected into representation space.

<figure><img src=https://github.com/kochbj/Deep-Learning-for-Causal-Inference/blob/main/images/CFRNet.png?raw=true width="900"><figcaption>CFRNet architecture introduced in Shalit et al., 2017. Purple indicates inputs, orange indicates network layers, other colors indicate output layers, and white indicates outputs. The dashes between colored shapes indicate an unspecifed number of additional hidden layers. The dashed lines on the right indicate non-gradient, plug-in computations that occur after training.</a></figcaption></figure>

In [None]:
!pip install -q tensorflow==2.8.0
import tensorflow as tf
import numpy as np #numpy is the numerical computing package in python
import datetime #we'll use dates to label our logs
print(tf.__version__)

First reload the data...

In [None]:
!pip install scikit-learn==0.24.2
from sklearn.preprocessing import StandardScaler
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.train.npz
!wget -nc http://www.fredjo.com/files/ihdp_npci_1-100.test.npz 

def load_IHDP_data(training_data,testing_data,i=7):
    with open(training_data,'rb') as trf, open(testing_data,'rb') as tef:
        train_data=np.load(trf); test_data=np.load(tef)
        y=np.concatenate(   (train_data['yf'][:,i],   test_data['yf'][:,i])).astype('float32') #most GPUs only compute 32-bit floats
        t=np.concatenate(   (train_data['t'][:,i],    test_data['t'][:,i])).astype('float32')
        x=np.concatenate(   (train_data['x'][:,:,i],  test_data['x'][:,:,i]),axis=0).astype('float32')
        mu_0=np.concatenate((train_data['mu0'][:,i],  test_data['mu0'][:,i])).astype('float32')
        mu_1=np.concatenate((train_data['mu1'][:,i],  test_data['mu1'][:,i])).astype('float32')

        data={'x':x,'t':t,'y':y,'t':t,'mu_0':mu_0,'mu_1':mu_1}
        data['t']=data['t'].reshape(-1,1) #we're just padding one dimensional vectors with an additional dimension 
        data['y']=data['y'].reshape(-1,1)
        
        #rescaling y between 0 and 1 often makes training of DL regressors easier
        data['y_scaler'] = StandardScaler().fit(data['y'])
        data['ys'] = data['y_scaler'].transform(data['y'])

    return data

data=load_IHDP_data(training_data='./ihdp_npci_1-100.train.npz',testing_data='./ihdp_npci_1-100.test.npz')

## Creating a custom MMD loss

We'll begin by writing up MMD estimator described above in a custom loss object. The layer will take $\Phi$ and $T$ as inputs, and output $\hat{MMD^2}$. We'll implement the unbiased MMD estimator described above with the Guassian RBF function. Note that I've chosen to calculate losses with `tf.[reduce_mean](https://)n` here because there can be exploding gradients in weighted version of CFRNet (described below).

In [None]:
def pdist2sq(x,y):
    x2 = tf.reduce_sum(x ** 2, axis=-1, keepdims=True)
    y2 = tf.reduce_sum(y ** 2, axis=-1, keepdims=True)
    dist = x2 + tf.transpose(y2, (1, 0)) - 2. * x @ tf.transpose(y, (1, 0))
    return dist

from tensorflow.keras.losses import Loss

class CFRNet_Loss(Loss):
  #initialize instance attributes
  def __init__(self, alpha=1.,sigma=1.):
      super().__init__()
      self.alpha = alpha # balances regression loss and MMD IPM
      self.rbf_sigma=sigma #for gaussian kernel
      self.name='cfrnet_loss'
      
  def split_pred(self,concat_pred):
      #generic helper to make sure we dont make mistakes
      preds={}
      preds['y0_pred'] = concat_pred[:, 0]
      preds['y1_pred'] = concat_pred[:, 1]
      preds['phi'] = concat_pred[:, 2:]
      return preds

  def rbf_kernel(self, x, y):
    return tf.exp(-pdist2sq(x,y)/tf.square(self.rbf_sigma))

  def calc_mmdsq(self, Phi, t):
    Phic, Phit =tf.dynamic_partition(Phi,tf.cast(tf.squeeze(t),tf.int32),2)

    Kcc = self.rbf_kernel(Phic,Phic)
    Kct = self.rbf_kernel(Phic,Phit)
    Ktt = self.rbf_kernel(Phit,Phit)

    m = tf.cast(tf.shape(Phic)[0],Phi.dtype)
    n = tf.cast(tf.shape(Phit)[0],Phi.dtype)

    mmd = 1.0/(m*(m-1.0))*(tf.reduce_sum(Kcc))
    mmd = mmd + 1.0/(n*(n-1.0))*(tf.reduce_sum(Ktt))
    mmd = mmd - 2.0/(m*n)*tf.reduce_sum(Kct)
    return mmd * tf.ones_like(t)

  def mmdsq_loss(self, concat_true,concat_pred):
    t_true = concat_true[:, 1]
    p=self.split_pred(concat_pred)
    mmdsq_loss = tf.reduce_mean(self.calc_mmdsq(p['phi'],t_true))
    return mmdsq_loss

  def regression_loss(self,concat_true,concat_pred):
      y_true = concat_true[:, 0]
      t_true = concat_true[:, 1]
      p = self.split_pred(concat_pred)
      loss0 = tf.reduce_mean((1. - t_true) * tf.square(y_true - p['y0_pred']))
      loss1 = tf.reduce_mean(t_true * tf.square(y_true - p['y1_pred']))
      return loss0+loss1

  def cfr_loss(self,concat_true,concat_pred):
      lossR = self.regression_loss(concat_true,concat_pred)
      lossIPM = self.mmdsq_loss(concat_true,concat_pred)
      return lossR + self.alpha * lossIPM

      #return lossR + self.alpha * lossIPM

  #compute loss
  def call(self, concat_true, concat_pred):        
      return self.cfr_loss(concat_true,concat_pred)

## Metrics callback 

Paste our standard callback from previous tutorials...

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
#https://towardsdatascience.com/implementing-macro-f1-score-in-keras-what-not-to-do-e9f1aa04029d
class Base_Metrics(Callback):
    def __init__(self,data, verbose=0):   
        super(Base_Metrics, self).__init__()
        self.data=data #feed the callback the full dataset
        self.verbose=verbose

        #needed for PEHEnn; Called in self.find_ynn
        self.data['o_idx']=tf.range(self.data['t'].shape[0])
        self.data['c_idx']=self.data['o_idx'][self.data['t'].squeeze()==0] #These are the indices of the control units
        self.data['t_idx']=self.data['o_idx'][self.data['t'].squeeze()==1] #These are the indices of the treated units
    
    def split_pred(self,concat_pred):
        preds={}
        preds['y0_pred'] = self.data['y_scaler'].inverse_transform(concat_pred[:, 0].reshape(-1, 1))
        preds['y1_pred'] = self.data['y_scaler'].inverse_transform(concat_pred[:, 1].reshape(-1, 1))
        preds['phi'] = concat_pred[:, 2:]
        return preds

    def find_ynn(self, Phi):
        #helper for PEHEnn
        PhiC, PhiT =tf.dynamic_partition(Phi,tf.cast(tf.squeeze(self.data['t']),tf.int32),2) #separate control and treated reps
        dists=tf.sqrt(pdist2sq(PhiC,PhiT)) #calculate squared distance then sqrt to get euclidean
        yT_nn_idx=tf.gather(self.data['c_idx'],tf.argmin(dists,axis=0),1) #get c_idxs of smallest distances for treated units
        yC_nn_idx=tf.gather(self.data['t_idx'],tf.argmin(dists,axis=1),1) #get t_idxs of smallest distances for control units
        yT_nn=tf.gather(self.data['y'],yT_nn_idx,1) #now use these to retrieve y values
        yC_nn=tf.gather(self.data['y'],yC_nn_idx,1)
        y_nn=tf.dynamic_stitch([self.data['t_idx'],self.data['c_idx']],[yT_nn,yC_nn]) #stitch em back up!
        return y_nn

    def PEHEnn(self,concat_pred):
        p = self.split_pred(concat_pred)
        y_nn = self.find_ynn(p['phi']) #now its 3 plus because 
        cate_nn_err=tf.reduce_mean( tf.square( (1-2*self.data['t']) * (y_nn-self.data['y']) - (p['y1_pred']-p['y0_pred']) ) )
        return cate_nn_err

    def ATE(self,concat_pred):
        p = self.split_pred(concat_pred)
        return p['y1_pred']-p['y0_pred']

    def PEHE(self,concat_pred):
        #simulation only
        p = self.split_pred(concat_pred)
        cate_err=tf.reduce_mean( tf.square( ( (self.data['mu_1']-self.data['mu_0']) - (p['y1_pred']-p['y0_pred']) ) ) )
        return cate_err 

    def on_epoch_end(self, epoch, logs={}):
        concat_pred=self.model.predict(self.data['x'])
        #Calculate Empirical Metrics        
        ate_pred=tf.reduce_mean(self.ATE(concat_pred)); tf.summary.scalar('ate', data=ate_pred, step=epoch)
        pehe_nn=self.PEHEnn(concat_pred); tf.summary.scalar('cate_nn_err', data=tf.sqrt(pehe_nn), step=epoch)
        
        #Simulation Metrics
        ate_true=tf.reduce_mean(self.data['mu_1']-self.data['mu_0'])
        ate_err=tf.abs(ate_true-ate_pred); tf.summary.scalar('ate_err', data=ate_err, step=epoch)
        pehe =self.PEHE(concat_pred); tf.summary.scalar('cate_err', data=tf.sqrt(pehe), step=epoch)
        out_str=f' — ate_err: {ate_err:.4f}  — cate_err: {tf.sqrt(pehe):.4f} — cate_nn_err: {tf.sqrt(pehe_nn):.4f} '
        
        if self.verbose > 0: print(out_str)

## Building the model

You'll notice that our hyperparameterization of TARNet is a little different this time. Even though we haven't been actually comparing models, I wanted to try and implement things closer to the authors' recommended settings for IHDP.

In [None]:
import tensorflow as tf
import numpy as np

from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Concatenate
from tensorflow.keras import regularizers
from tensorflow.keras import Model
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.metrics import binary_accuracy
from tensorflow.keras.losses import Loss

def make_tarnet(input_dim, reg_l2):

    x = Input(shape=(input_dim,), name='input')

    # representation
    phi = Dense(units=25, activation='elu', kernel_initializer='RandomNormal',name='phi_1')(x)
    phi = Dense(units=25, activation='elu', kernel_initializer='RandomNormal',name='phi_2')(phi)

    # HYPOTHESIS
    y0_hidden = Dense(units=25, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_1')(phi)
    y1_hidden = Dense(units=25, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_1')(phi)

    # second layer
    y0_hidden = Dense(units=25, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_2')(y0_hidden)
    y1_hidden = Dense(units=25, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_2')(y1_hidden)

    # third
    y0_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y0_predictions')(y0_hidden)
    y1_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y1_predictions')(y1_hidden)

    concat_pred = Concatenate(1)([y0_predictions, y1_predictions,phi])
    model = Model(inputs=x, outputs=concat_pred)

    return model

Again. I'm doing things a little differently this time. We're going to use ADAM as the optimizer this time since it was used in the paper. We also have a bigger batch size this time because the IPM is very unstable if the batch size is too small (i.e., there are no treated units in the batch). Lastly, we're running for a few more epochs this time.

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau, TerminateOnNaN
from tensorflow.keras.optimizers import SGD, Adam
#Colab command to allow us to run Colab in TF2
%load_ext tensorboard 

val_split=0.2
batch_size=100
verbose=1
i = 0
tf.random.set_seed(i)
np.random.seed(i)
yt = np.concatenate([data['ys'], data['t']], 1)

# Clear any logs from previous runs
!rm -rf ./logs/ 
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0)

#let's try ADAM this time
adam_callbacks = [
        TerminateOnNaN(),
        EarlyStopping(monitor='val_loss', patience=2, min_delta=0.),
        ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=verbose, mode='auto',
                          min_delta=1e-8, cooldown=0, min_lr=0),
        tensorboard_callback,
        Base_Metrics(data,verbose=verbose)
    ]


cfrnet_model=make_tarnet(data['x'].shape[1],.01)
cfrnet_loss=CFRNet_Loss(alpha=1.0)

cfrnet_model.compile(optimizer=Adam(learning_rate=1e-4),
                      loss=cfrnet_loss,
                 metrics=[cfrnet_loss,cfrnet_loss.regression_loss,cfrnet_loss.mmdsq_loss])

cfrnet_model.fit(x=data['x'],y=yt,
                 callbacks=adam_callbacks,
                  validation_split=val_split,
                  epochs=500,
                  batch_size=batch_size,
                  verbose=verbose)



## Reviewing results in Tensorboard
When we run TensorBoard we notice a couple of things. First that with $\alpha=1$, the MMD IPM contributes only ~5-15% of the total gradient that the regression loss contributes, but grows over time to put an increasingly stronger penalty on the representations for not balancing the data. Whether we should use a bigger $\alpha$ is an empirical question for hyperparameter tuning.

The other thing you should notice is that while the IPM definitely shrinks initially, within a single batch it is VERY high variance due to the small batch size, and it's unclear whether it's making a difference. 

In [None]:
%tensorboard --logdir logs/fit

# Part 2: Adding weights to CFRNet

A limitation of the original CFRNet is that it does not provide consistency guarantees. In [Johansson et al. 2019](https://arxiv.org/abs/1903.03448), and [Johansson et al., 2020](https://arxiv.org/abs/2001.07426), the authors introduce estimated weights $\pi(\Phi(X),T)$ derived from the propensity score to provide consistency guarantees. Interestingly, they also use these weights to relax some of the overlap assumptions as long as the weights themselves obey the positivity assumption. These weights are used to adjust both the representations when calculating the IPM and predicted outcomes.
<figure><img src=https://github.com/kochbj/Deep-Learning-for-Causal-Inference/blob/main/images/WeightedCFRNet.png?raw=true width="900"><figcaption>Weighted CFRNet architecture introduced in Johannson et al., 2020. Purple indicates inputs, orange indicates network layers, other colors indicate output layers, and white indicates outputs. The dashes between colored shapes indicate an unspecifed number of additional hidden layers. The dashed lines on the right indicate non-gradient, plug-in computations that occur after training.</a></figcaption></figure>

Weighted CFRNet minimizes the following loss function:
\begin{equation}
\min_{h,\Phi,IPM,\pi,\lambda_h,\lambda_w}\frac{1}{n}\sum_{i=1}^n \frac{P(t_i)}{\pi(\Phi(x_i,t_i))}\cdot\mathcal{L}(h(\Phi(x_i),t_i),y_i) + \lambda_h \mathcal{R}(h)+\alpha\cdot IPM(\frac{P(1)}{\pi(\Phi(X,1))}\cdot\Phi(X,|T=1),\frac{P(0)}{\pi(\Phi(X,0))}\cdot\Phi(X|T=0))+\lambda_\pi\frac{||\pi||_2}{n}\end{equation}
where $R(h)$ is a model complexity term and $\lambda_h$, $\lambda_\pi$ and $\alpha$ are hyperparameters.

Applying the weights during the calculation of the IPM should smooth out some of the noisiness we saw from computing the IPM in tiny batches. This should reduce bias, even if weighting presents a potential variance tradeoff. The last term $\lambda_\pi\frac{||\pi||_2}{n}$ regularizes the weights to reduce variance.

## Coding up weighted CFRNet

This is virtually identical to the code above but we introduce `weights_l2` for $\lambda_\pi$, and adjust our loss function to weight the IPM calculations and outcome predictions. We also adjust the number of layers and neurons per Johansson et al.'s recommendations. 

In [None]:
def make_weighted_cfrnet(input_dim, reg_l2, weights_l2):

    x = Input(shape=(input_dim,), name='input')

    # representation
    phi = Dense(units=32, activation='elu', kernel_initializer='RandomNormal',name='phi_1')(x)
    phi = Dense(units=16, activation='elu', kernel_initializer='RandomNormal',name='phi_2')(phi)

    # HYPOTHESIS
    y0_hidden = Dense(units=16, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_1')(phi)
    y1_hidden = Dense(units=16, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_1')(phi)
    t_hidden = Dense(units=32, activation='elu', kernel_regularizer=regularizers.l2(weights_l2),name='t_hidden_1')(phi)

    # second layer
    #y0_hidden = Dense(units=25, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y0_hidden_2')(y0_hidden)
    #y1_hidden = Dense(units=25, activation='elu', kernel_regularizer=regularizers.l2(reg_l2),name='y1_hidden_2')(y1_hidden)
    t_hidden = Dense(units=16, activation='elu', kernel_regularizer=regularizers.l2(weights_l2),name='t_hidden_2')(t_hidden)

    # third
    y0_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y0_predictions')(y0_hidden)
    y1_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y1_predictions')(y1_hidden)
    t_predictions = Dense(units=1, activation='sigmoid', kernel_regularizer=regularizers.l2(weights_l2), name='t_predictions')(t_hidden)

    concat_pred = Concatenate(1)([y0_predictions, y1_predictions,t_predictions,phi])
    model = Model(inputs=x, outputs=concat_pred)

    return model


class Weighted_CFRNet_Loss(CFRNet_Loss):
    #initialize instance attributes
    def __init__(self, prob_treat,alpha=1.0,sigma=1.0):
        super().__init__()
        self.pT=prob_treat
        self.alpha = alpha
        self.rbf_sigma=sigma
        self.name='weighted_cfrnet_loss'
    
    def split_pred(self,concat_pred):
        #generic helper to make sure we dont make mistakes
        preds={}
        preds['y0_pred'] = concat_pred[:, 0]
        preds['y1_pred'] = concat_pred[:, 1]
        preds['t_pred'] = concat_pred[:, 2]
        preds['t_pred'] = (preds['t_pred'] + 0.001) / 1.002
        preds['phi'] = concat_pred[:, 3:]
        return preds
    
    #for logging purposes only
    def treatment_acc(self,concat_true,concat_pred):
        t_true = concat_true[:, 1]
        p = self.split_pred(concat_pred)
        #Since this isn't used as a loss, I've used tf.reduce_mean for interpretability
        return tf.reduce_mean(binary_accuracy(t_true, p['t_pred'], threshold=0.5))

    def calc_weighted_mmdsq(self, Phi, t_true, t_pred):
        t_predC, t_predT  =tf.dynamic_partition(t_pred,tf.cast(tf.squeeze(t_true),tf.int32),2) #propensity
        PhiC, PhiT =tf.dynamic_partition(Phi,tf.cast(tf.squeeze(t_true),tf.int32),2) #representation
        weightC=tf.expand_dims((1-self.pT)/(1-t_predC),axis=-1)
        weightT=tf.expand_dims(self.pT/t_predT,axis=-1)

        wPhiC = weightC * PhiC
        wPhiT = weightT * PhiT

 
        Kcc = self.rbf_kernel(wPhiC,wPhiC)
        Kct = self.rbf_kernel(wPhiC,wPhiT)
        Ktt = self.rbf_kernel(wPhiT,wPhiT)

        m = tf.cast(tf.shape(PhiC)[0],Phi.dtype)
        n = tf.cast(tf.shape(PhiT)[0],Phi.dtype)

        mmd = 1.0/(m*(m-1.0))*(tf.reduce_sum(Kcc)-m)
        mmd = mmd + 1.0/(n*(n-1.0))*(tf.reduce_sum(Ktt)-n)
        mmd = mmd - 2.0/(m*n)*tf.reduce_sum(Kct)
        return mmd * tf.ones_like(t_true)

    def weighted_mmdsq_loss(self, concat_true,concat_pred):
        t_true = concat_true[:, 1]
        p=self.split_pred(concat_pred)
        mmdsq = tf.reduce_mean(self.calc_weighted_mmdsq(p['phi'],t_true, p['t_pred']))
        return mmdsq

    def weights(self,concat_true,concat_pred):
        p = self.split_pred(concat_pred)
        weightT=tf.expand_dims(self.pT/p['t_pred'],axis=-1)
        return tf.reduce_mean(weightT)

    def weighted_regression_loss(self, concat_true, concat_pred):
        y_true = concat_true[:, 0]
        t_true = concat_true[:, 1]
        p = self.split_pred(concat_pred)

        weightC=tf.expand_dims((1-self.pT)/(1.-p['t_pred']),axis=-1)
        weightT=tf.expand_dims(self.pT/p['t_pred'],axis=-1)

        loss0 = tf.reduce_mean((1. - t_true) * tf.square(y_true - p['y0_pred'])*weightC)
        loss1 = tf.reduce_mean(t_true * tf.square(y_true - p['y1_pred'])* weightT)
        return loss0 + loss1

    def call(self, concat_true, concat_pred):
        return self.weighted_regression_loss(concat_true,concat_pred) + self.alpha * self.weighted_mmdsq_loss(concat_true,concat_pred)


## Metrics callback
We just need to adjust our metric callback to account for our propensity score predictions...

In [None]:
class Weighted_Metrics(Base_Metrics):
    def __init__(self,data, verbose=0):   
        super().__init__(data,verbose)

    def split_pred(self,concat_pred):
        preds={}
        preds['y0_pred'] = self.data['y_scaler'].inverse_transform(concat_pred[:, 0])
        preds['y1_pred'] = self.data['y_scaler'].inverse_transform(concat_pred[:, 1])
        preds['t_pred'] = concat_pred[:, 2]
        preds['phi'] = concat_pred[:, 3:]

        return preds

In [None]:
val_split=0.2
batch_size=100
verbose=1
i = 0
tf.random.set_seed(i)
np.random.seed(i)
yt = np.concatenate([data['ys'], data['t']], 1)
pT=data['t'][data['t']==1].shape[0]/data['t'].shape[0]
print("Probability of treament:", pT)
# Clear any logs from previous runs
!rm -rf ./logs/ 
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
file_writer.set_as_default()
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0)
adam_callbacks = [
        TerminateOnNaN(),
        EarlyStopping(monitor='val_loss', patience=2, min_delta=0.),
        ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=verbose, mode='auto',
                          min_delta=1e-8, cooldown=0, min_lr=0),
        tensorboard_callback,
        Weighted_Metrics(data,verbose=verbose)
    ]


sgd_lr = 1e-4
momentum = 0.9

weighted_cfrnet_model=make_weighted_cfrnet(data['x'].shape[1],.001,.1)
cfrnet_loss=Weighted_CFRNet_Loss(prob_treat=pT,alpha=1.)

weighted_cfrnet_model.compile(optimizer=Adam(learning_rate=1e-4),
                      loss=cfrnet_loss,
                 metrics=[cfrnet_loss,cfrnet_loss.weights,cfrnet_loss.treatment_acc,cfrnet_loss.weighted_mmdsq_loss])

weighted_cfrnet_model.fit(x=data['x'],y=yt,
                 callbacks=adam_callbacks,
                  validation_split=val_split,
                  epochs=1000,
                  batch_size=batch_size,
                  verbose=verbose)


## Reviewing results in Tensorboard

The most noticeable thing in our logs is that the `mmdsq_loss` is significantly smoother than in the previous trial. FWIW, the authors also demonstrate in the paper that including the weights lessens the dependence on $\alpha$ because the weights are learned adaptively.

In [None]:
%tensorboard --logdir logs/fit

# Thats it!

- In this tutorial we learned about integral probability metrics.

- We applied an IPM loss to TARNet to create CFRNet.

- We extended CFRNet with propensity weights.

- We built a (weighted) CFRNet model and tested it on the IHDP data.

## Thanks

I hope you found these tutorials useful. If you have questions or would like me to do more (I've thought about adding GANITE and exploring how to create confidence intervals) please let me know through my [Github](github.com/kochbj) or the email on my [website](bernardjkoch.com). 
