In [1]:
from load_data import IHDP_Dataset
import numpy as np
import torch
from sklearn.linear_model import LinearRegression

In [4]:
# データの読み込み
ihdp_train_D = IHDP_Dataset(train=True, mono=True)
ihdp_test_D = train_D = IHDP_Dataset(train=False, mono=True)
ate_true = torch.mean(ihdp_test_D.mu1 - ihdp_test_D.mu0)

In [7]:
# OLSによる因果効果の生成式を推定
X_train = ihdp_train_D.x
feature_train = torch.stack([X_train[:, 0], X_train[:, 1], ihdp_train_D.t.float()], axis=1)
reg = LinearRegression()
reg.fit(feature_train, ihdp_train_D.yf)

# 結果の評価
def f_treat(X, N):
    '''^f(x, 1)を算出'''
    feature_treat = torch.stack([X[:, 0], X[:, 1], torch.ones(N)], axis=1)
    return torch.from_numpy(reg.predict(feature_treat))

def f_control(X, N):
    '''^f(x, 0)を算出'''
    feature_control = torch.stack([X[:, 0], X[:, 1], torch.zeros(N)], axis=1)
    return torch.from_numpy(reg.predict(feature_control))

def loss_pehe(y_treat_hat, y_control_hat, mu1, mu0):
    effect_hat = y_treat_hat - y_control_hat  # 効果の推定量
    effect_true = mu1 - mu0  # 真の効果
    return torch.sqrt(torch.mean((effect_hat - effect_true)**2))

def loss_ate(y_treat_hat, y_control_hat, mu1, mu0):
    effect_hat = y_treat_hat - y_control_hat  # 効果の推定量
    effect_true = mu1 - mu0  # 真の効果
    ate_hat = torch.mean(effect_hat)  # ATEの推定量
    ate_true = torch.mean(effect_true)  # 真のATE
    return torch.abs(ate_hat - ate_true)

In [8]:
def evaluation(D, fname):
    # potential outcomeの推定
    y_treat_hat = f_treat(D.x, D.N)
    y_control_hat = f_control(D.x, D.N)

    # \epsilon_{PEHE}を算出
    pehe = loss_pehe(y_treat_hat, y_control_hat, D.mu1, D.mu0)

    # \epsilon_{ATE}を算出
    ate_error = loss_ate(y_treat_hat, y_control_hat, D.mu1, D.mu0)

    # 結果の表示
    print('pehe = ', pehe.item())
    print('error of ate =', ate_error.item())

    # 結果の保存
    torch.save({
        'X_test': D.x,
        'pehe': pehe,
        'ate_error': ate_error,
    }, fname)

In [11]:
# within sample(WS)での推定誤差の評価
evaluation(ihdp_train_D, 'results/ols_WS.pt')

pehe =  0.8659573793411255
error of ate = 0.00019025802612304688


In [12]:
# out of sample(OoS)での推定誤差を評価
evaluation(ihdp_test_D, 'results/ols_OoS.pt')

pehe =  0.7958115339279175
error of ate = 0.044965267181396484
