In [1]:
import numpy as np
import torch
import os
import pickle
import time
from Model import *
from itertools import chain
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.utils.data import DataLoader
from synthetic import simulate_lorenz_96, simulate_var
from utils import build_flags, time_split, save_result, evaluate_result, count_accuracy, loss_sparsity, loss_divergence

In [2]:
parser = build_flags()
args = parser.parse_args(args=[])
args.seed = 2
args.num_nodes = 20
args.dims = 1
args.threshold = 0.5
args.time_length = 500
args.time_step = 10
args.epochs = 3000
args.batch_size = 128
args.lr = 1e-3
args.weight_decay = 1e-3
args.alpha = 0.02
args.beta_sparsity = 1   #log_sum
args.beta_kl = 0.1        #JS散度
args.hidden = 15
args.dropout = 0.4

In [3]:
X_np, GC = simulate_lorenz_96(p=20, F=10, T=500, seed=2)
X_np_ori = X_np
X_np = X_np.transpose(1, 0)
X_np = X_np[:, :, np.newaxis]
X_np = np.array(time_split(X_np, step=10))
X_np = torch.FloatTensor(X_np)
data = X_np
data_loader = DataLoader(data, batch_size=128)

In [4]:
idx = 0
save_file = '/home/jing_xuzijian/crf/Intrer_VAE_result/help_experiment/Inter_help_coef2.pt'
Inter_encoder = encoder(args.dims, args.hidden, args.dims, args.time_step - 1, args.num_nodes, args.dropout, args.alpha)
Inter_encoder = Inter_encoder.cuda()
Inter_decoder = decoder(args.dims, args.hidden, args.time_step - 1, args.num_nodes, args.dropout, args.alpha)
Inter_decoder = Inter_decoder.cuda()
optimizer = optim.Adam(params = chain(Inter_encoder.parameters(), Inter_decoder.parameters()), lr=args.lr, weight_decay=args.weight_decay)
loss_mse = nn.MSELoss()
best_loss = np.inf
for epoch in range(5000):
    t = time.time()
    Loss = []
    MSE_loss = []
    SPA_loss = []
    KL_loss = []
    for batch_idx, data in enumerate(data_loader):
        optimizer.zero_grad()
        data = data.cuda()
        target = data[:, idx, 1:, :]
        inputs = data[:, :, :-1, :]

        mu, log_var = Inter_encoder(inputs)  #Inter_encoder(inputs, adj)
        sigma = torch.exp(log_var / 2)
        gamma = torch.randn(size = mu.size()).cuda()
        gamma = mu + sigma * gamma
        mask = torch.sigmoid(gamma) #* torch.sigmoid(theta + gamma)

        inputs = mask_inputs(mask, inputs)
        pred = Inter_decoder(inputs, idx)   #Inter_decoder(inputs, adj, idx)



        mse_loss = loss_mse(pred, target)
        spa_loss = loss_sparsity(mask, 'log_sum')
        kl_loss = loss_divergence(mask, 'JS')

        loss = mse_loss + args.beta_sparsity * spa_loss + args.beta_kl * kl_loss

        loss.backward()
        optimizer.step()

        Loss.append(loss.item())
        MSE_loss.append(mse_loss.item())
        SPA_loss.append(spa_loss.item())
        KL_loss.append(kl_loss.item())
    
    if epoch == 500:
        optimizer.param_groups[0]['lr'] = args.lr/10

    if epoch % 100 == 0:
        print(  'Feature: {:04d}'.format(idx + 1),
                'Epoch: {:04d}'.format(epoch),
                'Loss: {:.10f}'.format(np.mean(Loss)),
                'MSE_Loss: {:.10f}'.format(np.mean(MSE_loss)),
                'Sparsity_loss: {:.10f}'.format(np.mean(SPA_loss)),
                'KL_loss: {:.10f}'.format(np.mean(KL_loss)),
                'time: {:.4f}s'.format(time.time() - t))
            
    if np.mean(Loss) < best_loss:
        best_loss = np.mean(Loss)
            # M[idx, :] = adj[idx, :]
            # gamma_matrix[idx, :] = gamma
            # theta_matrix[idx, :] = theta
        torch.save({
                    'encoder_state_dict': Inter_encoder.state_dict(),
                    'decoder_state_dict': Inter_decoder.state_dict(),
                        # 'adj' : adj

                    }, save_file)
            # np.save(save_file + str(idx) + '.npy', mask.cpu().detach().numpy())

        # print('Feature: {:04d}'.format(idx + 1),
        #       'Epoch: {:04d}'.format(epoch),
        #       'Loss: {:.10f}'.format(np.mean(Loss)),
        #       'mse_loss: {:.10f}'.format(np.mean(MSE_loss)),
        #       'Sparsity_loss: {:.10f}'.format(np.mean(SPA_loss)),
        #       'KL_loss: {:.10f}'.format(np.mean(KL_loss)),
        #       #'mmd_loss: {:.10f}'.format(np.mean(mmd_loss)),
        #       # 'time: {:.4f}s'.format(time.time() - t), file=log)

  nn.init.xavier_normal(m.weight.data, gain=1.414)


