# SAMを解剖する

In [1]:
import random
import numpy as np
import pandas as pd
import scipy
import matplotlib.pyplot as plt
from sklearn.preprocessing import scale
from scipy.special import expit
from tqdm import tqdm
import torch
import torch as th
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from cdt.utils.torch import ChannelBatchNorm1d, MatrixSampler, Linear3D

No GPU automatically detected. Setting SETTINGS.GPU to 0, and SETTINGS.NJOBS to cpu_count.


In [2]:
np.random.seed(1234)
random.seed(1234)
np.set_printoptions(precision=2, floatmode='fixed', suppress=True)

## 関数定義

In [3]:
def make_data(n_data=2000):

    x = np.random.uniform(low=-1, high=1, size=n_data)  # -1から1の一様乱数

    e_z = np.random.randn(n_data)  # ノイズの生成
    z_prob = scipy.special.expit(-5.0 * x + 5 * e_z)
    Z = np.array([])

    for i in range(n_data):
        Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]
        Z = np.append(Z, Z_i)

    t = np.zeros(n_data)
    for i in range(n_data):
        if x[i] < 0:
            t[i] = 0.5
        elif x[i] >= 0 and x[i] < 0.5:
            t[i] = 0.7
        elif x[i] >= 0.5:
            t[i] = 1.0

    e_y = np.random.randn(n_data)
    Y = 2.0 + t*Z + 0.3*x + 0.1*e_y 

    Y2 = np.random.choice(
        [1.0, 2.0, 3.0, 4.0, 5.0],
        n_data, p=[0.1, 0.2, 0.3, 0.2, 0.2]
    )

    e_y3 = np.random.randn(n_data)
    Y3 = 3 * Y + Y2 + e_y3

    e_y4 = np.random.randn(n_data)
    Y4 = 3 * Y3 + 2 * e_y4 + 5

    data = pd.DataFrame({
        'x': x,
        'Z': Z,
        'Y': Y,
        'Y2': Y2,
        'Y3': Y3,
        'Y4': Y4
    })
    
    return data

In [4]:
class SAMDiscriminator(nn.Module):
    
    def __init__(self, n_data_col, n_hidden_layer, n_hidden_layers):
        
        super(SAMDiscriminator, self).__init__()
        
        self.n_data_col = n_data_col
        
        layers = []
        layers.append(nn.Linear(n_data_col, n_hidden_layer))
        layers.append(nn.BatchNorm1d(n_hidden_layer))
        layers.append(nn.LeakyReLU(.2))
        
        for i in range(n_hidden_layers - 1):
            layers.append(nn.Linear(n_hidden_layer, n_hidden_layer))
            layers.append(nn.BatchNorm1d(n_hidden_layer))
            layers.append(nn.LeakyReLU(.2))
            
        layers.append(nn.Linear(n_hidden_layer, 1))
        self.layers = nn.Sequential(*layers)
        
        self.register_buffer(
            'mask', torch.eye(n_data_col, n_data_col).unsqueeze(0)
        )
        
    def forward(self, input, obs_data=None):
        
        if obs_data is not None:
            return [
                self.layers(i) for i in torch.unbind(
                    obs_data.unsqueeze(1) * (1 - self.mask) + input.unsqueeze(1) * self.mask,
                    1
                )
            ]
        else:
            return self.layers(input)
    
    def reset_parameters(self):
        
        for layer in self.layers:
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()

In [5]:
class SAMGenerator(nn.Module):
    
    def __init__(self, n_data_col, n_hidden_layer):
        
        super(SAMGenerator, self).__init__()
        
        skeleton = 1 - torch.eye(n_data_col + 1, n_data_col)
        
        self.register_buffer('skeleton', skeleton)
        
        self.input_layer = Linear3D((n_data_col, n_data_col + 1, n_hidden_layer))
        
        layers = []
        layers.append(ChannelBatchNorm1d(n_data_col, n_hidden_layer))
        layers.append(nn.Tanh())
        self.layers = nn.Sequential(*layers)
        
        self.output_layer = Linear3D((n_data_col, n_hidden_layer, 1))
        
    def forward(self, data, noise, adj_matrix, drawn_neurons=None):
        
        x = self.input_layer(data, noise, adj_matrix * self.skeleton)
        
        x = self.layers(x)
        
        output = self.output_layer(x, noise=None, adj_matrix=drawn_neurons)
        
        return output.squeeze(2)
    
    def reset_parameters(self):
        
        self.input_layer.reset_parameters()
        self.output_layer.reset_parameters()
        
        for layer in self.layers:
            if hasattr(layer, 'reset_parametres'):
                layer.register_parameters()

