-
Notifications
You must be signed in to change notification settings - Fork 1
/
example.py
91 lines (67 loc) · 3.54 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from models.backdoor import VariationalBackdoor, SeparateTrainingWrapper, FinetuningWrapper
from models.vae import GaussianIWAE
from models.nn import SimpleGaussianNN
import torch
import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
import numpy as np
from torch.utils.data import Dataset
# Define some data
class SimpleExample(Dataset):
def __init__(self, size, dim):
np.random.seed(0)
Z = np.random.normal(size=(size, dim))
X = np.random.normal(size=(size, dim)) * Z
Y = (X + Z) * np.random.normal(size=(size, dim))
print(Z.shape)
print(X.shape)
print(Y.shape)
self.Z = np.array(Z, dtype='float32')
self.X = np.array(X, dtype='float32')
self.Y = np.array(Y, dtype='float32')
def __len__(self):
return len(self.Z)
def __getitem__(self, n):
# Tuple must be of the form X, Y, Z where
# Z is the confounder (Z -> X and Z -> Y)
# X is the treatment (X -> Y)
# Y is the target
return self.X[n], self.Y[n], self.Z[n]
torch.manual_seed(0)
dim = 10
data_size = 50000
train = SimpleExample(data_size, dim)
train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Parameterize each of the following components with various models that can estimate log-likelihood each of the following distributions
# A variety of different models are given in the "models" folder. The best parameterization depends on the data.
# P(Z)
confounder = GaussianIWAE(feature_size=dim, latent_size=dim, class_size=0, hidden_size=10, num_samples=5).to(device)
# P(Y | X, Z)
target = SimpleGaussianNN(feature_size=dim, class_size=2 * dim).to(device)
# P(Z | X, Y)
encoder = SimpleGaussianNN(feature_size=dim, class_size=2 * dim).to(device)
# Intialize class for Variational Backdoor Adjustment
vb = VariationalBackdoor(confounder_model=confounder,
target_model=target,
encoder_model=encoder,
backdoor_samples=10, # Number of samples used to compute backdoor adjustment
component_samples=10) # Number of samples used to compute log-likelihood for the inner model
# Two step process: first train each of the components separately, then finetune the encoder to maximize the interventional density
print('Separate Component Training')
logger1 = CSVLogger('trained_models/example/logs', name='separate_training')
trainer1 = pl.Trainer(max_epochs=10, default_root_dir='trained_models/', logger=logger1)
trainer1.fit(model=SeparateTrainingWrapper(vb), train_dataloaders=train_loader)
print('Encoder Finetuning')
logger2 = CSVLogger('trained_models/example/logs', name='finetuning')
trainer2 = pl.Trainer(max_epochs=2, default_root_dir='trained_models/', logger=logger2)
trainer2.fit(model=FinetuningWrapper(vb), train_dataloaders=train_loader)
torch.save(vb.state_dict(), f'trained_models/example/example.pt')
vb.load_state_dict(torch.load(f'trained_models/example/example.pt'))
# Some test data from the same distribution
test = SimpleExample(1000, dim)
test_loader = torch.utils.data.DataLoader(test, batch_size=1000, shuffle=False)
for X, Y, Z in test_loader:
# To obtain interventional likelihood estimate, call get_log_backdoor()
interventional_likelihood = vb.get_log_backdoor(Y.to(device), X.to(device), backdoor_samples=10, component_samples=10)
print('The interventional log-likelihood of the test set is ', interventional_likelihood.sum().item())