# SAMを実装する

In [1]:
import random
import numpy as np
import pandas as pd
import scipy
import matplotlib.pyplot as plt
from sklearn.preprocessing import scale
from scipy.special import expit
import time
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader

In [2]:
np.random.seed(1234)
random.seed(1234)
np.set_printoptions(precision=2, floatmode='fixed', suppress=True)

In [3]:
def make_data(n_data=2000):

    x = np.random.uniform(low=-1, high=1, size=n_data)  # -1から1の一様乱数

    e_z = np.random.randn(n_data)  # ノイズの生成
    z_prob = scipy.special.expit(-5.0 * x + 5 * e_z)
    Z = np.array([])

    for i in range(n_data):
        Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]
        Z = np.append(Z, Z_i)

    t = np.zeros(n_data)
    for i in range(n_data):
        if x[i] < 0:
            t[i] = 0.5
        elif x[i] >= 0 and x[i] < 0.5:
            t[i] = 0.7
        elif x[i] >= 0.5:
            t[i] = 1.0

    e_y = np.random.randn(n_data)
    Y = 2.0 + t*Z + 0.3*x + 0.1*e_y 

    Y2 = np.random.choice(
        [1.0, 2.0, 3.0, 4.0, 5.0],
        n_data, p=[0.1, 0.2, 0.3, 0.2, 0.2]
    )

    e_y3 = np.random.randn(n_data)
    Y3 = 2 * x + e_y3

    data = pd.DataFrame({
        'x': x,
        'Z': Z,
        'Y': Y,
        'Y2': Y2,
        'Y3': Y3
    })
    
    return data

In [33]:
class CausalMatrixNN(nn.Module):
    
    def __init__(self, n_datal_col):
        
        super(CausalMatrixNN, self).__init__()

        self.weights = torch.nn.Parameter(torch.ones(n_data_col, n_data_col))
        self.mask = 1 - torch.eye(n_data_col, n_data_col)
        
    def forward(self):
        return self.weights * self.mask
    
    def predict_proba(self):
        return torch.sigmoid(self.weights) * self.mask

In [30]:
class SAMGenerator(nn.Module):
    
    def __init__(self, n_data_col, hidden_layer_size_list):
        
        super(SAMGenerator, self).__init__()
        
        self.n_data_col = n_data_col
        self.weight_input = nn.Parameter(torch.normal(0, std=0.1, size=(n_data_col + 1, n_data_col + 1)))
        self.bias_input = nn.Parameter(torch.normal(0, std=0.1, size=(n_data_col, n_data_col + 1)))
        
        layers = []
        hidden_layer_size_list = [n_data_col] + hidden_layer_size_list
        for i in range(len(hidden_layer_size_list) - 1):
            layer_input = hidden_layer_size_list[i]
            layer_output = hidden_layer_size_list[i + 1]
            layers.append(nn.Linear(layer_input, layer_output))
            layers.append(nn.BatchNorm1d(layer_output))
            layers.append(nn.LeakyReLU(.2))
            
        layers.append(nn.Linear(hidden_layer_size_list[-1], n_data_col))
        self.layers = nn.Sequential(*layers)
        
        
    def forward(self, X, noise, adj_matrix):
        
        h = torch.cat([
            X.unsqueeze(1).expand([X.shape[0], self.n_data_col, self.n_data_col]),
            noise.unsqueeze(2)
        ], 2)
        h = h * adj_matrix.t().unsqueeze(0)

        h = h.matmul(self.weight_input) + self.bias_input
        h = h.sum(axis=2)
        
        return self.layers(h)

In [6]:
class SAMDiscriminator(nn.Module):
    
    def __init__(self, n_data_col, hidden_layer_size_list):
        
        super(SAMDiscriminator, self).__init__()
        
        layers = []
        hidden_layer_size_list = [n_data_col] + hidden_layer_size_list
        for i in range(len(hidden_layer_size_list) - 1):
            layer_input = hidden_layer_size_list[i]
            layer_output = hidden_layer_size_list[i + 1]
            layers.append(torch.nn.Linear(layer_input, layer_output))
            layers.append(torch.nn.BatchNorm1d(layer_output))
            layers.append(torch.nn.LeakyReLU(.2))
            
        layers.append(nn.Linear(hidden_layer_size_list[-1], 1))
        layers.append(nn.Sigmoid())
        self.layers = nn.Sequential(*layers)
        
    def forward(self, X):
        return self.layers(X)

In [7]:
def notears_constr(adj_m, max_pow=None):
    
    m_exp = [adj_m]
    if max_pow is None:
        max_pow = adj_m.shape[1]
    while (m_exp[-1].sum() > 0 and len(m_exp) < max_pow):
        m_exp.append(m_exp[-1] @ adj_m / len(m_exp))
        
    return sum([i.diag().sum() for idx, i in enumerate(m_exp)])

In [20]:
data = make_data(n_data=5000)
learning_data = data.copy()
learning_data.loc[:, :] = scale(learning_data.values) 

In [21]:
data = learning_data.copy()

In [22]:
n_data_col = data.shape[1]
data = torch.from_numpy(data.values.astype('float32') )
batch_size = len(data)

data_iterator = DataLoader(
    data, batch_size=batch_size, shuffle=True, drop_last=True
)

In [23]:
hidden_layer_size_list_gen = [500, 500, 500]
hidden_layer_size_list_dis = [500, 500, 100]
epoch = 1000

In [34]:
structural_gates = CausalMatrixNN(n_data_col)
sam = SAMGenerator(n_data_col=n_data_col, hidden_layer_size_list=hidden_layer_size_list_gen)
dis = SAMDiscriminator(n_data_col=n_data_col, hidden_layer_size_list=hidden_layer_size_list_dis)

