# SAMを実装する 〜Original〜

In [13]:
import random
import numpy as np
import pandas as pd
import scipy
import matplotlib.pyplot as plt
from sklearn.preprocessing import scale
from scipy.special import expit
from tqdm import tqdm
import time
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from cdt.utils.torch import ChannelBatchNorm1d, MatrixSampler, Linear3D

In [2]:
np.random.seed(1234)
random.seed(1234)
np.set_printoptions(precision=2, floatmode='fixed', suppress=True)

## 関数定義

In [28]:
def make_data(n_data=2000):

    x = np.random.uniform(low=-1, high=1, size=n_data)  # -1から1の一様乱数

    e_z = np.random.randn(n_data)  # ノイズの生成
    z_prob = scipy.special.expit(-5.0 * x + 5 * e_z)
    Z = np.array([])

    for i in range(n_data):
        Z_i = np.random.choice(2, size=1, p=[1-z_prob[i], z_prob[i]])[0]
        Z = np.append(Z, Z_i)

    t = np.zeros(n_data)
    for i in range(n_data):
        if x[i] < 0:
            t[i] = 0.5
        elif x[i] >= 0 and x[i] < 0.5:
            t[i] = 0.7
        elif x[i] >= 0.5:
            t[i] = 1.0

    e_y = np.random.randn(n_data)
    Y = 2.0 + t*Z + 0.3*x + 0.1*e_y 

    Y2 = np.random.choice(
        [1.0, 2.0, 3.0, 4.0, 5.0],
        n_data, p=[0.1, 0.2, 0.3, 0.2, 0.2]
    )

    e_y3 = np.random.randn(n_data)
    Y3 = 3 * x + e_y3

    data = pd.DataFrame({
        'x': x,
        'Z': Z,
        'Y': Y,
        'Y2': Y2,
        'Y3': Y3
    })
    
    return data

In [4]:
class SAMDiscriminator(nn.Module):
    
    def __init__(self, n_data_col, n_hidden_layer, n_hidden_layers):
        
        super(SAMDiscriminator, self).__init__()
        
        self.n_data_col = n_data_col
        
        layers = []
        layers.append(nn.Linear(n_data_col, n_hidden_layer))
        layers.append(nn.BatchNorm1d(n_hidden_layer))
        layers.append(nn.LeakyReLU(.2))
        
        for i in range(n_hidden_layers - 1):
            layers.append(nn.Linear(n_hidden_layer, n_hidden_layer))
            layers.append(nn.BatchNorm1d(n_hidden_layer))
            layers.append(nn.LeakyReLU(.2))
            
        layers.append(nn.Linear(n_hidden_layer, 1))
        self.layers = nn.Sequential(*layers)
        
        self.register_buffer(
            'mask', torch.eye(n_data_col, n_data_col).unsqueeze(0)
        )
        
    def forward(self, input, obs_data=None):
        
        if obs_data is not None:
            return [
                self.layers(i) for i in torch.unbind(
                    obs_data.unsqueeze(1) * (1 - self.mask) + input.unsqueeze(1) * self.mask,
                    1
                )
            ]
        else:
            return self.layers(input)
    
    def reset_parameters(self):
        
        for layer in self.layers:
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()

In [5]:
class SAMGenerator(nn.Module):
    
    def __init__(self, n_data_col, n_hidden_layer):
        
        super(SAMGenerator, self).__init__()
        
        skeleton = 1 - torch.eye(n_data_col + 1, n_data_col)
        
        self.register_buffer('skeleton', skeleton)
        
        self.input_layer = Linear3D((n_data_col, n_data_col + 1, n_hidden_layer))
        
        layers = []
        layers.append(ChannelBatchNorm1d(n_data_col, n_hidden_layer))
        layers.append(nn.Tanh())
        self.layers = nn.Sequential(*layers)
        
        self.output_layer = Linear3D((n_data_col, n_hidden_layer, 1))
        
    def forward(self, data, noise, adj_matrix, drawn_neurons=None):
        
        x = self.input_layer(data, noise, adj_matrix * self.skeleton)
        
        x = self.layers(x)
        
        output = self.output_layer(x, noise=None, adj_matrix=drawn_neurons)
        
        return output.squeeze(2)
    
    def reset_parameters(self):
        
        self.input_layer.reset_parameters()
        self.output_layer.reset_parameters()
        
        for layer in self.layers:
            if hasattr(layer, 'reset_parametres'):
                layer.register_parameters()