In [6]:
def notears_constr(adj_m, max_pow=None):
    
    m_exp = [adj_m]
    if max_pow is None:
        max_pow = adj_m.shape[1]
    while (m_exp[-1].sum() > 0 and len(m_exp) < max_pow):
        m_exp.append(m_exp[-1] @ adj_m / len(m_exp))
        
    return sum([i.diag().sum() for idx, i in enumerate(m_exp)])

In [7]:
def train_sam_model(
    data,
    n_hidden_layer_gen=100, n_hidden_layer_dis=100,
    n_hidden_layers_dis=2,
    lr_gen=0.01*0.5, lr_dis=0.01*0.5*2,
    dag_start_rate=0.5, dag_penalization_increase=0.001*10,
    epochs_train=100, epochs_test=100,
    lambda1=5.0*20, lambda2=0.005*20
):

    data_columns = data.columns.tolist() 
    n_data_col = len(data_columns)  
    data = torch.from_numpy(data.values.astype('float32') )
    batch_size = len(data)

    data_iterator = DataLoader(
        data, batch_size=batch_size, shuffle=True, drop_last=True
    )

    sam = SAMGenerator(n_data_col, n_hidden_layer_gen)
    sam.reset_parameters()
    sampler_graph = MatrixSampler(n_data_col, mask=None, gumbel=False)
    sampler_neuron = MatrixSampler((n_hidden_layer_gen, n_data_col), mask=False, gumbel=True)

    sampler_graph.weights.data.fill_(2)

    discriminator = SAMDiscriminator(
        n_data_col=n_data_col, n_hidden_layer=n_hidden_layer_dis, n_hidden_layers=n_hidden_layers_dis
    )
    discriminator.reset_parameters()  

    optimizer_gen = optim.Adam(sam.parameters(), lr=lr_gen)
    optimizer_dis = optim.Adam(discriminator.parameters(), lr=lr_dis)
    optimizer_graph = optim.Adam(sampler_graph.parameters(), lr=lr_gen)
    optimizer_neuron = optim.Adam(sampler_neuron.parameters(), lr=lr_gen)

    criterion = nn.BCEWithLogitsLoss()

    _true = torch.ones(1)
    _false = torch.zeros(1)

    noise = torch.randn(batch_size, n_data_col)
    noise_row = torch.ones(1, n_data_col)

    output = torch.zeros(n_data_col, n_data_col)
    output_loss = torch.zeros(1, 1)

    pbar = tqdm(range(epochs_train + epochs_test))
    for epoch in pbar:
        for loop_num, data_batched in enumerate(data_iterator):

            optimizer_gen.zero_grad()
            optimizer_graph.zero_grad()
            optimizer_neuron.zero_grad()
            optimizer_dis.zero_grad()

            drawn_graph = sampler_graph()
            drawn_neurons = sampler_neuron()

            noise.normal_()
            generated_variables = sam(
                data=data_batched, 
                noise=noise,
                adj_matrix=torch.cat(
                    [drawn_graph, noise_row], 0
                ), 
                drawn_neurons=drawn_neurons
            )

            dis_vars_d = discriminator(generated_variables.detach(), data_batched)
            dis_vars_g = discriminator(generated_variables, data_batched)
            true_vars_dis = discriminator(data_batched) 

            loss_dis = sum(
                [criterion(gen, _false.expand_as(gen)) for gen in dis_vars_d]
            ) / n_data_col + criterion(
                true_vars_dis, _true.expand_as(true_vars_dis)
            )

            loss_gen = sum([criterion(gen, _true.expand_as(gen)) for gen in dis_vars_g])

            if epoch < epochs_train:
                loss_dis.backward()
                optimizer_dis.step()

            loss_struc = lambda1 / batch_size * drawn_graph.sum()     
            loss_func = lambda2 / batch_size * drawn_neurons.sum()  

            loss_regul = loss_struc + loss_func

            if epoch <= epochs_train * dag_start_rate:
                loss = loss_gen + loss_regul
            else:
                filters = sampler_graph.get_proba()
                loss = loss_gen + loss_regul + (
                    (epoch - epochs_train * dag_start_rate) * dag_penalization_increase
                ) * notears_constr(filters * filters)

            if epoch >= epochs_train:
                output.add_(filters.data)
                output_loss.add_(loss_gen.data)
            else:
                loss.backward(retain_graph=True)
                optimizer_gen.step()
                optimizer_graph.step()
                optimizer_neuron.step()

            # 進捗の表示
            if epoch % 50 == 0:
                pbar.set_postfix(
                    gen=loss_gen.item()/n_data_col,
                    dis=loss_dis.item(),
                    egul_loss=loss_regul.item(),
                    tot=loss.item()
                )

    return output.cpu().numpy()/epochs_test, output_loss.cpu().numpy()/epochs_test/n_data_col


