In [1]:
import numpy as np
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
import random
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
import networkx as nx
from networkx.algorithms.shortest_paths.dense import floyd_warshall_numpy

from networkx.generators.random_graphs import *
from networkx.generators.ego import ego_graph
from networkx.generators.geometric import random_geometric_graph
import os

import pickle

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
task_names = ["bace", "ctsd", "mmp2", "malaria", "esol", "freesolv", "lipo", "logp"]

In [5]:
import torch
from torch import nn
from torch.nn import Module, Parameter, Sequential
from torch.nn import Linear, Tanh, ReLU, CELU
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from torch.distributions import MultivariateNormal, Categorical, Normal, MultivariateNormal

class MLP_Regressor(nn.Module):
    def __init__(self, latent_dim, hidden_dim):
        super(MLP_Regressor, self).__init__()
        
        self.swish = nn.SiLU()
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        
        self.rfc1 = nn.Linear(latent_dim, hidden_dim)
        self.rfc2 = nn.Linear(hidden_dim, hidden_dim)
        self.rfc3 = nn.Linear(hidden_dim, 1)

    def decode(self, z):
        dh = self.tanh(self.rfc1(z))
        dh = self.tanh(self.rfc2(dh))
        return self.rfc3(dh)
     
    def forward(self, z, target):
        y_u = self.decode(z)
        
        log_prob_ys = Normal(y_u, 1).log_prob(target)
        
        return log_prob_ys, log_prob_ys.exp()

In [6]:
def get_waics(model, torch_ys, torch_zs, batch_indices, posterior_params):
    
    burn_in = 2000
    get_posterior_num = len(posterior_params) - burn_in

    nlogp_list = []
    var_list = []
    model.eval()

    for id, indices in enumerate(batch_indices): #observed samples loop

        print("sample :", id)
        ts = torch_ys[indices]
        zs = torch_zs[indices]

        if torch.cuda.is_available():
            ts = ts.to(device)
            zs = zs.to(device)

        prob_per_batch = [] #for Tn
        logps_per_batch = [] #for Vn

        #観測データを固定して、事後パラメータをサンプリング
        for i, posterior_param in enumerate(posterior_params[-get_posterior_num:]): #すでにcudaにのってる

            model.load_state_dict(posterior_param)

            batch_logp_per_posterior, batch_prob_per_posterior = model(zs, ts) #shape = (batch_size, 1)
            
            prob_per_batch.append(batch_prob_per_posterior.to('cpu').detach().numpy().copy().reshape(-1)) #for Tn
            logps_per_batch.append(batch_logp_per_posterior.to('cpu').detach().numpy().copy().reshape(-1)) # for Vn
            
            del posterior_param
            torch.cuda.empty_cache()

        nlogp_list.extend(-1 * np.log(np.mean(prob_per_batch, axis = 0))) # for Tn
        var_list.extend(np.var(np.array(logps_per_batch), axis = 0)) # for Vn
        
        del prob_per_batch, logps_per_batch, ts, zs
        torch.cuda.empty_cache()
        
    return nlogp_list, var_list

In [7]:
root = '/'
load_model_dir = root + 'save_models/zinc/pig-ae_models/'

def get_batch_index(indices, batch_size=None, n_batch=None):

    n_batch = len(indices)//batch_size
    batch_ids = np.array_split(indices, n_batch)
    return(batch_ids), n_batch

In [8]:
batch_size = 50

In [9]:
import pickle
import future, sys, os, datetime, argparse, copy, warnings, time

nlogp_lists = []
var_lists = []

