In [17]:
import pandas as pd
import numpy as np
from numpy import vectorize as vec
import scipy as sp
import sklearn
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
#import seaborn as sns
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole 
from rdkit.Chem import Descriptors,PandasTools
from rdkit.ML.Descriptors import MoleculeDescriptors

In [18]:
qm9_df = pd.read_csv('datasets/qm9/qm9_shuffle.csv')

In [19]:
column_type = qm9_df.dtypes
column_name = qm9_df.columns
print(column_type)
print(column_name)

Unnamed: 0      int64
mol_id         object
smiles         object
A             float64
B             float64
C             float64
mu            float64
alpha         float64
homo          float64
lumo          float64
gap           float64
r2            float64
zpve          float64
u0            float64
u298          float64
h298          float64
g298          float64
cv            float64
u0_atom       float64
u298_atom     float64
h298_atom     float64
g298_atom     float64
dtype: object
Index(['Unnamed: 0', 'mol_id', 'smiles', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom'], dtype='object')


# Load Embedding Target Property

In [36]:
target_vec = pd.read_csv('datasets/qm9/qm9_shuffle.csv',usecols=['homo'])
target_vec = target_vec.astype('float')
target_vec = target_vec.values.tolist()

In [26]:
target_list = []

for tgt in target_vec:
    target_list.append(tgt[0])

In [37]:
target_list = []

    # Target scaling for HOMO energy
    # The scale of HOMO energy　property values is very small
for tgt in target_vec:
    target_list.append(tgt[0] * 100)

In [38]:
print(target_list[0:10])

[-23.59, -23.02, -24.169999999999998, -24.52, -25.27, -26.85, -25.46, -24.22, -23.02, -23.400000000000002]


# Model load

In [47]:
from mhg.nn.autoencoder import GrammarSeq2SeqVAE
from mhg.nn.dataset import HRGDataset, batch_padding
from mhg.smi import HGGen, hg_to_mol
from mhg.hrg import HyperedgeReplacementGrammar as HRG

import torch
from torch.utils.data import DataLoader

import os
import logging
import gzip
import pickle
import numpy as np
from copy import deepcopy

from rdkit import Chem

def get_dataloaders(hrg, prod_rule_seq_list, target_val_list=None, batch_size=None, shuffle=False):
    
    ''' return a dataloader for train/val/test

    Parameters
    ----------
    prod_rule_seq_list : List of lists
        each element corresponds to a sequence of production rules.
    train_params : dict
        self.Train_params

    Returns
    -------
    Dataloaders for train, val, test of autoencoders
        each batch contains two torch Tensors, each of which corresponds to input and output of autoencoder.
    '''
    # Parameters for training a variational autoencoder
    Train_params = {
        'model': 'GrammarSeq2SeqVAE', # Model name
        'model_params' : { # Parameter for the model
            'latent_dim': 50, # Dimension of the latent dimension
            'max_len': 12, # maximum length of input sequences (represented as sequences of production rules)
            'batch_size': 64, # batch size for training
            'padding_idx': -1, # integer used for padding
            'start_rule_embedding': False, # Embed the starting rule into the latent dimension explicitly
            'encoder_name': 'GRU', # Type of encoder
            'encoder_params': {'hidden_dim': 384, # hidden dim
                               'num_layers': 3, # the number of layers
                               'bidirectional': True, # use bidirectional one or not
                               'dropout': 0.1}, # dropout probability
            'decoder_name': 'GRU', # Type of decoder
            'decoder_params': {'hidden_dim': 384, # hidden dim
                               'num_layers': 3, # the number of layers
                               'dropout': 0.1}, # dropout probability
            'prod_rule_embed': ['Embedding',
                                'MolecularProdRuleEmbedding',
                                'MolecularProdRuleEmbeddingLastLayer',
                                'MolecularProdRuleEmbeddingUsingFeatures'][0], # Embedding method of a production rule. The simple embedding was the best, but each production rule could be embedded using GNN
            'prod_rule_embed_params': {'out_dim': 900, # embedding dimension
                                       'layer2layer_activation': 'relu', # not used for `Embedding`
                                       'layer2out_activation': 'softmax', # not used for `Embedding`
                                       'num_layers': 4}, # not used for `Embedding`
            'criterion_func': ['VAELoss', 'GrammarVAELoss'][1], # Loss function
            'criterion_func_kwargs': {'beta': 0.01}}, # Parameters for the loss function. `beta` specifies beta-VAE.
        'sgd': 'Adam', # SGD algorithm
        'sgd_params': {'lr': 5e-4 # learning rate of SGD
        },
        #'seed_list': [141, 123, 425, 552, 1004, 50243], # seeds used for restarting training
        'seed_list': [128], # seeds used for restarting training
        'inversed_input': True, # the input sequence is in the reversed order or not.
        'num_train':99968, # the number of training examples
        'num_val': 1000, # the number of validation examples
        'num_test': 28928, # the number of test examples
        'num_early_stop': 220011, # the number of examples used to find better initializations (=seed)
        'num_epochs': 10 # the number of training epochs
    }
    
    train_params = Train_params
    
    if batch_size is None:
        batch_size = train_params['model_params']['batch_size']
    prod_rule_seq_list_train = prod_rule_seq_list[: train_params['num_train']]
    prod_rule_seq_list_val = prod_rule_seq_list[train_params['num_train']
                                                : train_params['num_train'] + train_params['num_val']]
    prod_rule_seq_list_test = prod_rule_seq_list[train_params['num_train'] + train_params['num_val']
                                                 : train_params['num_train']
                                                 + train_params['num_val']
                                                 + train_params['num_test']]
    if target_val_list is None:
        target_val_list_train = None
        target_val_list_val = None
        target_val_list_test = None
    else:
        target_val_list_train = target_val_list[: train_params['num_train']]
        target_val_list_val = target_val_list[train_params['num_train']
                                              : train_params['num_train'] + train_params['num_val']]
        target_val_list_test = target_val_list[train_params['num_train'] + train_params['num_val']
                                               : train_params['num_train']
                                               + train_params['num_val']
                                               + train_params['num_test']]
    hrg_dataset_train = HRGDataset(hrg,
                                   prod_rule_seq_list_train,
                                   train_params['model_params']['max_len'],
                                   target_val_list=target_val_list_train,
                                   inversed_input=train_params['inversed_input'])
    hrg_dataloader_train = DataLoader(dataset=hrg_dataset_train,
                                      batch_size=batch_size,
                                      shuffle=shuffle, drop_last=False)
    if train_params['num_val'] != 0:
        hrg_dataset_val = HRGDataset(hrg,
                                     prod_rule_seq_list_val,
                                     train_params['model_params']['max_len'],
                                     target_val_list=target_val_list_val,
                                     inversed_input=train_params['inversed_input'])
        hrg_dataloader_label = DataLoader(dataset=hrg_dataset_val,
                                        batch_size=batch_size,
                                        shuffle=shuffle, drop_last=False)
    else:
        hrg_dataset_val = None
        hrg_dataloader_val = None
    if train_params['num_test'] != 0 :
        hrg_dataset_test = HRGDataset(hrg,
                                      prod_rule_seq_list_test,
                                      train_params['model_params']['max_len'],
                                      target_val_list=target_val_list_test,
                                      inversed_input=train_params['inversed_input'])
        hrg_dataloader_test = DataLoader(dataset=hrg_dataset_test,
                                         batch_size=batch_size,
                                         shuffle=shuffle, drop_last=False)
    else:
        hrg_dataset_test = None
        hrg_dataloader_test = None
    return hrg_dataloader_train, hrg_dataloader_label, hrg_dataloader_test, train_params


def load_output():
    input_dir_path = "OUTPUT/data_prep_for_qm9"
    with gzip.open(os.path.join(input_dir_path, 'mhg_prod_rules.pklz'), "rb") as f:
            hrg, prod_rule_seq_list = pickle.load(f)
    return hrg, prod_rule_seq_list

print("Start load input data and model")

hrg, prod_rule_seq_list = load_output()

hrg_dataloader_train, hrg_dataloader_val, hrg_dataloader_test, Train_params \
    = get_dataloaders(hrg, prod_rule_seq_list, target_val_list = target_list)

model_params = deepcopy(dict(Train_params['model_params']))

mhg_vae = GrammarSeq2SeqVAE(
    hrg=hrg,**model_params, use_gpu=True)

mhg_vae = mhg_vae.to('cuda:0')

print("load completed")

Start load input data and model
load completed


## VAE Optimizer

In [48]:
import torch.optim as optim

vae_optimizer = optim.Adam(mhg_vae.parameters(), lr=1e-3)

## MHG VAE Loss function

In [49]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
from torch.autograd import Variable
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss

from mhg.nn.loss import GrammarVAELoss_beta

mhg_vae_loss_func = GrammarVAELoss_beta(hrg = hrg, beta = 0.01)

# Beta-TCVAE loss function

In [50]:
import math

from tqdm import trange, tqdm
import torch

def matrix_log_density_gaussian(x, mu, logvar):
    
    """
    Calculates log density of a Gaussian for all combination of bacth pairs of
    `x` and `mu`. I.e. return tensor of shape `(batch_size, batch_size, dim)`
    instead of (batch_size, dim) in the usual log density.

    Parameters
    ----------
    x: torch.Tensor
        Value at which to compute the density. Shape: (batch_size, dim).

    mu: torch.Tensor
        Mean. Shape: (batch_size, dim).

    logvar: torch.Tensor
        Log variance. Shape: (batch_size, dim).

    batch_size: int
        number of training images in the batch
    """
    batch_size, dim = x.shape
    x = x.view(batch_size, 1, dim)
    mu = mu.view(1, batch_size, dim)
    logvar = logvar.view(1, batch_size, dim)
    return log_density_gaussian(x, mu, logvar)


def log_density_gaussian(x, mu, logvar):
    """Calculates log density of a Gaussian.

    Parameters
    ----------
    x: torch.Tensor or np.ndarray or float
        Value at which to compute the density.

    mu: torch.Tensor or np.ndarray or float
        Mean.

    logvar: torch.Tensor or np.ndarray or float
        Log variance.
    """
    normalization = - 0.5 * (math.log(2 * math.pi) + logvar)
    inv_var = torch.exp(-logvar)
    log_density = normalization - 0.5 * ((x - mu)**2 * inv_var)
    return log_density


def log_importance_weight_matrix(batch_size, dataset_size):
    """
    Calculates a log importance weight matrix

    Parameters
    ----------
    batch_size: int
        number of training images in the batch

    dataset_size: int
    number of training images in the dataset
    """
    N = dataset_size
    M = batch_size - 1
    strat_weight = (N - M) / (N * M)
    W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
    W.view(-1)[::M + 1] = 1 / N
    W.view(-1)[1::M + 1] = strat_weight
    W[M - 1, 0] = strat_weight
    return W.log()

In [51]:
import math

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim

def tc_decomposition_kl(n_data, latent_dist, latent_sample, alpha=1., beta=2., gamma=1., is_mss = False):
    
    """
    Parameters
    ----------
    n_data: int
        Number of data in the training set

    alpha : float
        Weight of the mutual information term.

    beta : float
        Weight of the total correlation term.

    gamma : float
        Weight of the dimension-wise KL term.

    is_mss : bool
        Whether to use minibatch stratified sampling instead of minibatch
        weighted sampling.

    """
    batch_size, latent_dim = latent_sample.shape
    log_pz, log_qz, log_prod_qzi, log_q_zCx = _get_log_pz_qz_prodzi_qzCx(n_data=n_data, latent_sample=latent_sample, latent_dist=latent_dist,
                                                                                                  is_mss=is_mss)
    
    # I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
    mi_loss = (log_q_zCx - log_qz).mean()
    # TC[z] = KL[q(z)||\prod_i z_i]
    tc_loss = (log_qz - log_prod_qzi).mean()
    # dw_kl_loss is KL[q(z)||p(z)] instead of usual KL[q(z|x)||p(z))]
    dw_kl_loss = (log_prod_qzi - log_pz).mean()
    
    # total loss
    all_loss = alpha * mi_loss + beta * tc_loss + gamma * dw_kl_loss
    
    return all_loss, mi_loss, tc_loss, dw_kl_loss
    
# Batch TC specific
# TO-DO: test if mss is better!
def _get_log_pz_qz_prodzi_qzCx(n_data, latent_sample, latent_dist, is_mss=False):

    batch_size, hidden_dim = latent_sample.shape

    # calculate log q(z|x)
    log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)

    # calculate log p(z)
    # mean and log var is 0
    zeros = torch.zeros_like(latent_sample)
    log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)

    mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

    if is_mss:
        # use stratification
        log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)
        mat_log_qz = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1)

    log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)
    log_prod_qzi = torch.logsumexp(mat_log_qz, dim=1, keepdim=False).sum(1)

    return log_pz, log_qz, log_prod_qzi, log_q_zCx

