In [1]:
import sys, random
sys.path.insert(0, '../../')
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

import torch
from kernels.nn import ImplicitDenseNetKernel
from model.ick import ICK
from model.cmick import CMICK
from benchmarks.cmgp_modified import CMGP
from utils.train import CMICKEnsembleTrainer
from utils.helpers import *
from utils.metrics import policy_risk, att_err

# To make this notebook's output stable across runs
random.seed(2020)
np.random.seed(2020)
torch.manual_seed(2020)
torch.cuda.manual_seed(2020)
torch.cuda.manual_seed_all(2020)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False



# 1. Load and preprocess data

In [3]:
def load_and_preprocess_data(train_ratio, test_ratio, random_state):
    randomized_data_dir = '../../data/JOBS/randomized.csv'
    randomized_df = pd.read_csv(randomized_data_dir)
    nonrandomized_data_dir = '../../data/JOBS/nonrandomized.csv'
    nonrandomized_df = pd.read_csv(nonrandomized_data_dir)
    df = pd.concat([randomized_df, nonrandomized_df], ignore_index=True)
    cols_to_normalize = ['Age', 'Education', 'RE75']
    for c in df.columns:
        if c in cols_to_normalize:
            scaler = StandardScaler()
            df[c] = scaler.fit_transform(df[c].to_numpy().reshape(-1,1)).reshape(-1)
    df['RE78'] = df['RE78'].apply(lambda x: 1. if x > 0. else 0.)   # Trasform to binary classification task
    
    N = len(df)
    # Test set only from randomized samples
    test_df = df[:len(randomized_df)].sample(n=int(test_ratio*N), random_state=random_state)  
    df = df.drop(test_df.index)
    train_df = df.sample(n=int(train_ratio*N), random_state=random_state)
    val_df = df.drop(train_df.index)
    X_train, X_val, X_test = train_df.to_numpy()[:,1:-1], val_df.to_numpy()[:,1:-1], test_df.to_numpy()[:,1:-1]
    T_train, T_val, T_test = train_df.to_numpy()[:,:1], val_df.to_numpy()[:,:1], test_df.to_numpy()[:,:1]
    Y_train, Y_val, Y_test = train_df.to_numpy()[:,-1:], val_df.to_numpy()[:,-1:], test_df.to_numpy()[:,-1:]
    
    data = {'X_train': X_train, 'T_train': T_train, 'Y_train': Y_train, 'X_val': X_val, 'T_val': T_val,
            'Y_val': Y_val, 'X_test': X_test, 'T_test': T_test, 'Y_test': Y_test}
    data_train, data_val, data_test = [X_train, T_train], [X_val, T_val], [X_test, T_test]
    data_generators = create_generators_from_data(
        x_train=data_train, y_train=Y_train, 
        x_val=data_val, y_val=Y_val,
        x_test=data_test, y_test=Y_test
    )
    return data_generators, data

# 2. Define CMNN model