In [6]:
def notears_constr(adj_m, max_pow=None):
    
    m_exp = [adj_m]
    if max_pow is None:
        max_pow = adj_m.shape[1]
    while (m_exp[-1].sum() > 0 and len(m_exp) < max_pow):
        m_exp.append(m_exp[-1] @ adj_m / len(m_exp))
        
    return sum([i.diag().sum() for idx, i in enumerate(m_exp)])

In [7]:
def train_sam_model(
    data,
    n_hidden_layer_gen=100, n_hidden_layer_dis=100,
    n_hidden_layers_dis=2,
    lr_gen=0.01*0.5, lr_dis=0.01*0.5*2,
    dag_start_rate=0.5, dag_penalization_increase=0.001*10,
    epochs_train=100, epochs_test=100,
    lambda1=5.0*20, lambda2=0.005*20
):

    data_columns = data.columns.tolist() 
    n_data_col = len(data_columns)  
    data = torch.from_numpy(data.values.astype('float32') )
    batch_size = len(data)

    data_iterator = DataLoader(
        data, batch_size=batch_size, shuffle=True, drop_last=True
    )

    sam = SAMGenerator(n_data_col, n_hidden_layer_gen)
    sam.reset_parameters()
    sampler_graph = MatrixSampler(n_data_col, mask=None, gumbel=False)
    sampler_neuron = MatrixSampler((n_hidden_layer_gen, n_data_col), mask=False, gumbel=True)

    sampler_graph.weights.data.fill_(2)

    discriminator = SAMDiscriminator(
        n_data_col=n_data_col, n_hidden_layer=n_hidden_layer_dis, n_hidden_layers=n_hidden_layers_dis
    )
    discriminator.reset_parameters()  

    optimizer_gen = optim.Adam(sam.parameters(), lr=lr_gen)
    optimizer_dis = optim.Adam(discriminator.parameters(), lr=lr_dis)
    optimizer_graph = optim.Adam(sampler_graph.parameters(), lr=lr_gen)
    optimizer_neuron = optim.Adam(sampler_neuron.parameters(), lr=lr_gen)

    criterion = nn.BCEWithLogitsLoss()

    _true = torch.ones(1)
    _false = torch.zeros(1)

    noise = torch.randn(batch_size, n_data_col)
    noise_row = torch.ones(1, n_data_col)

    output = torch.zeros(n_data_col, n_data_col)
    output_loss = torch.zeros(1, 1)

    pbar = tqdm(range(epochs_train + epochs_test))
    for epoch in pbar:
        for loop_num, data_batched in enumerate(data_iterator):

            optimizer_gen.zero_grad()
            optimizer_graph.zero_grad()
            optimizer_neuron.zero_grad()
            optimizer_dis.zero_grad()

            drawn_graph = sampler_graph()
            drawn_neurons = sampler_neuron()

            noise.normal_()
            generated_variables = sam(
                data=data_batched, 
                noise=noise,
                adj_matrix=torch.cat(
                    [drawn_graph, noise_row], 0
                ), 
                drawn_neurons=drawn_neurons
            )

            dis_vars_d = discriminator(generated_variables.detach(), data_batched)
            dis_vars_g = discriminator(generated_variables, data_batched)
            true_vars_dis = discriminator(data_batched) 

            loss_dis = sum(
                [criterion(gen, _false.expand_as(gen)) for gen in dis_vars_d]
            ) / n_data_col + criterion(
                true_vars_dis, _true.expand_as(true_vars_dis)
            )

            loss_gen = sum([criterion(gen, _true.expand_as(gen)) for gen in dis_vars_g])

            if epoch < epochs_train:
                loss_dis.backward()
                optimizer_dis.step()

            loss_struc = lambda1 / batch_size * drawn_graph.sum()     
            loss_func = lambda2 / batch_size * drawn_neurons.sum()  

            loss_regul = loss_struc + loss_func

            if epoch <= epochs_train * dag_start_rate:
                loss = loss_gen + loss_regul
            else:
                filters = sampler_graph.get_proba()
                loss = loss_gen + loss_regul + (
                    (epoch - epochs_train * dag_start_rate) * dag_penalization_increase
                ) * notears_constr(filters * filters)

            if epoch >= epochs_train:
                output.add_(filters.data)
                output_loss.add_(loss_gen.data)
            else:
                loss.backward(retain_graph=True)
                optimizer_gen.step()
                optimizer_graph.step()
                optimizer_neuron.step()

            # 進捗の表示
            if epoch % 50 == 0:
                pbar.set_postfix(
                    gen=loss_gen.item()/n_data_col,
                    dis=loss_dis.item(),
                    egul_loss=loss_regul.item(),
                    tot=loss.item()
                )

    return output.cpu().numpy()/epochs_test, output_loss.cpu().numpy()/epochs_test/n_data_col

