#### setup

In [None]:
import json
import numpy as np, pandas as pd, torch
from torch.utils.data import DataLoader

from library.data_utils import make_splits, set_seed
from library.models import *
from library.training import *
from library.eval_utils import *

In [None]:
# device and seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(seed=0)

In [None]:
# parameters
dataset = 'synthetic'
train_size = 1000

#### data

In [None]:
# load data and split
df = pd.read_csv(f"./data/{dataset}.csv", index_col=0)
train_df, val_df, _, _, _ = make_splits(df=df, train_size=train_size, seed=0)
confounders = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']

#### train nuisance models

In [None]:
# set parameters
input_dim = len(confounders)
params = dict(
    hidden_dim=64,
    learning_rate=5e-4,
    weight_decay=0,
    batch_size=128,
    max_epochs=50,
    patience=5)

In [None]:
# init data loaders
train_loader, val_loader = make_nuisance_loaders(train_df, val_df, confounders, params["batch_size"])

# train propensity model
propensity_model = ClassificationHead(input_dim=input_dim, hidden_dim=params["hidden_dim"]).to(device)
propensity_model, info = train_propensity(propensity_model, train_loader, val_loader, device, lr=params["learning_rate"], 
                                          weight_decay=params["weight_decay"], max_epochs=params["max_epochs"], patience=params["patience"], seed=0)
# store
torch.save(propensity_model.state_dict(), './checkpoints/propensity_model.pt')

In [None]:
# init data loaders (idem for T = 1)
train_resp = train_df[train_df["T"] == 0]
val_resp = val_df[val_df["T"] == 0]
train_loader, val_loader = make_nuisance_loaders(train_resp, val_resp, confounders, params["batch_size"])

# train response model
response_model_control = RegressionHead(input_dim=input_dim, hidden_dim=params["hidden_dim"]).to(device)
response_model_control, info = train_response(response_model_control, train_loader, val_loader, device,lr=params["learning_rate"], weight_decay=params["weight_decay"], 
                                              max_epochs=params["max_epochs"], patience=params["patience"], seed=0)

# store
torch.save(response_model_control.state_dict(), './checkpoints/response_model_control.pt')