In [4]:
def build_cmnn_ensemble(input_dim, load_weights=False):
    alpha11, alpha12, alpha13 = 1.0, 1.0, 1.0
    alpha21, alpha22, alpha23 = 1.0, 1.0, 1.0
    num_estimators = 2

    ensemble, ensemble_weights = [], {}
    for i in range(num_estimators):
        f11 = ICK(
            kernel_assignment=['ImplicitDenseNetKernel'],
            kernel_params={
                'ImplicitDenseNetKernel':{
                    'input_dim': input_dim,
                    'latent_feature_dim': 512,
                    'num_blocks': 0, 
                    'num_layers_per_block': 1, 
                    'num_units': 512, 
                    'activation': 'relu'
                }
            }
        )
        f12 = ICK(
            kernel_assignment=['ImplicitDenseNetKernel'],
            kernel_params={
                'ImplicitDenseNetKernel':{
                    'input_dim': input_dim,
                    'latent_feature_dim': 512,
                    'num_blocks': 0, 
                    'num_layers_per_block': 1, 
                    'num_units': 512, 
                    'activation': 'relu'
                }
            }
        )
        f13 = ICK(
            kernel_assignment=['ImplicitDenseNetKernel'],
            kernel_params={
                'ImplicitDenseNetKernel':{
                    'input_dim': input_dim,
                    'latent_feature_dim': 512,
                    'num_blocks': 0, 
                    'num_layers_per_block': 1, 
                    'num_units': 512, 
                    'activation': 'relu'
                }
            }
        )
        f21 = ICK(
            kernel_assignment=['ImplicitDenseNetKernel'],
            kernel_params={
                'ImplicitDenseNetKernel':{
                    'input_dim': input_dim,
                    'latent_feature_dim': 512,
                    'num_blocks': 0, 
                    'num_layers_per_block': 1, 
                    'num_units': 512, 
                    'activation': 'relu'
                }
            }
        )
        f22 = ICK(
            kernel_assignment=['ImplicitDenseNetKernel'],
            kernel_params={
                'ImplicitDenseNetKernel':{
                    'input_dim': input_dim,
                    'latent_feature_dim': 512,
                    'num_blocks': 0, 
                    'num_layers_per_block': 1, 
                    'num_units': 512, 
                    'activation': 'relu'
                }
            }
        )
        f23 = ICK(
            kernel_assignment=['ImplicitDenseNetKernel'],
            kernel_params={
                'ImplicitDenseNetKernel':{
                    'input_dim': input_dim,
                    'latent_feature_dim': 512,
                    'num_blocks': 0, 
                    'num_layers_per_block': 1, 
                    'num_units': 512, 
                    'activation': 'relu'
                }
            }
        )
        if load_weights:
            for f in ['f11', 'f12', 'f13', 'f21', 'f22', 'f23']:
                eval(f).kernels[0].load_state_dict(torch.load('./checkpoints/cmick_jobs.pt')['model_'+str(i+1)][f])
        else:
            model_weights = {
                'f11': f11.kernels[0].state_dict(), 'f12': f12.kernels[0].state_dict(), 'f13': f13.kernels[0].state_dict(), 
                'f21': f21.kernels[0].state_dict(), 'f22': f22.kernels[0].state_dict(), 'f23': f23.kernels[0].state_dict()
            }
            ensemble_weights['model_'+str(i+1)] = model_weights
        # Set output_binary=True for binary Y0 and Y1
        baselearner = CMICK(
            control_components=[f11,f21], treatment_components=[f12,f22], shared_components=[f13,f23],
            control_coeffs=[alpha11,alpha21], treatment_coeffs=[alpha12,alpha22], shared_coeffs=[alpha13,alpha23], 
            coeff_trainable=True, output_binary=True
        )
        ensemble.append(baselearner)
    if not load_weights:
        if not os.path.exists('./checkpoints'):
            os.makedirs('./checkpoints')
        torch.save(ensemble_weights, './checkpoints/cmick_jobs.pt')

    return ensemble

# 3. Training and evaluation of CMNN model

In [5]:
def fit_and_evaluate_cmnn(ensemble, data_generators, data, lr, treatment_index=1): 
    # The index of "T_train" in "data_train" is 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optim = 'sgd'
    optim_params = {
        'lr': lr, 
        'momentum': 0.99,
        'weight_decay': 1e-4
    }
    epochs, patience = 1000, 10
    trainer = CMICKEnsembleTrainer(
        model=ensemble,
        data_generators=data_generators,
        optim=optim,
        optim_params=optim_params, 
        model_save_dir=None,
        device=device,
        epochs=epochs,
        patience=patience, 
        treatment_index=treatment_index
    )
    trainer.train()
    
    mean_test_pred, std_test_pred, y_test_true = trainer.predict()
    y_test, t_test = data['Y_test'], data['T_test']
    r_pol = policy_risk(mean_test_pred, y_test, t_test)
    eps_att = att_err(mean_test_pred, y_test, t_test)
    print('Policy risk (CMNN):             %.4f' % (r_pol))
    print('Avg treatment effect err (CMNN):             %.4f' % (eps_att))
    
    return r_pol

# 4. Benchmark 1: original CMGP

In [6]:
def fit_and_evaluate_original_cmgp(data):
    X_train, T_train, Y_train = data['X_train'], data['T_train'], data['Y_train']
    X_test, T_test, Y_test = data['X_test'], data['T_test'], data['Y_test']
    cmgp_model = CMGP(X_train, T_train, Y_train)

    mu0_test_pred, mu1_test_pred = cmgp_model.predict(X_test, return_var=False)
    mu_test_pred = np.concatenate([mu0_test_pred, mu1_test_pred], axis=1)
    r_pol = policy_risk(mu_test_pred, Y_test, T_test)
    eps_att = att_err(mu_test_pred, Y_test, T_test)
    print('Policy risk (CMGP):             %.4f' % (r_pol))
    print('Avg treatment effect err (CMGP):             %.4f' % (eps_att))
    
    return r_pol

# Main function