In [31]:
n_hidden_layer_gen = 200
n_hidden_layer_dis = 200
n_hidden_layers_dis = 2
lr_gen = 0.01 * 0.5
lr_dis = 0.01 * 0.5 * 2
dag_start_rate = 0.5
dag_penalization_increase = 0.001 * 10
epochs_train = 10000
epochs_test = 1000
lambda1 = 5.0 * 20
lambda2 = 0.005 * 20

In [32]:
data = make_data(n_data=2000)
learning_data = data.copy()
learning_data.loc[:, :] = scale(learning_data.values) 

data = learning_data.copy()

In [33]:
n_data_col = data.shape[1]
data = torch.from_numpy(data.values.astype('float32'))
batch_size = len(data)

data_iterator = DataLoader(
    data, batch_size=batch_size, shuffle=True, drop_last=True
)

sam = SAMGenerator(n_data_col, n_hidden_layer_gen)
sam.reset_parameters()
sampler_graph = MatrixSampler(n_data_col, mask=None, gumbel=False)
sampler_neuron = MatrixSampler((n_hidden_layer_gen, n_data_col), mask=False, gumbel=True)

sampler_graph.weights.data.fill_(2)

discriminator = SAMDiscriminator(
    n_data_col=n_data_col, n_hidden_layer=n_hidden_layer_dis, n_hidden_layers=n_hidden_layers_dis
)
discriminator.reset_parameters()  

optimizer_gen = optim.Adam(sam.parameters(), lr=lr_gen)
optimizer_dis = optim.Adam(discriminator.parameters(), lr=lr_dis)
optimizer_graph = optim.Adam(sampler_graph.parameters(), lr=lr_gen)
optimizer_neuron = optim.Adam(sampler_neuron.parameters(), lr=lr_gen)

criterion = nn.BCEWithLogitsLoss()

_true = torch.ones(1)
_false = torch.zeros(1)

noise = torch.randn(batch_size, n_data_col)
noise_row = torch.ones(1, n_data_col)

output = torch.zeros(n_data_col, n_data_col)
output_loss = torch.zeros(1, 1)