In [8]:
data = make_data(n_data=2000)
learning_data = data.copy()
learning_data.loc[:, :] = scale(learning_data.values) 

In [9]:
data =learning_data.copy()
n_hidden_layer_gen=100
n_hidden_layer_dis=100
n_hidden_layers_dis=2
lr_gen=0.01*0.5
lr_dis=0.01*0.5*2
dag_start_rate=0.5
dag_penalization_increase=0.001*10
epochs_train=1000
epochs_test=1000
lambda1=5.0*20
lambda2=0.005*20

In [10]:
data_columns = data.columns.tolist() 
n_data_col = len(data_columns)  
data = torch.from_numpy(data.values.astype('float32') )
batch_size = len(data)

data_iterator = DataLoader(
    data, batch_size=batch_size, shuffle=True, drop_last=True
)

sam = SAMGenerator(n_data_col, n_hidden_layer_gen)
sam.reset_parameters()
sampler_graph = MatrixSampler(n_data_col, mask=None, gumbel=False)
sampler_neuron = MatrixSampler((n_hidden_layer_gen, n_data_col), mask=False, gumbel=True)

sampler_graph.weights.data.fill_(2)

discriminator = SAMDiscriminator(
    n_data_col=n_data_col, n_hidden_layer=n_hidden_layer_dis, n_hidden_layers=n_hidden_layers_dis
)
discriminator.reset_parameters()  

optimizer_gen = optim.Adam(sam.parameters(), lr=lr_gen)
optimizer_dis = optim.Adam(discriminator.parameters(), lr=lr_dis)
optimizer_graph = optim.Adam(sampler_graph.parameters(), lr=lr_gen)
optimizer_neuron = optim.Adam(sampler_neuron.parameters(), lr=lr_gen)

criterion = nn.BCEWithLogitsLoss()

_true = torch.ones(1)
_false = torch.zeros(1)

noise = torch.randn(batch_size, n_data_col)
noise_row = torch.ones(1, n_data_col)

output = torch.zeros(n_data_col, n_data_col)
output_loss = torch.zeros(1, 1)

pbar = tqdm(range(epochs_train + epochs_test))

  0%|          | 0/2000 [00:00<?, ?it/s]

In [11]:
data_batched = list(data_iterator)[0]
n_hidden_layer = n_hidden_layer_gen

In [12]:
data_batched

tensor([[-1.6264,  0.9588, -0.4542,  1.4548,  1.5592,  1.0264],
        [-1.5630,  0.9588, -0.2876,  0.6510,  0.1707,  0.0334],
        [ 1.1218, -1.0429, -0.6249, -0.9565, -1.1501, -1.0986],
        ...,
        [ 1.2443, -1.0429, -0.0804, -0.1527, -0.0903,  0.0187],
        [-1.3658,  0.9588,  0.0247,  1.4548,  1.4087,  0.5823],
        [-1.5699,  0.9588, -0.4240,  1.4548,  0.9753,  0.9582]])

In [15]:
optimizer_gen.zero_grad()
optimizer_graph.zero_grad()
optimizer_neuron.zero_grad()
optimizer_dis.zero_grad()

drawn_graph = sampler_graph()
drawn_neurons = sampler_neuron()

noise.normal_()
generated_variables = sam(
    data=data_batched, 
    noise=noise,
    adj_matrix=torch.cat(
        [drawn_graph, noise_row], 0
    ), 
    drawn_neurons=drawn_neurons
)

dis_vars_d = discriminator(generated_variables.detach(), data_batched)
dis_vars_g = discriminator(generated_variables, data_batched)
true_vars_dis = discriminator(data_batched) 

In [17]:
dis_vars_d

[tensor([[-0.1235],
         [ 0.1333],
         [ 0.7230],
         ...,
         [ 0.4232],
         [ 0.0601],
         [ 0.1103]], grad_fn=<AddmmBackward>),
 tensor([[ 0.3398],
         [ 0.3073],
         [ 0.3554],
         ...,
         [-0.0626],
         [ 0.3504],
         [ 0.4491]], grad_fn=<AddmmBackward>),
 tensor([[ 0.5451],
         [ 0.3470],
         [ 0.5242],
         ...,
         [-0.0019],
         [ 0.5945],
         [ 0.7447]], grad_fn=<AddmmBackward>),
 tensor([[ 0.0488],
         [ 0.1127],
         [ 0.5327],
         ...,
         [-0.0641],
         [ 0.0772],
         [ 0.2441]], grad_fn=<AddmmBackward>),
 tensor([[ 0.6313],
         [ 0.1817],
         [ 0.1072],
         ...,
         [-0.2520],
         [ 0.4559],
         [ 0.6192]], grad_fn=<AddmmBackward>),
 tensor([[ 0.2091],
         [ 0.2019],
         [ 0.2617],
         ...,
         [-0.1229],
         [ 0.2435],
         [ 0.3586]], grad_fn=<AddmmBackward>)]