for task in task_names:

    posterior_save_dir = "lmc/zinc/" + task + "/ae/"
    posterior_save_dir = root + "posteriors/waic/zinc/" + task + "/ae/"
    
    print(posterior_save_dir)
    
    z_mus = np.load(sample_load_dir + "embs_mu.npy")
    targets = np.load(sample_load_dir + "targets.npy")
    targets = np.reshape(targets, (targets.shape[0], 1))

    torch_mus = torch.from_numpy(z_mus.astype(np.float32)).clone()
    torch_zs = torch_mus
    torch_ys = torch.from_numpy(targets.astype(np.float32)).clone()

    batch_indices, n_batch = get_batch_index(np.arange(len(torch_ys)), batch_size=batch_size)

    f1 = open(posterior_save_dir + "posterior_params.pickle",'rb')
    posterior_params = pickle.load(f1)

    model = MLP_Regressor(90, 256).to(device)

    nlogp_list, var_list  = get_waics(model, torch_ys, torch_zs, batch_indices, posterior_params)
    
    nlogp_lists.append(nlogp_list)
    var_lists.append(var_list)
    
    del model

lmc/zinc/bace/ae/
sample : 0
sample : 1
sample : 2
sample : 3
sample : 4
sample : 5
sample : 6
sample : 7
sample : 8
sample : 9
lmc/zinc/ctsd/ae/
sample : 0
lmc/zinc/mmp2/ae/
sample : 0
sample : 1
sample : 2
sample : 3
sample : 4
sample : 5
sample : 6
sample : 7
sample : 8
sample : 9
sample : 10
sample : 11
sample : 12
sample : 13
sample : 14
sample : 15
sample : 16
sample : 17
sample : 18
sample : 19
lmc/zinc/malaria/ae/
sample : 0
sample : 1
sample : 2
sample : 3
sample : 4
sample : 5
sample : 6
sample : 7
sample : 8
sample : 9
sample : 10
sample : 11
sample : 12
sample : 13
sample : 14
sample : 15
sample : 16
sample : 17
sample : 18
sample : 19
sample : 20
sample : 21
sample : 22
sample : 23
sample : 24
sample : 25
sample : 26
sample : 27
sample : 28
sample : 29
sample : 30
sample : 31
sample : 32
sample : 33
sample : 34
sample : 35
sample : 36
sample : 37
sample : 38
sample : 39
sample : 40
sample : 41
sample : 42
sample : 43
sample : 44
sample : 45
sample : 46
sample : 47
sample :

  nlogp_list.extend(-1 * np.log(np.mean(prob_per_batch, axis = 0))) # for Tn


sample : 7
sample : 8
sample : 9
sample : 10
sample : 11
lmc/zinc/lipo/ae/
sample : 0
sample : 1
sample : 2
sample : 3
sample : 4
sample : 5
sample : 6
sample : 7
sample : 8
sample : 9
sample : 10
sample : 11
sample : 12
sample : 13
sample : 14
sample : 15
sample : 16
sample : 17
sample : 18
sample : 19
sample : 20
sample : 21
sample : 22
sample : 23
sample : 24
sample : 25
sample : 26
sample : 27
sample : 28
sample : 29
sample : 30
sample : 31
sample : 32
sample : 33
sample : 34
sample : 35
sample : 36
sample : 37


In [10]:
for nlogp_list, var_list in zip(nlogp_lists, var_lists):
    print(np.mean(nlogp_list) + np.mean(var_list), np.mean(nlogp_list), np.mean(var_list))

1.652771 1.6270512 0.025719708
1.3126745 1.2788761 0.03379849
1.9250054 1.8946797 0.030325755
1.9307127 1.8981124 0.032600235
1.9693598 1.9125398 0.056819867
inf inf 2.9009142
1.5807693 1.5678978 0.0128715485


In [None]:
no_ids = [44, 215, 362]
#no_ids = [362]

#nlogp_lists[0], var_lists[0]

nlogp_list = []
var_list = []

for id, (nlogp, var) in enumerate(zip(nlogp_lists[-2], var_lists[-2])):
    
    if id not in no_ids:
        nlogp_list.append(nlogp)
        var_list.append(var)

In [19]:
print(np.mean(nlogp_list) + np.mean(var_list), np.mean(nlogp_list), np.mean(var_list))

4.3850403 3.5036266 0.8814136