In [34]:
start_time = time.time()
for epoch in range(epochs_train + epochs_test):
    for loop_num, data_batched in enumerate(data_iterator):

        optimizer_gen.zero_grad()
        optimizer_graph.zero_grad()
        optimizer_neuron.zero_grad()
        optimizer_dis.zero_grad()

        drawn_graph = sampler_graph()
        drawn_neurons = sampler_neuron()

        noise.normal_()
        generated_variables = sam(
            data=data_batched, 
            noise=noise,
            adj_matrix=torch.cat(
                [drawn_graph, noise_row], 0
            ), 
            drawn_neurons=drawn_neurons
        )

        dis_vars_d = discriminator(generated_variables.detach(), data_batched)
        true_vars_dis = discriminator(data_batched) 

        loss_dis = sum(
            [criterion(gen, _false.expand_as(gen)) for gen in dis_vars_d]
        ) / n_data_col + criterion(
            true_vars_dis, _true.expand_as(true_vars_dis)
        )
        
        if epoch < epochs_train:
            loss_dis.backward()
            optimizer_dis.step()

        loss_struc = lambda1 / batch_size * drawn_graph.sum()     
        loss_func = lambda2 / batch_size * drawn_neurons.sum()  

        loss_regul = loss_struc + loss_func
        
        dis_vars_g = discriminator(generated_variables, data_batched)
        loss_gen = sum([criterion(gen, _true.expand_as(gen)) for gen in dis_vars_g])

        if epoch <= epochs_train * dag_start_rate:
            loss = loss_gen + loss_regul
        else:
            filters = sampler_graph.get_proba()
            loss = loss_gen + loss_regul + (
                (epoch - epochs_train * dag_start_rate) * dag_penalization_increase
            ) * notears_constr(filters * filters)

        if epoch >= epochs_train:
            output.add_(filters.data)
            output_loss.add_(loss_gen.data)
        else:
            loss.backward(retain_graph=True)
            optimizer_gen.step()
            optimizer_graph.step()
            optimizer_neuron.step()
        
        if (epoch + 1) % 10 == 0:
            
            print('[{}/ {}] Loss Dis: {:.2f}, Loss Gen: {:.2f} , {:.0f}[s]'.format(
                epoch + 1, epochs_train + epochs_test, loss_dis, loss_gen, time.time() - start_time
            ))
            start_time = time.time()
        
        if epoch % 100 == 0:
            print('')
            print(sampler_graph.get_proba().detach().numpy())
            print('')


[[0.00 0.98 0.98 0.98 0.98]
 [0.98 0.00 0.98 0.98 0.98]
 [0.98 0.98 0.00 0.98 0.98]
 [0.98 0.98 0.98 0.00 0.98]
 [0.98 0.98 0.98 0.98 0.00]]

[10/ 11000] Loss Dis: 1.40, Loss Gen: 4.00 , 2[s]
[20/ 11000] Loss Dis: 1.39, Loss Gen: 3.32 , 2[s]
[30/ 11000] Loss Dis: 1.38, Loss Gen: 3.47 , 2[s]
[40/ 11000] Loss Dis: 1.38, Loss Gen: 3.57 , 2[s]
[50/ 11000] Loss Dis: 1.38, Loss Gen: 3.46 , 2[s]
[60/ 11000] Loss Dis: 1.38, Loss Gen: 3.48 , 2[s]
[70/ 11000] Loss Dis: 1.38, Loss Gen: 3.47 , 2[s]
[80/ 11000] Loss Dis: 1.38, Loss Gen: 3.50 , 2[s]
[90/ 11000] Loss Dis: 1.38, Loss Gen: 3.52 , 2[s]
[100/ 11000] Loss Dis: 1.38, Loss Gen: 3.47 , 2[s]

[[0.00 0.96 0.97 0.97 0.96]
 [0.96 0.00 0.96 0.97 0.96]
 [0.96 0.97 0.00 0.97 0.97]
 [0.97 0.97 0.97 0.00 0.97]
 [0.97 0.96 0.97 0.96 0.00]]