## Generator

In [13]:
optimizer_gen.zero_grad()
optimizer_graph.zero_grad()
optimizer_neuron.zero_grad()
optimizer_dis.zero_grad()

drawn_graph = sampler_graph()
drawn_neurons = sampler_neuron()

noise.normal_()
adj_matrix = torch.cat([drawn_graph, noise_row], 0)

In [14]:
adj_matrix

tensor([[0., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1.],
        [1., 1., 0., 1., 0., 0.],
        [1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]], grad_fn=<CatBackward>)

In [15]:
input_layer = Linear3D((n_data_col, n_data_col + 1, n_hidden_layer))
layers = []
layers.append(ChannelBatchNorm1d(n_data_col, n_hidden_layer))
layers.append(nn.Tanh())
layers = nn.Sequential(*layers)
output_layer = Linear3D((n_data_col, n_hidden_layer, 1))
skeleton = 1 - torch.eye(n_data_col + 1, n_data_col)

In [16]:
adj_matrix * skeleton

tensor([[0., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1.],
        [1., 1., 0., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]], grad_fn=<MulBackward0>)

In [17]:
x = input_layer(data, noise, adj_matrix * skeleton)

In [18]:
x.shape

torch.Size([2000, 6, 100])

In [19]:
x = layers(x)

In [20]:
x.shape

torch.Size([2000, 6, 100])

In [21]:
output = output_layer(x, noise=None, adj_matrix=drawn_neurons)

In [22]:
output.shape

torch.Size([2000, 6, 1])

In [23]:
output = output.squeeze(2)

In [24]:
output

tensor([[ 2.2622e-01, -5.5772e-02, -7.1864e-02,  1.8798e-01,  2.8039e-01,
         -8.9512e-02],
        [ 2.4660e-01,  1.3123e-02, -2.2076e-01, -3.3251e-01,  4.4170e-01,
         -2.0561e-01],
        [-3.2613e-01,  1.2684e-01,  6.1894e-01,  4.3970e-01, -4.9755e-01,
         -1.3570e-02],
        ...,
        [ 3.7394e-03,  1.9085e-01,  2.0638e-01, -2.7084e-01, -4.4423e-01,
         -5.1361e-02],
        [ 6.6384e-03,  7.7289e-02, -9.3609e-05, -2.3796e-01, -3.1015e-01,
         -9.7169e-02],
        [ 5.3628e-02,  9.9456e-02,  2.5218e-01,  5.6730e-02, -4.1738e-01,
         -1.1370e-01]], grad_fn=<SqueezeBackward1>)

### Linear3D

In [17]:
weight = torch.nn.Parameter(torch.Tensor(n_data_col, n_data_col + 1, n_hidden_layer))
bias = torch.nn.Parameter(torch.Tensor(n_data_col, n_hidden_layer))

In [18]:
noise.unsqueeze(2).shape

torch.Size([2000, 6, 1])

In [19]:
h_tmp = torch.cat([
    data.unsqueeze(1).expand([data.shape[0], n_data_col, n_data_col]),
    noise.unsqueeze(2)
], 2)

In [20]:
h_tmp.shape

torch.Size([2000, 6, 7])

In [21]:
h_tmp[0]

tensor([[-1.0288,  0.9588, -0.0717, -0.9565, -0.7468, -0.8526,  1.2960],
        [-1.0288,  0.9588, -0.0717, -0.9565, -0.7468, -0.8526, -1.3462],
        [-1.0288,  0.9588, -0.0717, -0.9565, -0.7468, -0.8526, -0.8286],
        [-1.0288,  0.9588, -0.0717, -0.9565, -0.7468, -0.8526,  0.8917],
        [-1.0288,  0.9588, -0.0717, -0.9565, -0.7468, -0.8526,  0.6658],
        [-1.0288,  0.9588, -0.0717, -0.9565, -0.7468, -0.8526,  0.4814]])

In [22]:
adj_matrix.t().unsqueeze(0).shape

torch.Size([1, 6, 7])

In [23]:
adj_matrix.t().unsqueeze(0)

tensor([[[0., 1., 1., 1., 1., 1., 1.],
         [1., 0., 1., 1., 1., 1., 1.],
         [1., 1., 0., 1., 1., 1., 1.],
         [1., 1., 1., 0., 1., 1., 1.],
         [1., 1., 1., 1., 0., 1., 1.],
         [1., 1., 1., 1., 1., 0., 1.]]], grad_fn=<UnsqueezeBackward0>)

In [31]:
h_tmp = h_tmp * adj_matrix.t().unsqueeze(0)
h_tmp.shape

torch.Size([2000, 6, 7])

