#### setup

In [None]:
import json
import numpy as np, pandas as pd, torch
from torch.utils.data import DataLoader

from library.data_utils import make_splits, set_seed
from library.models import *
from library.training import *
from library.eval_utils import *

In [None]:
# device and seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(seed=0)

In [None]:
# parameters
dataset = 'synthetic'
train_size = 1000

#### data

In [None]:
# load data and split
df = pd.read_csv(f"./data/{dataset}.csv", index_col=0)
_, _, train_df, val_df, test_df = make_splits(df=df, train_size=train_size, seed=seed)
confounders = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']

#### set doubly robust scores

In [None]:
# load models
checkpoint_dir = './checkpoints/'
input_dim = len(confounders)

# e(x)
propensity_model = ClassificationHead(input_dim=input_dim, hidden_dim=64).to(device)
propensity_model.load_state_dict(torch.load(checkpoint_dir + 'propensity_model.pt', weights_only=True))

# mu0(x)
response_model_control = RegressionHead(input_dim=input_dim, hidden_dim=64).to(device)
response_model_control.load_state_dict(torch.load(checkpoint_dir + 'response_model_control.pt', weights_only=True))

# mu1(x)
response_model_treated = RegressionHead(input_dim=input_dim, hidden_dim=64).to(device)
response_model_treated.load_state_dict(torch.load(checkpoint_dir + 'response_model_treated.pt', weights_only=True))

In [None]:
# set scores
train_df = compute_dr_scores(train_df, confounders, propensity_model, response_model_control, response_model_treated, device)
val_df = compute_dr_scores(val_df, confounders, propensity_model, response_model_control, response_model_treated, device)

#### train CATE estimator

In [None]:
# set parameters
params = dict(
    hidden_dim=64,
    learning_rate=5e-4,
    weight_decay=0,
    batch_size=128,
    max_epochs=50,
    patience=5)

In [None]:
# init data loaders
train_loader, val_loader = make_cate_loaders(train_df, val_df, confounders, batch_size=params['batch_size'])

# train cate model
cate_model = RegressionHead(input_dim=input_dim, hidden_dim=params["hidden_dim"]).to(device)
cate_model, info = train_cate(cate_model, train_loader, val_loader,device, lr=params['learning_rate'], weight_decay=params['weight_decay'],
                              max_epochs=params['max_epochs'], patience=params['patience'], seed=0)

# store
torch.save(cate_model.state_dict(), './checkpoints/cate_model.pt')

#### evaluation

In [None]:
# create test loader
test_loader = DataLoader(EvalDataset(test_df, confounders), batch_size=params['batch_size'], shuffle=False)

In [None]:
# load cate model
cate_model = RegressionHead(hidden_dim=params["hidden_dim"], input_dim=input_dim).to(device)
cate_model.load_state_dict(torch.load('./checkpoints/cate_model.pt', weights_only=True))

In [None]:
# load response models
response_model_control = RegressionHead(hidden_dim=64, input_dim=input_dim).to(device)
response_model_control.load_state_dict(torch.load('./checkpoints/response_model_control.pt', weights_only=True))

response_model_treated = RegressionHead(hidden_dim=64, input_dim=input_dim).to(device)
response_model_treated.load_state_dict(torch.load('./checkpoints/response_model_treated.pt', weights_only=True))

In [None]:
# get predictions and metrics
df_eval = get_estimates_cate(cate_model, response_model_control, response_model_treated, test_loader, device)
results = compute_metrics_cate(df_eval)