In [1]:
import numpy as np
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
import random
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
import networkx as nx
from networkx.algorithms.shortest_paths.dense import floyd_warshall_numpy

from networkx.generators.random_graphs import *
from networkx.generators.ego import ego_graph
from networkx.generators.geometric import random_geometric_graph
import os

import pickle

In [2]:
task_names = ['mu', 'alpha', 'homo', 'lumo', 'r2', 'zpve', 'u0','cv']

In [3]:
import torch
from torch import nn
from torch.nn import Module, Parameter, Sequential
from torch.nn import Linear, Tanh, ReLU, CELU
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
from torch.distributions import MultivariateNormal, Categorical, Normal, MultivariateNormal

class MLP_Regressor(nn.Module):
    def __init__(self, latent_dim, hidden_dim):
        super(MLP_Regressor, self).__init__()
        
        self.swish = nn.SiLU()
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        
        self.rfc1 = nn.Linear(latent_dim, hidden_dim)
        self.rfc2 = nn.Linear(hidden_dim, hidden_dim)
        self.rfc3 = nn.Linear(hidden_dim, 1)

    def decode(self, z):
        dh = self.swish(self.rfc1(z))
        dh = self.swish(self.rfc2(dh))
        return self.rfc3(dh)
     
    def forward(self, z, target):
        y_u = self.decode(z)
        
        log_prob_ys = Normal(y_u, 1).log_prob(target)
        #log_prob_zs = MultivariateNormal(z_mu, z_var).log_prob(z).sum()
        
        return log_prob_ys, log_prob_ys.exp()

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
def get_waics(model, torch_ys, torch_zs, batch_indices, posterior_params):
    
    burn_in = 2000
    get_posterior_num = len(posterior_params) - burn_in

    nlogp_list = []
    var_list = []
    model.eval()

    for id, indices in enumerate(batch_indices): #observed samples loop

        print("sample :", id)
        ts = torch_ys[indices]
        zs = torch_zs[indices]

        if torch.cuda.is_available():
            ts = ts.to(device)
            zs = zs.to(device)

        prob_per_batch = [] #for Tn
        logps_per_batch = [] #for Vn

        #観測データを固定して、事後パラメータをサンプリング
        for i, posterior_param in enumerate(posterior_params[-get_posterior_num:]): #すでにcudaにのってる

            model.load_state_dict(posterior_param)

            batch_logp_per_posterior, batch_prob_per_posterior = model(zs, ts) #shape = (batch_size, 1)
            
            prob_per_batch.append(batch_prob_per_posterior.to('cpu').detach().numpy().copy().reshape(-1)) #for Tn
            logps_per_batch.append(batch_logp_per_posterior.to('cpu').detach().numpy().copy().reshape(-1)) # for Vn
            
            del posterior_param
            torch.cuda.empty_cache()

        nlogp_list.extend(-1 * np.log(np.mean(prob_per_batch, axis = 0))) # for Tn
        var_list.extend(np.var(np.array(logps_per_batch), axis = 0)) # for Vn
        
        del prob_per_batch, logps_per_batch, ts, zs
        torch.cuda.empty_cache()
        
    return nlogp_list, var_list

In [6]:
root = '/'
load_model_dir = root + 'save_models/qm9/pig-e3ae_models/'

def get_batch_index(indices, batch_size=None, n_batch=None):

    n_batch = len(indices)//batch_size
    batch_ids = np.array_split(indices, n_batch)
    return(batch_ids), n_batch

In [7]:
batch_size = 100

n_samples = 5000

In [8]:
import pickle
import future, sys, os, datetime, argparse, copy, warnings, time

task_list = ["homo", "lumo"]

nlogp_lists = []
var_lists = []

for id, task in enumerate(task_names):
    
    if task in task_list:

        sample_load_dir = load_model_dir + "samples_for_mcmc/"
        posterior_save_dir = root + "posteriors/waic/qm9/" + task + "/ae/"

        print(posterior_save_dir)

        z_mus = np.load(sample_load_dir + "embs_mu.npy")[0:n_samples]
        targets = np.load(sample_load_dir + "all_targets.npy")[0:n_samples,id]
        targets = np.reshape(targets, (targets.shape[0], 1))
        
        torch_mus = torch.from_numpy(z_mus.astype(np.float32)).clone()
        torch_ys = torch.from_numpy(targets.astype(np.float32)).clone()
        torch_zs = torch_mus

        batch_indices, n_batch = get_batch_index(np.arange(len(torch_ys)), batch_size=batch_size)

        f1 = open(posterior_save_dir + "posterior_params.pickle",'rb')
        posterior_params = pickle.load(f1)

        model = MLP_Regressor(50, 128).to(device)
        nlogp_list, var_list  = get_waics(model, torch_ys, torch_zs, batch_indices, posterior_params)
        
        nlogp_lists.append(nlogp_list)
        var_lists.append(var_list)

        del model

lmc/qm9/homo/ae/
sample : 0
sample : 1


KeyboardInterrupt: 

In [None]:
for nlogp_list, var_list in zip(nlogp_lists, var_lists):
    print(np.mean(nlogp_list) + np.mean(var_list), np.mean(nlogp_list), np.mean(var_list))