## beta scheduler

In [52]:
def scheduler_beta(epoch):
    
    if epoch == 1:
        return 1.25
    
    elif 1.25 - (epoch - 1) * 0.1 > 0.75:
        return 1.25 - (epoch - 1) * 0.1
    
    else:
        return 0.75

# Train

In [45]:
import os
import numpy as np
import math
import random
import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
from torch.autograd import Variable, Function
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss

from mhg.nn.metric import LogRatio_loss
from scipy.stats import pearsonr

mse = nn.MSELoss()

def tgt_cor(latent_vec, tgt_vec):

    tgt_pi = []
    max_pi = 0
    
    latent_vec = latent_vec.cpu()
    latent_vec = latent_vec.detach()
    latent_vec = latent_vec.clone()
    latent_vec = latent_vec.numpy()
    
    tgt_vec = tgt_vec.cpu()
    tgt_vec = tgt_vec.detach()
    tgt_vec = tgt_vec.clone()
    tgt_vec = tgt_vec.numpy()
    
    latent_vec_t = latent_vec.T
    tgt_vec = tgt_vec.reshape(64) #(batch_size,)
    
    for vec in latent_vec_t:
        p_result = pearsonr(vec, tgt_vec)
        tgt_pi.append(abs(p_result[0]))
        
    for pi in tgt_pi:
        if max_pi < pi:
            max_pi = pi
    
    max_pi = float(max_pi)
    
    return max_pi