Feature: 0001 Epoch: 0000 Loss: 35.3649740219 MSE_Loss: 26.9113945961 Sparsity_loss: 8.4376585484 KL_loss: 0.1592129655 time: 0.1109s
Feature: 0001 Epoch: 0100 Loss: 16.5646128654 MSE_Loss: 13.7968297005 Sparsity_loss: 2.7627740502 KL_loss: 0.0500849960 time: 0.0192s
Feature: 0001 Epoch: 0200 Loss: 11.2060635090 MSE_Loss: 9.8673775196 Sparsity_loss: 1.3362399042 KL_loss: 0.0244614482 time: 0.0184s
Feature: 0001 Epoch: 0300 Loss: 8.6542916298 MSE_Loss: 7.4600173235 Sparsity_loss: 1.1922337711 KL_loss: 0.0204063314 time: 0.0192s
Feature: 0001 Epoch: 0400 Loss: 7.2643269300 MSE_Loss: 6.0709468126 Sparsity_loss: 1.1913311481 KL_loss: 0.0204906212 time: 0.0184s
Feature: 0001 Epoch: 0500 Loss: 6.0865253210 MSE_Loss: 5.1390103102 Sparsity_loss: 0.9462039471 KL_loss: 0.0131115990 time: 0.0183s
Feature: 0001 Epoch: 0600 Loss: 5.6702266932 MSE_Loss: 4.7516417503 Sparsity_loss: 0.9173251390 KL_loss: 0.0125989979 time: 0.0252s
Feature: 0001 Epoch: 0700 Loss: 5.8422636986 MSE_Loss: 4.8944616318 Spa

In [8]:
print(Inter_decoder.adj[idx, :])
print(Inter_decoder.adj.grad)
print(GC[idx, :])

tensor([ 3.6874e+00,  1.5899e-01,  4.2904e-02, -4.9666e-02, -3.7316e-02,
         1.5473e-03,  1.3185e-02, -1.2098e-02,  9.3266e-03, -6.6691e-03,
         1.7035e-03,  2.3604e-02,  1.6201e-02, -1.5379e-02,  9.8692e-03,
         6.3649e-03,  2.5113e-02,  1.4795e-02, -3.8504e-01, -7.8076e-02],
       device='cuda:0', grad_fn=<SliceBackward0>)