In [32]:
h_tmp[0]

tensor([[-0.0000,  0.9588, -0.0717, -0.9565, -0.7468, -0.8526, -1.8773],
        [-1.0288,  0.0000, -0.0717, -0.9565, -0.7468, -0.8526, -1.0362],
        [-1.0288,  0.9588, -0.0000, -0.9565, -0.7468, -0.8526,  0.3221],
        [-1.0288,  0.9588, -0.0717, -0.0000, -0.7468, -0.8526,  0.0459],
        [-1.0288,  0.9588, -0.0717, -0.9565, -0.0000, -0.8526, -0.2099],
        [-1.0288,  0.9588, -0.0717, -0.9565, -0.7468, -0.0000,  0.9651]],
       grad_fn=<SelectBackward>)

In [33]:
h_tmp[1]

tensor([[ 0.0000,  0.9588,  1.5529, -0.1527,  0.3245,  0.2860, -0.3984],
        [ 0.4281,  0.0000,  1.5529, -0.1527,  0.3245,  0.2860, -0.7596],
        [ 0.4281,  0.9588,  0.0000, -0.1527,  0.3245,  0.2860, -0.5563],
        [ 0.4281,  0.9588,  1.5529, -0.0000,  0.3245,  0.2860, -1.7631],
        [ 0.4281,  0.9588,  1.5529, -0.1527,  0.0000,  0.2860,  0.0732],
        [ 0.4281,  0.9588,  1.5529, -0.1527,  0.3245,  0.0000,  0.1426]],
       grad_fn=<SelectBackward>)

In [34]:
h_tmp.transpose(0, 1)[1]

tensor([[-1.0288,  0.0000, -0.0717,  ..., -0.7468, -0.8526, -1.0362],
        [ 0.4281,  0.0000,  1.5529,  ...,  0.3245,  0.2860, -0.7596],
        [-0.1957, -0.0000, -0.3203,  ...,  0.2314,  0.5084,  0.2556],
        ...,
        [ 1.2523, -0.0000, -0.2571,  ..., -0.0278,  0.3842, -0.1061],
        [ 1.5519, -0.0000, -0.1780,  ...,  0.1190,  0.2537,  0.3969],
        [ 1.2993, -0.0000, -0.4257,  ...,  0.6279,  0.8868,  1.2072]],
       grad_fn=<SelectBackward>)

In [35]:
output = h_tmp.transpose(0, 1).matmul(weight)

In [36]:
output[0]

tensor([[ 0.0000,  1.7978,  0.0000,  ..., -4.9201,  0.0000, -3.1332],
        [ 0.0000,  1.7978,  0.0000,  ..., -0.1387,  0.0000,  4.0563],
        [ 0.0000, -1.9555,  0.0000,  ...,  3.5570,  0.0000,  0.7865],
        ...,
        [ 0.0000, -1.9555,  0.0000,  ..., -1.0816,  0.0000,  0.1863],
        [ 0.0000, -1.9555,  0.0000,  ..., -2.0183,  0.0000,  0.3650],
        [ 0.0000, -1.9555,  0.0000,  ..., -2.3395,  0.0000,  2.0418]],
       grad_fn=<SelectBackward>)

In [37]:
output = h_tmp.transpose(0, 1).matmul(weight)
if bias is not None:
    output += bias.unsqueeze(1)

In [38]:
output = output.transpose(0, 1)

In [39]:
output.shape

torch.Size([2000, 6, 100])

## Discriminator

In [40]:
generated_variables = sam(
    data=data_batched, 
    noise=noise,
    adj_matrix=torch.cat(
        [drawn_graph, noise_row], 0
    ), 
    drawn_neurons=drawn_neurons
)

In [41]:
n_hidden_layer = n_hidden_layer_dis
n_hidden_layers = 2

In [42]:
nn.Linear(n_data_col, n_hidden_layer).weight.shape

torch.Size([100, 6])

In [43]:
layers = []
layers.append(nn.Linear(n_data_col, n_hidden_layer))
layers.append(nn.BatchNorm1d(n_hidden_layer))
layers.append(nn.LeakyReLU(.2))

for i in range(n_hidden_layers - 1):
    layers.append(nn.Linear(n_hidden_layer, n_hidden_layer))
    layers.append(nn.BatchNorm1d(n_hidden_layer))
    layers.append(nn.LeakyReLU(.2))

layers.append(nn.Linear(n_hidden_layer, 1))
layers = nn.Sequential(*layers)

In [44]:
mask = torch.eye(n_data_col, n_data_col).unsqueeze(0)

In [45]:
mask

tensor([[[1., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 1.]]])

In [46]:
generated_variables