#train loop
def train(epoch):
    
    mhg_vae.train()
    
    n_data = len(hrg_dataloader_train.dataset)

    for each_idx, each_batch in enumerate(hrg_dataloader_train):
    
        print("epoch:", epoch, "batch:",each_idx + 1,"/",len(hrg_dataloader_train))
    
        if len(each_batch[0]) < Train_params['model_params']['batch_size']:
                            each_batch[0] = torch.cat([each_batch[0],
                                                       Train_params['model_params']['padding_idx'] * torch.ones((Train_params['model_params']['batch_size'] - len(each_batch[0]),
                                                                                  len(each_batch[0][0])), dtype=torch.int64)], dim=0)
                            each_batch[1] = torch.cat([each_batch[1],
                                                       Train_params['model_params']['padding_idx'] * torch.ones((Train_params['model_params']['batch_size'] - len(each_batch[1]),
                                                                                  len(each_batch[1][0])), dtype=torch.int64)], dim=0)
                            each_batch[2] = torch.cat([each_batch[2],
                                                       torch.zeros((Train_params['model_params']['batch_size'] - len(each_batch[2])))], dim=0)
                
        if type(each_batch) == list:
            
            in_batch, out_batch, tgt_batch = each_batch
            in_batch = torch.LongTensor(np.mod(in_batch, mhg_vae.vocab_size))
            out_batch = torch.LongTensor(np.mod(out_batch, mhg_vae.vocab_size))
            tgt_batch = torch.FloatTensor(tgt_batch)
                  
            tgt_batch = torch.reshape(tgt_batch, (-1, 1))
           
            in_batch = in_batch.to("cuda:0")
            out_batch = out_batch.to("cuda:0")
            tgt_batch = tgt_batch.to("cuda:0")
        
            mhg_vae.init_hidden()
            
            #encode
            mu, logvar = mhg_vae.encode(in_batch)
            z = mhg_vae.reparameterize(mu, logvar)
            
            dist = mu, logvar
        
        ###### log ratio loss ######## 

        log_ratio_loss_score = LogRatio_loss(mu, tgt_batch)
    
        ###### mhg vae loss ######
        
        vae_optimizer.zero_grad()
        
        decoded = mhg_vae.decode(z, out_batch)
        
        pred_batch = decoded, mu, logvar
        
        beta_var = scheduler_beta(epoch)
        
        gamma = tgt_cor(mu, tgt_batch)

        if 1 - gamma < 0.01:
            gamma = 0.01
            
        elif 1 - gamma > 0.01:
            gamma = 1 - gamma
        
        gamma_scale = 0.1
    
        default_beta = 0.01
        mhg_vae_loss_score, reconst, kld = mhg_vae_loss_func(decoded, out_batch, mu, logvar, beta = default_beta)
        
        kl_decomposition_loss, mi_loss, tc_loss, dw_kl_loss = tc_decomposition_kl(n_data=n_data, alpha=0.75 , beta=beta_var, gamma=0.75, \
                                                                                latent_dist = dist, latent_sample = z, is_mss = True)
        #### Loss function ####
        mhg_vae_loss = reconst + kl_decomposition_loss + gamma_scale * gamma * log_ratio_loss_score
    
        mhg_vae_loss.backward()
        vae_optimizer.step()
        
        mhg_vae_loss_score = mhg_vae_loss_score.cpu()
        mhg_vae_loss = mhg_vae_loss.cpu()
        reconst = reconst.cpu()
        kl_decomposition_loss = kl_decomposition_loss.cpu() 
        mi_loss = mi_loss.cpu()
        tc_loss = tc_loss.cpu()
        dw_kl_loss = dw_kl_loss.cpu()
        log_ratio_loss_score = log_ratio_loss_score.cpu()
      

        mhg_vae_loss_score = mhg_vae_loss_score.data.numpy()
        mhg_vae_loss = mhg_vae_loss.data.numpy()
        reconst = reconst.data.numpy()
        kl_decomposition_loss = kl_decomposition_loss.data.numpy()
        mi_loss = mi_loss.data.numpy()
        tc_loss = tc_loss.data.numpy()
        dw_kl_loss = dw_kl_loss.data.numpy()
        log_ratio_loss_score = log_ratio_loss_score.data.numpy()
        
        save_path = "OUTPUT/train/semiLogRatio_train_qm9_tcbeta"
        
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        
        train_loss_log = "epoch:" + str(epoch) + "/" + str(num_epochs) + " " + "vae train loss:" + str(mhg_vae_loss_score) + " " + "log-ratio loss:"\
                            + str(log_ratio_loss_score) + " " + "Reconst:" + str(reconst) + " " + "MI(x,z):" + str(mi_loss) +\
                            " " + "TC:" + str(tc_loss) +" " + "DW KLD:" + str(dw_kl_loss) +"\n"
                            
        if each_idx % 50 == 0:
            with open(os.path.join(save_path, 'train_loss_log.txt'), mode='a') as f:
                f.write(train_loss_log)
                print("save train loss")
        
        if each_idx % 100 == 0:
            torch.save(mhg_vae, "OUTPUT/train/semiLogRatio_train_qm9_tcbeta/mhg_vae_metric_qm9_save.pt")

