# *In progress*: Automatic Inference
I am in the process of repurposing the code written for the final exam of the course discussed below. The original code is commented and will slowly be incorporated into a (hopefully) coherent codebase.

Author: Connacher Murphy

I implement the procedure described in [Farrell, Liang, and Misra (2021)](https://arxiv.org/abs/2010.14694). I make use of R code provided in the Causal Machine Learning course offered in the Fall of 2023 by Max Farrell and Sanjog Misra.

The parameter of interest is $\mu_0 = \mathbb{E}[\mathbf{H}(\mathbf{X},\mathbf{\theta}(\mathbf{X}); \mathbf{Z})]$. The outcome variable $Y$ is linked to the parameter functions $\mathbf{\theta}(\cdot)$ by the equality $\mathbb{E}[Y | \mathbf{X} = \mathbf{x}, \mathbf{Z} = \mathbf{z}] = G(\mathbf{\theta}(\mathbf{X}), \mathbf{Z})$.

When projecting the Hessian of the loss function onto $\mathbf{X}$ for the estimation of $\mathbf{\Lambda}(\mathbf{X})$, it is sometimes possible to avoid estimation. For example, with a linear $G(\mathbf{\theta}(\mathbf{X}), \mathbf{Z})$ and squared loss, we can compute the Hessian directly. This code does _not_ account for such possibilities and will rely on automatic differentiation for the Hessian and a DNN for the projection of this Hessian onto X

_Caution_: Some parts of this code are specialized to the $\operatorname{dim}(\mathbf{\theta}) = 2$ case. I plan to make the code more flexible along this dimension.

## 0. Libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import automatic_inference as auto_inf

# import numpy as np
# import torch.linalg as linalg
# import torch.autograd as autograd
# import matplotlib.pyplot as plt
# import math

## 1. Data generating process (DGP)

In [2]:
torch.manual_seed(12345)  # Set the seed for reproducibility

N = 15000  # observation count
K = 2  # feature count

# Draw independent features from standard normal
dnn_features = torch.randn(N, K)

# Build the structural parameters from features
structural_parameters = torch.cat(
    (dnn_features[:, 0].view(N, 1), 3 + dnn_features[:, 1].view(N, 1)), dim=1
)
structural_parameters_dim = structural_parameters.shape[1]

# The structural feature is a binary treatment indicator
structural_features = 1 * (torch.randn(N, 1) > 0).view(N, 1)


# Define the correspondence between structural parameters and structural features
# We use a linear correspondence here; let's not get too crazy
# CM: this is a pretty common structural layer, so I should move it to the .py file
def structural_layer(structural_parameters, structural_features):
    structural_layer_eval = structural_parameters[:, 0:1] + torch.sum(
        (structural_features * structural_parameters[:, 1:]), axis=1, keepdim=True
    )

    return structural_layer_eval


# Calculate outcomes (structural component + noise)
outcomes_structural = structural_layer(structural_parameters, structural_features)
outcomes = outcomes_structural + torch.randn(N, 1)

In [3]:
# CM: create a function for sample splits
# CM: check N / splits != integer case

# Create splits
perm = torch.randperm(N)  # create a permutation of the indices

num_splits = 3  # number of splits

split_size = N // num_splits  # compute the size of each split

splits = []  # store splits in a list of dictionaries
for s in range(num_splits):
    indices = perm[s * split_size : (s + 1) * split_size]

    # Use indices to create a split
    split = {
        "dnn_features": dnn_features[indices],
        "structural_features": structural_features[indices],
        "outcomes": outcomes[indices],
        "structural_parameters": structural_parameters[indices],
    }

    # Add the split to the list of splits
    splits.append(split)

## 2. Estimation

### 2.1. Estimate structural parameters

In [4]:
# Hyperparameters
hidden_sizes = [30, 30]
dropout_rate = 0.0
learning_rate = 5e-3
weight_decay = 0.0  # no L2 regularization
num_epochs = 2000

# Initialize loss function; we use mean squared error
loss_function = nn.MSELoss(reduction = "mean")

# We will initialize the model and optimizer in each loop below

In [5]:
print("Estimating structural parameters")

model_structural_parameters = []  # trained models

for s in range(num_splits):
    print(f"Split {s + 1}")

    # Initialize neural network
    model = auto_inf.DeepNeuralNetworkReLU(
        input_dim=K,
        hidden_sizes=hidden_sizes,
        output_dim=structural_parameters_dim,
        dropout_rate=dropout_rate,
    )

    # Initialize optimizer; we use stochastic gradient descent
    optimizer = optim.SGD(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )

    model_fit = auto_inf.train_dnn(
        splits[s]["dnn_features"],
        splits[s]["structural_features"],
        splits[s]["outcomes"],
        structural_layer,
        model,
        loss_function,
        optimizer,
        num_epochs,
        noisily=False,
    )

    model_structural_parameters.append(model)

Estimating theta
Split 1


100%|██████████| 2000/2000 [00:02<00:00, 749.09it/s]


Split 2


100%|██████████| 2000/2000 [00:02<00:00, 691.97it/s]


Split 3


100%|██████████| 2000/2000 [00:02<00:00, 758.21it/s]


In [None]:
# CM: add diagnostics for structural parameter estimation

### 1.2. Estimate $\mathbf{\Lambda}(\mathbf{X})$

In [20]:
# CM: up next

# Old code:

### 1.2. Estimate $\mathbf{\Lambda}(\mathbf{X})$

In [10]:
# def identity(theta, Z): # simple identity function for G(Theta)
#     return(theta)

# # Estimate Lambda for a given split and DNN
# def estimate_Lambda(split, dnn, theta_dim, G, loss_function):
#     print(f'Split {split + 1} with DNN {dnn + 1}')

#     # Number of observations in the split
#     N = splits[split]['X'].size(0)

#     # Evaluate the structural parameter functions
#     theta = model_theta[dnn](splits[split]['X'])

#     # Predict outcomes
#     outputs = G(theta, splits[split]['Z'])

#     # Calculate loss
#     loss = loss_function(outputs, splits[split]['Y'])

#     # Gradient of loss w.r.t. theta
#     theta_grad = autograd.grad(
#         loss, theta, create_graph = True, retain_graph= True
#     )[0]

#     # Initialize Hessian
#     Lambda = torch.zeros([N, theta_dim, theta_dim])
    
#     for k in range(theta_dim):
#         hess_row = autograd.grad( # row k of Hessian
#             theta_grad[:,k].sum(), theta, retain_graph = True
#         )

#         for j in range(k, theta_dim): # only need to compute upper triangle
#             print(f'Element ({k}, {j})')
            
#             # Extract Hessian element (k,j)
#             hess_element = hess_row[0][:,j]

#             # Project Hessian element (k,j) onto X
#             hess_element_projection = train_DNN( 
#                 splits[split]['X'], hess_element.view(N, 1), splits[split]['Z'],
#                 1, identity, [30, 30], 0.0, nn.MSELoss(reduction = 'mean'),
#                 learning_rate, weight_decay, num_epochs
#             )

#             # Store projection
#             Lambda[:,k,j] = hess_element_projection(splits[split]['X']).view(N)

#             if k != j: # reflect upper triangle to lower triangle
#                 Lambda[:,j,k] = Lambda[:,k,j]
#                 print(f'Reflecting to element ({j}, {k})')
#     print('\n')

#     return(Lambda)

In [11]:
# loss_function = nn.MSELoss(reduction = 'sum')

# print('Projecting Hessian onto X\n')

# Lambdas = []

# Lambdas.append(estimate_Lambda(0, 2, theta_dim, G, loss_function))
# Lambdas.append(estimate_Lambda(1, 0, theta_dim, G, loss_function))
# Lambdas.append(estimate_Lambda(2, 1, theta_dim, G, loss_function))

### 1.3. Influence function

In [12]:
# # Estimate influence function for a given split, DNN, and Lambda
# def estimate_influence_function(split, dnn, Lambda, theta_dim, G, H, loss_function):
#     print(f'Split {split + 1} with DNN {dnn + 1}')
    
#     # Number of observations in the split
#     N = splits[split]['X'].size(0)

#     # Evaluate the structural parameter functions
#     theta = model_theta[dnn](splits[split]['X'])
    
#     # Predict outcomes
#     outputs = G(theta, splits[split]['Z'])

#     # Calculate loss
#     loss = loss_function(outputs, splits[split]['Y'])

#     # Gradient of loss w.r.t. theta
#     theta_grad = autograd.grad(
#         loss, theta, create_graph = True, retain_graph= True
#     )[0]

#     # Evaluate H(.)
#     H_eval = H(theta, splits[split]['Z'])

#     # Gradient of H for adjustment term
#     H_theta = autograd.grad(H_eval.sum(), theta)[0]

#     influence_function = H_eval - torch.matmul(
#         torch.matmul(
#             H_theta.view(N, 1, theta_dim), linalg.pinv(Lambdas[Lambda])
#         ).view(N, 1, theta_dim),
#         theta_grad.view(N, theta_dim, 1)
#     ).view(N, 1)

#     return(influence_function)

In [13]:
# # Define H(theta) function
# def H(theta, Z): # we let H(theta(X)) = beta(X)
#     N = Z.size(0)
#     return(theta[:,1].view(N, 1))

# loss_function = nn.MSELoss(reduction = 'sum')

# print('Estimating influence function')

# influence_function = [] # store influence function estimates

# influence_function.append(
#     estimate_influence_function(0, 1, 2, theta_dim, G, H, loss_function)
# )
# influence_function.append(
#     estimate_influence_function(1, 2, 0, theta_dim, G, H, loss_function)
# )
# influence_function.append(
#     estimate_influence_function(2, 0, 1, theta_dim, G, H, loss_function)
# )

In [14]:
# # Concatenate influence function values across splits and store as np array
# influence_function_np = np.concatenate([
#     influence_function[0].detach().numpy(),
#     influence_function[1].detach().numpy(),
#     influence_function[2].detach().numpy()
# ])

# # Calculate estimate and standard error from concatenated influence function
# est = influence_function_np.mean()
# se = math.sqrt(influence_function_np.var() / N)

In [15]:
# # Report results
# print('Mean:', round(est, 4))
# print('S.E.:', round(se, 4))
# print('95% CI: [', round(est - 1.96 * se, 4), ', ',
#       round(est + 1.96 * se, 4), ']', sep = '')