tensor([[-0.5002, -0.2319,  0.2539, -0.5178,  0.2953, -0.2955],
        [-0.2911,  0.2191, -0.2386, -0.1601, -0.0403, -0.0851],
        [ 0.2335,  0.3069, -0.0659,  0.2923, -0.0417, -0.1788],
        ...,
        [-0.0857,  0.1863, -0.3762,  0.3394, -0.3351,  0.5045],
        [-0.1296, -0.3794,  0.1099, -0.1596,  0.1346,  0.1124],
        [-0.4049, -0.1997,  0.2635, -0.2337,  0.0892, -0.0347]],
       grad_fn=<SqueezeBackward1>)

In [47]:
obs_data = data_batched.detach()

In [48]:
obs_data

tensor([[-1.5586, -1.0429, -1.6498,  0.6510, -1.2175, -1.5709],
        [ 0.6651, -1.0429, -0.2457, -0.9565, -1.0417, -0.9757],
        [ 1.6284, -1.0429,  0.2841, -0.1527, -0.2247,  0.1411],
        ...,
        [-0.2378,  0.9588,  0.5697, -0.9565,  0.2053,  0.3430],
        [-1.6105,  0.9588, -0.4935,  1.4548,  0.5564,  0.4725],
        [ 0.3310, -1.0429, -1.1649,  0.6510, -0.7197, -0.9972]])

In [49]:
obs_data.unsqueeze(1) * (1 - mask) 