def test(epoch):
    
    mhg_vae.eval()
        
    mhg_vae_test_loss = 0
    reconst_test_loss = 0
    kl_decomposition_test_loss = 0
    mi_test_loss = 0
    tc_test_loss = 0
    dw_kl_test_loss = 0
    log_ratio_test_loss = 0

    n_data = len(hrg_dataloader_test.dataset)
    
    for each_idx, each_batch in enumerate(hrg_dataloader_test):
            
        print("epoch:", epoch, "batch:", each_idx + 1, "/", len(hrg_dataloader_test))
    
        if len(each_batch[0]) < Train_params['model_params']['batch_size']:
                            each_batch[0] = torch.cat([each_batch[0],
                                                       Train_params['model_params']['padding_idx'] * torch.ones((Train_params['model_params']['batch_size'] - len(each_batch[0]),
                                                                                  len(each_batch[0][0])), dtype=torch.int64)], dim=0)
                            each_batch[1] = torch.cat([each_batch[1],
                                                       Train_params['model_params']['padding_idx'] * torch.ones((Train_params['model_params']['batch_size'] - len(each_batch[1]),
                                                                                  len(each_batch[1][0])), dtype=torch.int64)], dim=0)
                            each_batch[2] = torch.cat([each_batch[2],
                                                       torch.zeros((Train_params['model_params']['batch_size'] - len(each_batch[2])))], dim=0)
                    
        if type(each_batch) == list:
            
            in_batch, out_batch, tgt_batch = each_batch
            in_batch = torch.LongTensor(np.mod(in_batch, mhg_vae.vocab_size))
            out_batch = torch.LongTensor(np.mod(out_batch, mhg_vae.vocab_size))
            tgt_batch = torch.FloatTensor(tgt_batch)
        
            tgt_batch = torch.reshape(tgt_batch, (-1, 1))
          
            in_batch = in_batch.to("cuda:0")
            out_batch = out_batch.to("cuda:0")
            tgt_batch = tgt_batch.to("cuda:0")
        
            mhg_vae.init_hidden()
        
            mu, logvar = mhg_vae.encode(in_batch)
            z = mhg_vae.reparameterize(mu, logvar)
            
            dist = mu, logvar

        log_ratio_test_loss_score = LogRatio_loss(mu, tgt_batch)
        
        decoded = mhg_vae.decode(z, out_batch)
        
        pred_batch = decoded, mu, logvar
        
        default_beta = 0.01
        beta_var = scheduler_beta(epoch)
        
        mhg_vae_loss, reconst, kld = mhg_vae_loss_func(decoded, out_batch, mu, logvar, beta = default_beta)

        kl_decomposition_loss, mi_loss, tc_loss, dw_kl_loss = tc_decomposition_kl(n_data = n_data, alpha=0.75 , beta=beta_var, gamma=0.75,\
                                                                                   latent_dist = dist, latent_sample = z, is_mss = True)
     
        torch.cuda.empty_cache()
        
        mhg_vae_loss = mhg_vae_loss.cpu()
        reconst = reconst.cpu()
        kl_decomposition_loss = kl_decomposition_loss.cpu() 
        mi_loss = mi_loss.cpu()
        tc_loss = tc_loss.cpu()
        dw_kl_loss = dw_kl_loss.cpu()
        log_ratio_test_loss_score = log_ratio_test_loss_score.cpu()
      

        mhg_vae_loss = mhg_vae_loss.data.numpy()
        reconst = reconst.data.numpy()
        kl_decomposition_loss = kl_decomposition_loss.data.numpy()
        mi_loss = mi_loss.data.numpy()
        tc_loss = tc_loss.data.numpy()
        dw_kl_loss = dw_kl_loss.data.numpy()
        log_ratio_test_loss_score = log_ratio_test_loss_score.data.numpy()
        
        mhg_vae_test_loss += float(mhg_vae_loss)
        reconst_test_loss += float(reconst)
        kl_decomposition_test_loss += float(kl_decomposition_loss)
        mi_test_loss += float(mi_loss)
        tc_test_loss += float(tc_loss)
        dw_kl_test_loss += float(dw_kl_loss)
        log_ratio_test_loss += float(log_ratio_test_loss_score)

    mhg_vae_test_loss /= len(hrg_dataloader_test.dataset)
    reconst_test_loss /= len(hrg_dataloader_test.dataset)
    kl_decomposition_test_loss /= len(hrg_dataloader_test.dataset)
    mi_test_loss /= len(hrg_dataloader_test.dataset)
    tc_test_loss /= len(hrg_dataloader_test.dataset)
    dw_kl_test_loss /= len(hrg_dataloader_test.dataset)
    log_ratio_test_loss /= len(hrg_dataloader_test.dataset)

    save_path = "OUTPUT/train/semiLogRatio_train_qm9_tcbeta"
    
    test_loss_log = "epoch:" + str(epoch) + "/" + str(num_epochs) + " " + "vae val losso:" + str(mhg_vae_test_loss) + " " + "log-ratio loss:"\
                            + str(log_ratio_test_loss) + " " + "Reconst:" + str(reconst_test_loss) + " " + "MI(x,z):"\
                            + str(mi_test_loss) + " " + "TC:" + str(tc_test_loss) + " " + "DW KLD:" + str(dw_kl_test_loss) + "\n"
    
    with open(os.path.join(save_path, 'test_loss_log.txt'), mode='a') as f:
        f.write(test_loss_log)
        print("save test loss")
    
    return log_ratio_test_loss, mhg_vae_test_loss

In [46]:
import time

save_path = "OUTPUT/train/semiLogRatio_train_qm9_tcbeta"

best_test_loss = 10000
num_epochs = 20

for epoch in range(1, num_epochs + 1):
    
    train(epoch)
    
    metric_loss, vae_loss = test(epoch)
    
    test_loss = metric_loss + vae_loss
    
    save_model = "OUTPUT/train/semiLogRatio_train_qm9_tcbeta/mhg_vae_metric_qm9_" + str(epoch) + "_model" + ".pt"
    torch.save(mhg_vae, save_model)
    
    if test_loss < best_test_loss:
        best_test_loss = test_loss
        torch.save(mhg_vae, "OUTPUT/train/semiLogRatio_train_qm9_tcbeta/mhg_vae_metric_qm9.pt")

epoch: 1 batch: 1 / 1562
save train loss
epoch: 1 batch: 2 / 1562
epoch: 1 batch: 3 / 1562
epoch: 1 batch: 4 / 1562
epoch: 1 batch: 5 / 1562
epoch: 1 batch: 6 / 1562


KeyboardInterrupt: 