In [7]:
def main():
    train_ratio, test_ratio = 0.56, 0.20
    lr = 1e-4
    data_generators, data = load_and_preprocess_data(train_ratio, test_ratio, random_state=1)
    input_dim = data['X_train'].shape[1]
    ensemble = build_cmnn_ensemble(input_dim, load_weights=False)
    r_pol_cmnn = fit_and_evaluate_cmnn(ensemble, data_generators, data, lr)
    r_pol_cmgp = fit_and_evaluate_original_cmgp(data)

if __name__ == "__main__":
    main()

Training started:

Epoch 1/1000
Learning rate: 0.000100
treatment params:
tensor(-0.0014)
control params:
tensor(-0.0008)
treatment params:
tensor(-0.0042)
control params:
tensor(-0.0030)
treatment params:
tensor(-0.0014)
control params:
tensor(0.0016)
treatment params:
tensor(-0.0042)
control params:
tensor(-0.0030)
treatment params:
tensor(-0.0014)
control params:
tensor(0.0040)
treatment params:
tensor(-0.0042)
control params:
tensor(-0.0030)
treatment params:
tensor(-0.0014)
control params:
tensor(0.0064)
treatment params:
tensor(-0.0042)
control params:
tensor(-0.0029)
treatment params:
tensor(-0.0014)
control params:
tensor(0.0088)
treatment params:
tensor(-0.0042)
control params:
tensor(-0.0028)
treatment params:
tensor(-0.0014)
control params:
tensor(0.0111)
treatment params:
tensor(-0.0042)
control params:
tensor(-0.0028)
treatment params:
tensor(-0.0014)
control params:
tensor(0.0134)
treatment params:
tensor(-0.0042)
control params:
tensor(-0.0027)
treatment params:
tensor(-

Epoch 2/1000
Learning rate: 0.000100
treatment params:
tensor(-0.0014)
control params:
tensor(0.1050)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0004)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1064)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0004)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1078)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0005)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1091)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0005)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1104)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0006)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1117)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0006)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1130)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0006)
treatment params:
tensor(-0.0014)
control params:
ten

Epoch 3/1000
Learning rate: 0.000100
treatment params:
tensor(-0.0014)
control params:
tensor(0.1647)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0024)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1655)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0024)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1663)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0024)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1670)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0025)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1678)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0025)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1685)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0025)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1692)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0025)
treatment params:
tensor(-0.0014)
control params:
ten

treatment params:
tensor(-0.0042)
control params:
tensor(0.0035)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0035)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0035)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0035)
0s - loss 710410048.0000

Epoch 4/1000
Learning rate: 0.000100
treatment params:
tensor(-0.0014)
control params:
tensor(0.1984)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0035)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1988)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0035)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1992)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0035)
treatment params:
tensor(-0.0014)
control params:
tensor(0.1997)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0036)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2001)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0036)
treatment params:
tensor(-0

Epoch 5/1000
Learning rate: 0.000100
treatment params:
tensor(-0.0014)
control params:
tensor(0.2173)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0041)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2176)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0042)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2178)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0042)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2181)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0042)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2183)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0042)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2185)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0042)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2188)
treatment params:
tensor(-0.0042)
control params:
tensor(0.0042)
treatment params:
tensor(-0.0014)
control params:
ten

Epoch 6/1000
Learning rate: 0.000100
treatment params:
tensor(-0.0014)
control params:
tensor(0.2280)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0045)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2282)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0045)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2283)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0045)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2284)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0045)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2286)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0045)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2287)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0045)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2288)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0045)
treatment params:
tensor(-0.0014)
control params:
ten

Epoch 7/1000
Learning rate: 0.000100
treatment params:
tensor(-0.0014)
control params:
tensor(0.2341)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0047)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2341)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0047)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2342)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0047)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2343)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0047)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2344)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0047)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2344)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0047)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2345)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0047)
treatment params:
tensor(-0.0014)
control params:
ten

treatment params:
tensor(-0.0041)
control params:
tensor(0.0048)
0s - loss 710410048.0000

Epoch 8/1000
Learning rate: 0.000100
treatment params:
tensor(-0.0014)
control params:
tensor(0.2375)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0048)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2375)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0048)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2375)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0048)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2376)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0048)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2376)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0048)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2377)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0048)
treatment params:
tensor(-0.0014)
control params:
tensor(0.2377)
treatment params:
tensor(-0

tensor(-0.0014)
control params:
tensor(0.2394)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
0s - loss 710410048.0000

Epoch 9/1000
Learning rate: 0.000100
treatment params:
tensor(-0.0014)
control params:
tensor(0.2394)
treatment params:
tensor(-0.0041)
control params:
tensor(0.0049)
treatment params:
tensor(-0.0014)
control par

KeyboardInterrupt: 