[110/ 11000] Loss Dis: 1.38, Loss Gen: 3.47 , 2[s]
[120/ 11000] Loss Dis: 1.38, Loss Gen: 3.49 , 2[s]
[130/ 11000] Loss Dis: 1.37, Loss Gen: 3.59 , 2[s]
[140/ 11000] Loss Dis: 1.38, Loss Gen: 3.50 , 2[s]
[150/ 110

In [29]:
m_list = []
loss_list = []

for i in range(5):
    m, loss = train_sam_model(
        data=learning_data.copy(),
        n_hidden_layer_gen=100, n_hidden_layer_dis=100,
        n_hidden_layers_dis=2,
        lr_gen=0.01*0.5, lr_dis=0.01*0.5*2,
        dag_start_rate=0.5,
        dag_penalization_increase=0.001*10,
        epochs_train=1000, epochs_test=1000,
        lambda1=5.0 * 20, lambda2=0.005 * 20
    )

    print(loss)
    print(m)

    m_list.append(m)
    loss_list.append(loss)

# ネットワーク構造（5回の平均）
print(sum(m_list) / len(m_list))

  0%|          | 0/2000 [00:00<?, ?it/s]

2000


100%|██████████| 2000/2000 [08:38<00:00,  3.86it/s, dis=0.681, egul_loss=0.559, gen=3.26, tot=32.5]
  0%|          | 0/2000 [00:00<?, ?it/s]

[[3.95]]
[[0.00 0.20 0.15 0.03 0.05 0.04]
 [0.91 0.00 0.53 0.42 0.11 0.47]
 [0.41 0.62 0.00 0.83 0.30 0.58]
 [0.05 0.02 0.02 0.00 0.06 0.02]
 [0.14 0.05 0.25 0.94 0.00 0.52]
 [0.07 0.05 0.19 0.39 0.90 0.00]]
2000


100%|██████████| 2000/2000 [10:14<00:00,  3.25it/s, dis=1.11, egul_loss=0.561, gen=5.09, tot=46.6] 
  0%|          | 0/2000 [00:00<?, ?it/s]

[[4.40]]
[[0.00 0.05 0.63 0.05 0.04 0.05]
 [0.89 0.00 0.65 0.52 0.17 0.56]
 [0.04 0.49 0.00 0.32 0.25 0.77]
 [0.04 0.03 0.03 0.00 0.07 0.07]
 [0.64 0.01 0.07 0.89 0.00 0.63]
 [0.40 0.07 0.04 0.83 0.84 0.00]]
2000


100%|██████████| 2000/2000 [08:44<00:00,  3.81it/s, dis=0.896, egul_loss=0.408, gen=2.58, tot=21.9]
  0%|          | 0/2000 [00:00<?, ?it/s]

[[2.98]]
[[0.00 0.20 0.42 0.04 0.03 0.08]
 [0.54 0.00 0.22 0.49 0.12 0.17]
 [0.15 0.85 0.00 0.85 0.73 0.18]
 [0.07 0.02 0.02 0.00 0.17 0.02]
 [0.23 0.03 0.05 0.89 0.00 0.90]
 [0.14 0.07 0.07 0.45 0.33 0.00]]
2000


100%|██████████| 2000/2000 [10:02<00:00,  3.32it/s, dis=0.907, egul_loss=0.56, gen=1.84, tot=26.5] 
  0%|          | 0/2000 [00:00<?, ?it/s]

[[2.64]]
[[0.00 0.85 0.63 0.18 0.05 0.05]
 [0.09 0.00 0.39 0.58 0.09 0.02]
 [0.08 0.93 0.00 0.81 0.48 0.16]
 [0.03 0.02 0.04 0.00 0.13 0.03]
 [0.11 0.03 0.21 0.87 0.00 0.82]
 [0.12 0.04 0.17 0.72 0.69 0.00]]
2000


100%|██████████| 2000/2000 [09:51<00:00,  3.38it/s, dis=0.512, egul_loss=0.458, gen=4.67, tot=41.7]

[[4.17]]
[[0.00 0.25 0.12 0.18 0.03 0.05]
 [0.89 0.00 0.51 0.10 0.12 0.07]
 [0.29 0.93 0.00 0.86 0.56 0.03]
 [0.04 0.03 0.05 0.00 0.10 0.03]
 [0.38 0.02 0.14 0.88 0.00 0.90]
 [0.22 0.16 0.11 0.59 0.40 0.00]]
[[0.00 0.31 0.39 0.09 0.04 0.06]
 [0.66 0.00 0.46 0.42 0.12 0.26]
 [0.19 0.76 0.00 0.73 0.47 0.35]
 [0.05 0.02 0.03 0.00 0.11 0.03]
 [0.30 0.03 0.14 0.89 0.00 0.75]
 [0.19 0.08 0.12 0.59 0.63 0.00]]





In [25]:
data=learning_data.copy()
n_hidden_layer_gen=200
n_hidden_layer_dis=200
n_hidden_layers_dis=2
lr_gen=0.01*0.5
lr_dis=0.01*0.5*2
dag_start_rate=0.5
dag_penalization_increase=0.001*10
epochs_train=10000
epochs_test=1000
lambda1=5.0*20
lambda2=0.005*20

In [26]:
data_columns = data.columns.tolist() 
n_data_col = len(data_columns)  
data = torch.from_numpy(data.values.astype('float32') )
batch_size = len(data)

data_iterator = DataLoader(
    data, batch_size=batch_size, shuffle=True, drop_last=True
)

sam = SAMGenerator(n_data_col, n_hidden_layer_gen)
sam.reset_parameters()
sampler_graph = MatrixSampler(n_data_col, mask=None, gumbel=False)
sampler_neuron = MatrixSampler((n_hidden_layer_gen, n_data_col), mask=False, gumbel=True)

sampler_graph.weights.data.fill_(2)

discriminator = SAMDiscriminator(
    n_data_col=n_data_col, n_hidden_layer=n_hidden_layer_dis, n_hidden_layers=n_hidden_layers_dis
)
discriminator.reset_parameters()  

optimizer_gen = optim.Adam(sam.parameters(), lr=lr_gen)
optimizer_dis = optim.Adam(discriminator.parameters(), lr=lr_dis)
optimizer_graph = optim.Adam(sampler_graph.parameters(), lr=lr_gen)
optimizer_neuron = optim.Adam(sampler_neuron.parameters(), lr=lr_gen)

criterion = nn.BCEWithLogitsLoss()

_true = torch.ones(1)
_false = torch.zeros(1)

noise = torch.randn(batch_size, n_data_col)
noise_row = torch.ones(1, n_data_col)

output = torch.zeros(n_data_col, n_data_col)
output_loss = torch.zeros(1, 1)

In [27]:
pbar = tqdm(range(epochs_train + epochs_test))
for epoch in pbar:
    for loop_num, data_batched in enumerate(data_iterator):

        optimizer_gen.zero_grad()
        optimizer_graph.zero_grad()
        optimizer_neuron.zero_grad()
        optimizer_dis.zero_grad()

        drawn_graph = sampler_graph()
        drawn_neurons = sampler_neuron()

        noise.normal_()
        generated_variables = sam(
            data=data_batched, 
            noise=noise,
            adj_matrix=torch.cat(
                [drawn_graph, noise_row], 0
            ), drawn_neurons=drawn_neurons
        )

        dis_vars_d = discriminator(generated_variables.detach(), data_batched)
        dis_vars_g = discriminator(generated_variables, data_batched)
        true_vars_dis = discriminator(data_batched) 

        loss_dis = sum(
            [criterion(gen, _false.expand_as(gen)) for gen in dis_vars_d]
        ) / n_data_col + criterion(
            true_vars_dis, _true.expand_as(true_vars_dis)
        )

        loss_gen = sum([criterion(gen, _true.expand_as(gen)) for gen in dis_vars_g])

        if epoch < epochs_train:
            loss_dis.backward()
            optimizer_dis.step()

        loss_struc = lambda1 / batch_size * drawn_graph.sum()     
        loss_func = lambda2 / batch_size * drawn_neurons.sum()  

        loss_regul = loss_struc + loss_func

        if epoch <= epochs_train * dag_start_rate:
            loss = loss_gen + loss_regul
        else:
            filters = sampler_graph.get_proba()
            loss = loss_gen + loss_regul + (
                (epoch - epochs_train * dag_start_rate) * dag_penalization_increase
            ) * notears_constr(filters * filters)

        if epoch >= epochs_train:
            output.add_(filters.data)
            output_loss.add_(loss_gen.data)
        else:
            loss.backward(retain_graph=True)
            optimizer_gen.step()
            optimizer_graph.step()
            optimizer_neuron.step()

        # 進捗の表示
        if epoch % 50 == 0:
            pbar.set_postfix(
                gen=loss_gen.item()/n_data_col,
                dis=loss_dis.item(),
                egul_loss=loss_regul.item(),
                tot=loss.item()
            )


  0%|          | 0/11000 [00:00<?, ?it/s]

11000


  1%|          | 133/11000 [01:34<2:08:22,  1.41it/s, dis=1.39, egul_loss=1.42, gen=0.694, tot=5.59]


KeyboardInterrupt: 

In [None]:
for epoch in pbar:
    for i_batch, data_batched in enumerate(data_iterator):


        

        

        

        

