In [1]:
import numpy as np
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
import random
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
import networkx as nx
from networkx.algorithms.shortest_paths.dense import floyd_warshall_numpy

from networkx.generators.random_graphs import *
from networkx.generators.ego import ego_graph
from networkx.generators.geometric import random_geometric_graph
import os

import pickle

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
import torch
from torch import nn
from torch.nn import Module, Parameter, Sequential
from torch.nn import Linear, Tanh, ReLU, CELU
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from torch.distributions import MultivariateNormal, Categorical, Normal, MultivariateNormal

class MLP_Regressor(nn.Module):
    def __init__(self, latent_dim, hidden_dim):
        super(MLP_Regressor, self).__init__()
        
        self.swish = nn.SiLU()
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        
        self.rfc1 = nn.Linear(latent_dim, hidden_dim)
        self.rfc2 = nn.Linear(hidden_dim, hidden_dim)
        self.rfc3 = nn.Linear(hidden_dim, 1)

    def decode(self, z):
        dh = self.tanh(self.rfc1(z))
        dh = self.tanh(self.rfc2(dh))
        return self.rfc3(dh)
     
    def forward(self, z, target):
        y_u = self.decode(z)
        
        log_prob_ys = Normal(y_u, 1).log_prob(target).sum()
        
        return log_prob_ys

In [4]:
task_names = ["bace", "ctsd", "mmp2", "malaria", "esol", "freesolv", "logp"]

In [5]:
def get_wbic(model, torch_ys, torch_zs, batch_indices, posterior_params):
    
    burn_in = 1000
    get_posterior_num = len(posterior_params) - burn_in

    nlogp_list = []
    model.eval()

    for id, indices in enumerate(batch_indices): #observed samples loop

        print("sample :", id)
        ts = torch_ys[indices]
        zs = torch_zs[indices]

        if torch.cuda.is_available():
            ts, zs = ts.to(device), zs.to(device)

        nlogp_per_batch = []

        #観測データを固定して、事後パラメータをサンプリング
        for i, posterior_param in enumerate(posterior_params[-get_posterior_num:]): #すでにcudaにのってる

            model.load_state_dict(posterior_param)
            batch_nlogp_per_posterior = - model(zs, ts)
            nlogp_per_batch.append(batch_nlogp_per_posterior.data.item())

            del posterior_param
            torch.cuda.empty_cache()

        nlogp_list.append(np.mean(nlogp_per_batch))

        del zs, ts
        torch.cuda.empty_cache()
        
    return nlogp_list

In [6]:
root = '/'
load_model_dir = root + 'save_models/zinc/pig-ae_models/'

def get_batch_index(indices, batch_size=None, n_batch=None):

    n_batch = len(indices)//batch_size
    batch_ids = np.array_split(indices, n_batch)
    return(batch_ids), n_batch

In [7]:
nlogp_list = []

batch_size = 50

In [8]:
import pickle
import future, sys, os, datetime, argparse, copy, warnings, time

for task in task_names:

    posterior_save_dir = root + "posteriors/wbic/zinc/" + task + "/ae/"
    sample_load_dir = load_model_dir + "samples_for_mcmc/" + task + "/"
    
    print(posterior_save_dir)
    
    z_mus = np.load(sample_load_dir + "embs_mu.npy")
    
    targets = np.load(sample_load_dir + "targets.npy")
    targets = np.reshape(targets, (targets.shape[0], 1))

    torch_mus = torch.from_numpy(z_mus.astype(np.float32)).clone()
    torch_zs = torch_mus
    torch_ys = torch.from_numpy(targets.astype(np.float32)).clone()

    batch_indices, n_batch = get_batch_index(np.arange(len(torch_ys)), batch_size=batch_size)

    f1 = open(posterior_save_dir + "posterior_params.pickle",'rb')
    posterior_params = pickle.load(f1)

    model = MLP_Regressor(90, 256).to(device)
    nlogps = get_wbic(model, torch_ys, torch_zs, batch_indices, posterior_params)

    nlogp_list.append(nlogps)

    del model, nlogps

lmc/zinc/bace/ae/
sample : 0
sample : 1
sample : 2
sample : 3
sample : 4
sample : 5
sample : 6
sample : 7
sample : 8
sample : 9
lmc/zinc/ctsd/ae/
sample : 0
lmc/zinc/mmp2/ae/
sample : 0
sample : 1
sample : 2
sample : 3
sample : 4
sample : 5
sample : 6
sample : 7
sample : 8
sample : 9
sample : 10
sample : 11
sample : 12
sample : 13
sample : 14
sample : 15
sample : 16
sample : 17
sample : 18
sample : 19
lmc/zinc/malaria/ae/
sample : 0
sample : 1
sample : 2
sample : 3
sample : 4
sample : 5
sample : 6
sample : 7
sample : 8
sample : 9
sample : 10
sample : 11
sample : 12
sample : 13
sample : 14
sample : 15
sample : 16
sample : 17
sample : 18
sample : 19
sample : 20
sample : 21
sample : 22
sample : 23
sample : 24
sample : 25
sample : 26
sample : 27
sample : 28
sample : 29
sample : 30
sample : 31
sample : 32
sample : 33
sample : 34
sample : 35
sample : 36
sample : 37
sample : 38
sample : 39
sample : 40
sample : 41
sample : 42
sample : 43
sample : 44
sample : 45
sample : 46
sample : 47
sample :

In [9]:
for task, nlogps in zip(task_names, nlogp_list):
    print("task :", task, np.sum(nlogps))

task : bace 852.1536924819948
task : ctsd 113.69796212387085
task : mmp2 2095.0939081077577
task : malaria 6010.996668292999
task : esol 2238.24011863327
task : freesolv 2955.6315377616884
task : lipo 3073.680873889923
task : vp 7663.324683347702