tensor([[-0.0245,  0.1380, -0.8680,  1.0253, -2.4208, -2.8583,  1.6075, -0.5942,
          0.0422, -5.0181, -4.2502, -1.2880,  3.8038, -3.3679, -2.8081, -1.0867,
          1.1769,  0.8595,  0.1737, -2.9083],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.000

In [17]:
help_encoder_net = encoder(args.dims, args.hidden, args.dims, args.time_step - 1, args.num_nodes, args.dropout, args.alpha)
help_decoder_net = decoder(args.dims, args.hidden, args.time_step -1, args.num_nodes, args.dropout, args.alpha)
checkpoint = torch.load('/home/jing_xuzijian/crf/Intrer_VAE_result/help_experiment/Inter_help_coef2.pt')
help_encoder_net.load_state_dict(checkpoint['encoder_state_dict'])
help_decoder_net.load_state_dict(checkpoint['decoder_state_dict'])
help_encoder_net.eval()
help_decoder_net.eval()
mu, log_var = help_encoder_net(X_np[:, :, :-1, :])
sigma = torch.exp(log_var / 2)
gamma = torch.randn(size = mu.size())
gamma = mu + sigma * gamma
mask = torch.sigmoid(gamma)
mask = mask.squeeze()

  nn.init.xavier_normal(m.weight.data, gain=1.414)


In [18]:
print(mask.shape)

torch.Size([491, 20, 9])


In [19]:
for para in help_decoder_net.named_parameters():
    print(para[0])

adj
mlp1.fc1.weight
mlp1.fc1.bias
mlp1.fc2.weight
mlp1.fc2.bias
mlp1.bn.weight
mlp1.bn.bias
mlp2.fc1.weight
mlp2.fc1.bias
mlp2.fc2.weight
mlp2.fc2.bias
mlp2.bn.weight
mlp2.bn.bias
gat.W
gat.a
fc.weight
fc.bias
bn.weight
bn.bias


In [9]:
print(help_decoder_net.bn.weight.grad)

None


In [19]:
print(help_decoder_net.adj)

Parameter containing:
tensor([[8.1529e-01, 7.5017e-01, 8.0229e-01, 8.1132e-01, 8.0489e-01, 8.2746e-01,
         8.4336e-01, 7.7390e-01, 8.5451e-01, 8.1560e-01, 7.7644e-01, 8.4206e-01,
         8.1700e-01, 8.0074e-01, 8.5518e-01, 7.7297e-01, 8.6488e-01, 7.9932e-01,
         8.4872e-01, 7.5807e-01],
        [4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42,
         4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42,
         4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42,
         4.9340e-42, 4.9340e-42],
        [4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42,
         4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42,
         4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42,
         4.9340e-42, 4.9340e-42],
        [4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42,
         4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-42, 4.9340e-4

In [7]:
eva = mask.mean(dim=0).mean(dim=1)
print(eva)
print(GC[0, :])

tensor([0.0420, 0.0410, 0.0423, 0.0397, 0.0426, 0.0422, 0.0425, 0.0454, 0.0425,
        0.0435, 0.0448, 0.0433, 0.0417, 0.0458, 0.0447, 0.0428, 0.0420, 0.0408,
        0.0453, 0.0447], grad_fn=<MeanBackward1>)
[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1]


In [4]:
idx = 1
print('Begin training feature: {:04d}'.format(idx + 1))
# decoder_file = os.path.join(save_folder, 'decoder_' + str(idx ) + '.pt')
Inter_decoder = decoder(args.dims, args.hidden, args.time_step - 1, args.num_nodes, args.dropout, args.alpha)
Inter_decoder = Inter_decoder.cuda()
optimizer = optim.Adam(params = Inter_decoder.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)
loss_val = nn.MSELoss()
best_loss = np.Inf
for epoch in range(5000):
    scheduler.step()
    t = time.time()
    Loss = []
    mse_loss = []
    for batch_idx, data in enumerate(data_loader):
        data = data.cuda()
        target = data[:, idx, 1:, :]
        optimizer.zero_grad()
        inputs = data[:, :, :-1, :]
        pred = Inter_decoder(inputs, idx)
        mse = loss_val(pred, target)
        loss = mse

        loss.backward()
        optimizer.step()


        Loss.append(loss.item())
        mse_loss.append(mse.item())

    if epoch % 100 == 0:
        print('Feature: {:04d}'.format(idx + 1),
              'Epoch: {:04d}'.format(epoch),
              'Loss: {:.10f}'.format(np.mean(Loss)),
              'MSE_Loss: {:.10f}'.format(np.mean(mse_loss)),
              'time: {:.4f}s'.format(time.time() - t))

        
    # if np.mean(mse_loss) < best_loss:
    #     best_loss = np.mean(mse_loss)
    #     torch.save(Inter_decoder.state_dict(), decoder_file)
    #     print('Feature: {:04d}'.format(idx + 1),
    #           'Epoch: {:04d}'.format(epoch),
    #           'Loss: {:.10f}'.format(np.mean(Loss)),
    #           'mse_loss: {:.10f}'.format(np.mean(mse_loss)),
    #           'mmd_loss: {:.10f}'.format(np.mean(mmd_loss)),
    #           'time: {:.4f}s'.format(time.time() - t), file=log)

Begin training feature: 0002


  nn.init.xavier_normal(m.weight.data, gain=1.414)


Feature: 0002 Epoch: 0000 Loss: 38.1945438385 MSE_Loss: 38.1945438385 time: 0.0859s
Feature: 0002 Epoch: 0100 Loss: 22.2226710320 MSE_Loss: 22.2226710320 time: 0.0085s
Feature: 0002 Epoch: 0200 Loss: 16.2656733990 MSE_Loss: 16.2656733990 time: 0.0090s
Feature: 0002 Epoch: 0300 Loss: 12.2572782040 MSE_Loss: 12.2572782040 time: 0.0090s
Feature: 0002 Epoch: 0400 Loss: 9.9791700840 MSE_Loss: 9.9791700840 time: 0.0090s
Feature: 0002 Epoch: 0500 Loss: 8.6275345087 MSE_Loss: 8.6275345087 time: 0.0090s
Feature: 0002 Epoch: 0600 Loss: 8.8666521311 MSE_Loss: 8.8666521311 time: 0.0090s
Feature: 0002 Epoch: 0700 Loss: 7.7709715366 MSE_Loss: 7.7709715366 time: 0.0090s
Feature: 0002 Epoch: 0800 Loss: 7.8987344503 MSE_Loss: 7.8987344503 time: 0.0090s
Feature: 0002 Epoch: 0900 Loss: 8.0406353474 MSE_Loss: 8.0406353474 time: 0.0090s
Feature: 0002 Epoch: 1000 Loss: 7.7239035368 MSE_Loss: 7.7239035368 time: 0.0090s
Feature: 0002 Epoch: 1100 Loss: 7.4051609039 MSE_Loss: 7.4051609039 time: 0.0084s
Feature:

In [5]:
print(Inter_decoder.adj[idx, :])
print(GC[idx, :])

tensor([-1.1446e-01,  3.5145e+00,  4.9145e-02, -2.6867e-02, -2.2781e-02,
        -1.9490e-02, -4.3720e-03, -3.7241e-03, -2.5528e-03, -5.8850e-03,
        -2.0816e-03, -1.3717e-02,  1.0860e-02, -5.9835e-03,  7.0344e-03,
         2.2493e-02, -5.7498e-03,  3.7945e-02,  2.3060e-02, -7.1532e-01],
       device='cuda:0', grad_fn=<SliceBackward0>)
[1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]