tensor([[[-0.0000, -1.0429, -1.6498,  0.6510, -1.2175, -1.5709],
         [-1.5586, -0.0000, -1.6498,  0.6510, -1.2175, -1.5709],
         [-1.5586, -1.0429, -0.0000,  0.6510, -1.2175, -1.5709],
         [-1.5586, -1.0429, -1.6498,  0.0000, -1.2175, -1.5709],
         [-1.5586, -1.0429, -1.6498,  0.6510, -0.0000, -1.5709],
         [-1.5586, -1.0429, -1.6498,  0.6510, -1.2175, -0.0000]],

        [[ 0.0000, -1.0429, -0.2457, -0.9565, -1.0417, -0.9757],
         [ 0.6651, -0.0000, -0.2457, -0.9565, -1.0417, -0.9757],
         [ 0.6651, -1.0429, -0.0000, -0.9565, -1.0417, -0.9757],
         [ 0.6651, -1.0429, -0.2457, -0.0000, -1.0417, -0.9757],
         [ 0.6651, -1.0429, -0.2457, -0.9565, -0.0000, -0.9757],
         [ 0.6651, -1.0429, -0.2457, -0.9565, -1.0417, -0.0000]],

        [[ 0.0000, -1.0429,  0.2841, -0.1527, -0.2247,  0.1411],
         [ 1.6284, -0.0000,  0.2841, -0.1527, -0.2247,  0.1411],
         [ 1.6284, -1.0429,  0.0000, -0.1527, -0.2247,  0.1411],
         [ 1.6284, -1

In [50]:
generated_variables.unsqueeze(1)

tensor([[[-0.5002, -0.2319,  0.2539, -0.5178,  0.2953, -0.2955]],

        [[-0.2911,  0.2191, -0.2386, -0.1601, -0.0403, -0.0851]],

        [[ 0.2335,  0.3069, -0.0659,  0.2923, -0.0417, -0.1788]],

        ...,

        [[-0.0857,  0.1863, -0.3762,  0.3394, -0.3351,  0.5045]],

        [[-0.1296, -0.3794,  0.1099, -0.1596,  0.1346,  0.1124]],

        [[-0.4049, -0.1997,  0.2635, -0.2337,  0.0892, -0.0347]]],
       grad_fn=<UnsqueezeBackward0>)

In [51]:
(obs_data.unsqueeze(1) * (1 - mask) + generated_variables.unsqueeze(1) * mask)[0]

tensor([[-0.5002, -1.0429, -1.6498,  0.6510, -1.2175, -1.5709],
        [-1.5586, -0.2319, -1.6498,  0.6510, -1.2175, -1.5709],
        [-1.5586, -1.0429,  0.2539,  0.6510, -1.2175, -1.5709],
        [-1.5586, -1.0429, -1.6498, -0.5178, -1.2175, -1.5709],
        [-1.5586, -1.0429, -1.6498,  0.6510,  0.2953, -1.5709],
        [-1.5586, -1.0429, -1.6498,  0.6510, -1.2175, -0.2955]],
       grad_fn=<SelectBackward>)

In [52]:
(obs_data.unsqueeze(1) * (1 - mask) + generated_variables.unsqueeze(1) * mask)[1]

tensor([[-0.2911, -1.0429, -0.2457, -0.9565, -1.0417, -0.9757],
        [ 0.6651,  0.2191, -0.2457, -0.9565, -1.0417, -0.9757],
        [ 0.6651, -1.0429, -0.2386, -0.9565, -1.0417, -0.9757],
        [ 0.6651, -1.0429, -0.2457, -0.1601, -1.0417, -0.9757],
        [ 0.6651, -1.0429, -0.2457, -0.9565, -0.0403, -0.9757],
        [ 0.6651, -1.0429, -0.2457, -0.9565, -1.0417, -0.0851]],
       grad_fn=<SelectBackward>)

In [53]:
tmp = torch.unbind(
    obs_data.unsqueeze(1) * (1 - mask) + generated_variables.unsqueeze(1) * mask,
    1
)

In [54]:
tmp[0]

tensor([[-0.5002, -1.0429, -1.6498,  0.6510, -1.2175, -1.5709],
        [-0.2911, -1.0429, -0.2457, -0.9565, -1.0417, -0.9757],
        [ 0.2335, -1.0429,  0.2841, -0.1527, -0.2247,  0.1411],
        ...,
        [-0.0857,  0.9588,  0.5697, -0.9565,  0.2053,  0.3430],
        [-0.1296,  0.9588, -0.4935,  1.4548,  0.5564,  0.4725],
        [-0.4049, -1.0429, -1.1649,  0.6510, -0.7197, -0.9972]],
       grad_fn=<UnbindBackward>)

In [56]:
output = [layers(i) for i in torch.unbind(
    obs_data.unsqueeze(1) * (1 - mask) + generated_variables.unsqueeze(1) * mask,
    1
)]

In [57]:
np.array([[1, 2, 3]]).dot(np.array([[1, 2], [1, 2], [1, 2]]))

array([[ 6, 12]])

In [58]:
len(output)

6

In [59]:
output[0].shape

torch.Size([2000, 1])

### MatrixSampler

In [81]:
graph_size = n_data_col
mask = None
gumbel = False

graph_size = (graph_size, graph_size)
weights = th.nn.Parameter(th.FloatTensor(*graph_size))
weights.data.zero_()
if mask is None:
    mask = 1 - th.eye(*graph_size)

ones_tensor = th.ones(graph_size)
zeros_tensor = th.zeros(*graph_size)

In [63]:
def _sample_logistic(shape, out=None):
    U = out.resize_(shape).uniform_() if out is not None else th.rand(shape)
    return th.log(U) - th.log(1-U)

In [64]:
_sample_logistic(10)

tensor([ 2.9600,  3.0563, -0.8097,  0.4350,  1.2195,  0.5548,  1.5671,  4.2093,
        -0.4976, -0.8441])

In [65]:
def _sigmoid_sample(logits, tau=1):
    dims = logits.dim()
    logistic_noise = _sample_logistic(logits.size(), out=logits.data.new())
    y = logits + logistic_noise
    return th.sigmoid(y / tau)

行列にノイズを加える関数<br>
hardをTrueにすると、0.5より大きい値の場合は１に、0.5未満の場合は０になるようにしている

In [111]:
def gumbel_sigmoid(logits, ones_tensor, zeros_tensor, tau=1, hard=False):

    shape = logits.size()
    y_soft = _sigmoid_sample(logits, tau=tau)
    if hard:
        y_hard = th.where(y_soft > 0.5, ones_tensor, zeros_tensor)
        y = y_hard.detach() - y_soft.detach() + y_soft
    else:
        y = y_soft
    return y

#### sample_graph

In [112]:
tau = 1
drawhard = True
drawn_proba = gumbel_sigmoid(
    2 * weights, ones_tensor, zeros_tensor, tau=tau, hard=drawhard
)

In [113]:
drawn_proba

tensor([[1., 0., 0., 1., 0., 0.],
        [1., 0., 1., 1., 0., 1.],
        [1., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1.]], grad_fn=<AddBackward0>)

#### sample_neuron

In [114]:
mask = None
gumbel = True
graph_size = (n_hidden_layer_gen, n_data_col)

weights = th.nn.Parameter(th.FloatTensor(*graph_size))
weights.data.zero_()
if mask is None:
    mask = 1 - th.eye(*graph_size)

ones_tensor = th.ones(graph_size)
zeros_tensor = th.zeros(*graph_size)

In [118]:
th.stack([weights.view(-1), -weights.view(-1)], 1).shape

torch.Size([600, 2])

In [124]:
def _sample_gumbel(shape, eps=1e-10, out=None):
    """
    Implementation of pytorch.
    (https://github.com/pytorch/pytorch/blob/e4eee7c2cf43f4edba7a14687ad59d3ed61d9833/torch/nn/functional.py)
    Sample from Gumbel(0, 1)
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
    (MIT license)
    """
    U = out.resize_(shape).uniform_() if out is not None else th.rand(shape)
    return - th.log(eps - th.log(U + eps))

In [122]:
def _gumbel_softmax_sample(logits, tau=1, eps=1e-10):
    dims = logits.dim()
    gumbel_noise = _sample_gumbel(logits.size(), eps=eps, out=logits.data.new())
    y = logits + gumbel_noise
    return th.softmax(y / tau, dims-1)

In [120]:

def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10):
   
    shape = logits.size()
    assert len(shape) == 2
    y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps)
    if hard:
        _, k = y_soft.data.max(-1)
        # this bit is based on
        # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
        y_hard = logits.data.new(*shape).zero_().scatter_(-1, k.view(-1, 1), 1.0)
        # this cool bit of code achieves two things:
        # - makes the output value exactly one-hot (since we add then
        #   subtract y_soft value)
        # - makes the gradient equal to y_soft gradient (since we strip
        #   all other gradients)
        y = y_hard - y_soft.data + y_soft
    else:
        y = y_soft
    return y

In [125]:
drawn_proba = gumbel_softmax(
    th.stack([weights.view(-1), -weights.view(-1)], 1),
    tau=tau, hard=drawhard
)[:, 0].view(graph_size)

In [126]:
drawn_proba

tensor([[1., 1., 1., 0., 0., 1.],
        [0., 0., 0., 1., 1., 0.],
        [1., 0., 0., 0., 1., 1.],
        [1., 0., 1., 1., 1., 1.],
        [0., 1., 1., 0., 0., 0.],
        [0., 1., 1., 0., 1., 0.],
        [1., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 1., 1., 1., 0.],
        [0., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 1., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 0., 1., 0., 1., 0.],
        [0., 1., 1., 1., 1., 1.],
        [1., 0., 0., 1., 0., 0.],
        [0., 1., 1., 1., 1., 1.],
        [1., 1., 1., 0., 0., 1.],
        [1., 1., 0., 1., 1., 0.],
        [0., 0., 1., 0., 1., 1.],
        [0., 0., 1., 1., 0., 0.],
        [0., 0., 1., 0., 0., 1.],
        [0., 1., 0., 1., 0., 1.],
        [0., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 1., 0.],
        [1., 0., 0., 0., 0., 1.],
        [1., 1

In [127]:
sampler_graph = MatrixSampler(n_data_col, mask=None, gumbel=False)
filters = sampler_graph.get_proba()

In [None]:
def notears_constr(adj_m, max_pow=None):
    
    m_exp = [adj_m]
    if max_pow is None:
        max_pow = adj_m.shape[1]
    while (m_exp[-1].sum() > 0 and len(m_exp) < max_pow):
        m_exp.append(m_exp[-1] @ adj_m / len(m_exp))
        
    return sum([i.diag().sum() for idx, i in enumerate(m_exp)])

In [129]:
filters * filters

tensor([[0.0000, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.0000, 0.2500, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.0000, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.0000, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.0000]],
       grad_fn=<MulBackward0>)

In [130]:
notears_constr(filters * filters, max_pow=None)

tensor(3.3814, grad_fn=<AddBackward0>)

In [132]:
adj_m = filters * filters

In [133]:
m_exp = [adj_m]

In [135]:
max_pow = adj_m.shape[1]

In [136]:
max_pow

6

In [142]:
while (m_exp[-1].sum() > 0 and len(m_exp) < max_pow):
    m_exp.append(m_exp[-1] @ adj_m / len(m_exp))

In [147]:
adj_m

tensor([[0.0000, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.0000, 0.2500, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.0000, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.0000, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.0000]],
       grad_fn=<MulBackward0>)

In [146]:
adj_m * adj_m

tensor([[0.0000, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625],
        [0.0625, 0.0000, 0.0625, 0.0625, 0.0625, 0.0625],
        [0.0625, 0.0625, 0.0000, 0.0625, 0.0625, 0.0625],
        [0.0625, 0.0625, 0.0625, 0.0000, 0.0625, 0.0625],
        [0.0625, 0.0625, 0.0625, 0.0625, 0.0000, 0.0625],
        [0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0000]],
       grad_fn=<MulBackward0>)

In [143]:
m_exp

[tensor([[0.0000, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.0000, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.0000, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.0000, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.0000]],
        grad_fn=<MulBackward0>),
 tensor([[0.3125, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.3125, 0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.3125, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.3125, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.3125, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.3125]],
        grad_fn=<DivBackward0>),
 tensor([[0.1562, 0.1641, 0.1641, 0.1641, 0.1641, 0.1641],
         [0.1641, 0.1562, 0.1641, 0.1641, 0.1641, 0.1641],
         [0.1641, 0.1641, 0.1562, 0.1641, 0.1641, 0.1641],
         [0.1641, 0.1641, 0.1641, 0.1562, 0.164