noise = torch.randn(batch_size, n_data_col)

optimizer_gen = optim.Adam(sam.parameters())
optimizer_dis = optim.Adam(dis.parameters())
optimizer_gate = optim.Adam(structural_gates.parameters())

criterion = nn.BCEWithLogitsLoss()

start_time = time.time()
for loop in range(epoch):
    for data_batched in data_iterator:
        
        optimizer_gen.zero_grad()
        optimizer_gate.zero_grad()
        optimizer_dis.zero_grad()

        adj_matrix = torch.cat([structural_gates(), torch.ones(1, n_data_col)], 0)
        noise.normal_()

        data_gen = sam(X=data_batched, noise=noise, adj_matrix=adj_matrix)
        output_gen_detach = dis(data_gen.detach())
        output_true = dis(data_batched)

        loss_dis = criterion(output_gen_detach,  torch.zeros(output_gen_detach.shape)) + criterion(
            output_true, torch.ones(output_true.shape)
        )
        
        loss_dis.backward()
        optimizer_dis.step()
        
        output_gen = dis(data_gen)
        gates = structural_gates()
        loss_gen = criterion(output_gen, torch.ones(output_gen.shape))
        loss = loss_gen + gates.sum()
        if loop > epoch / 2:
            gates = structural_gates.predict_proba()
            loss = loss + notears_constr(gates * gates) / n_data_col
        
        loss.backward()
        optimizer_gen.step()
        optimizer_gate.step()
        
        if (loop + 1) % 10 == 0:
            
            print('[{}/ {}] Loss Dis: {:.2f}, Loss Gen: {:.2f}, {:.0f}[s]'.format(
                loop + 1, epoch, loss_dis, loss_gen, time.time() - start_time
            ))
            start_time = time.time()
        
        if loop % 100 == 0:
            print('')
            print(structural_gates().detach().numpy())
            print('')


[[0.00 1.00 1.00 1.00 1.00]
 [1.00 0.00 1.00 1.00 1.00]
 [1.00 1.00 0.00 1.00 1.00]
 [1.00 1.00 1.00 0.00 1.00]
 [1.00 1.00 1.00 1.00 0.00]]

[10/ 1000] Loss Dis: 1.42, Loss Gen: 0.52, 11[s]
[20/ 1000] Loss Dis: 1.41, Loss Gen: 0.54, 12[s]
[30/ 1000] Loss Dis: 1.39, Loss Gen: 0.58, 11[s]
[40/ 1000] Loss Dis: 1.39, Loss Gen: 0.58, 11[s]
[50/ 1000] Loss Dis: 1.39, Loss Gen: 0.59, 11[s]
[60/ 1000] Loss Dis: 1.38, Loss Gen: 0.62, 11[s]
[70/ 1000] Loss Dis: 1.36, Loss Gen: 0.61, 11[s]
[80/ 1000] Loss Dis: 1.35, Loss Gen: 0.62, 11[s]
[90/ 1000] Loss Dis: 1.33, Loss Gen: 0.63, 11[s]
[100/ 1000] Loss Dis: 1.33, Loss Gen: 0.64, 11[s]

[[0.00 0.90 0.90 0.90 0.90]
 [0.90 0.00 0.90 0.90 0.90]
 [0.90 0.90 0.00 0.90 0.90]
 [0.90 0.90 0.90 0.00 0.90]
 [0.90 0.90 0.90 0.90 0.00]]

[110/ 1000] Loss Dis: 1.31, Loss Gen: 0.64, 11[s]
[120/ 1000] Loss Dis: 1.27, Loss Gen: 0.65, 11[s]
[130/ 1000] Loss Dis: 1.30, Loss Gen: 0.66, 11[s]
[140/ 1000] Loss Dis: 1.28, Loss Gen: 0.67, 12[s]
[150/ 1000] Loss Dis: 1

KeyboardInterrupt: 

In [None]:
notears_constr(gates * gates)

In [None]:
criterion(output_gen_detach,  torch.zeros(output_gen_detach.shape))

In [121]:
criterion(
            output_true, torch.ones(output_true.shape)
        )

tensor(0.4858, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [118]:
adj_matrix = torch.cat([structural_gates(), torch.ones(1, n_data_col)], 0)
noise.normal_()

sam(X=data, noise=noise, adj_matrix=adj_matrix)

tensor([[-0.4530,  0.0726,  0.4511,  0.1313, -0.1351, -0.1234],
        [ 0.5551, -0.1602, -0.0725,  0.1496, -0.2958,  0.1390],
        [ 0.5793, -0.1725, -0.0989,  0.1751, -0.3039,  0.1149],
        ...,
        [ 0.4760, -0.1629, -0.0676,  0.1756, -0.2365,  0.1156],
        [ 0.5232, -0.1560, -0.0699,  0.1628, -0.2727,  0.1226],
        [ 0.6975, -0.2295, -0.1397,  0.1622, -0.3318,  0.1525]],
       grad_fn=<AddmmBackward>)

In [119]:
data_batched

tensor([[-1.0288,  0.9588, -0.0717, -0.9565, -0.7468, -0.8526],
        [ 0.4281,  0.9588,  1.5529, -0.1527,  0.3245,  0.2860],
        [-0.1957, -1.0429, -0.3203, -0.1527,  0.2314,  0.5084],
        ...,
        [ 1.2523, -1.0429, -0.2571, -0.1527, -0.0278,  0.3842],
        [ 1.5519, -1.0429, -0.1780,  0.6510,  0.1190,  0.2537],
        [ 1.2993, -1.0429, -0.4257,  0.6510,  0.6279,  0.8868]])