In [2]:
%load_ext autoreload
%autoreload 2

In [None]:
import time
import argparse
import os
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.utils import data
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
from sklearn.metrics import mean_absolute_error
from tensorboardX import SummaryWriter


from dataloader import *
from utils import *
from metric import *

%matplotlib inline
parser = argparse.ArgumentParser()

# Attention

In [None]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self, step):
        "Update parameters and rate"
        rate = self.rate(step)
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step = None):
        step += 1
        return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))
    
    def zero_grad(self):
        self.optimizer.zero_grad()
        
    def state_dict(self):
        return self.optimizer.state_dict()
    
class Attention(nn.Module):
    def __init__(self, input_dim, output_dim, num_attn_head, dropout=0.1):
        super(Attention, self).__init__()   

        self.num_attn_heads = num_attn_head
        self.attn_dim = output_dim // num_attn_head
        self.projection = nn.ModuleList([nn.Linear(input_dim, self.attn_dim) for i in range(self.num_attn_heads)])
        self.coef_matrix = nn.ParameterList([nn.Parameter(torch.FloatTensor(self.attn_dim, self.attn_dim)) for i in range(self.num_attn_heads)])
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
        self.param_initializer()

    def forward(self, X, A):
        list_X_head = list()
        for i in range(self.num_attn_heads):
            X_projected = self.projection[i](X)
            attn_matrix = self.attn_coeff(X_projected, A, self.coef_matrix[i])
            X_head = torch.matmul(attn_matrix, X_projected)
            list_X_head.append(X_head)
            
        X = torch.cat(list_X_head, dim=2)
        X = self.relu(X)
        return X
            
    def attn_coeff(self, X_projected, A, C):
        X = torch.einsum('akj,ij->aki', (X_projected, C))
        attn_matrix = torch.matmul(X, torch.transpose(X_projected, 1, 2)) 
        attn_matrix = torch.mul(A, attn_matrix)
        attn_matrix = self.dropout(self.tanh(attn_matrix))
        return attn_matrix
    
    def param_initializer(self):
        for i in range(self.num_attn_heads):    
            nn.init.xavier_normal_(self.projection[i].weight.data)
            nn.init.xavier_normal_(self.coef_matrix[i].data)
            

# Gconv, Readout, BN1D, ResBlock, Encoder

In [5]:
def gelu(x):

    """ Ref: https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py
        Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu}

class Attention(nn.Module):
    def __init__(self, input_dim, output_dim, num_attn_head, dropout=0.1):
        super(Attention, self).__init__()

        self.num_attn_heads = num_attn_head
        self.attn_dim = output_dim // num_attn_head
        self.projection = nn.ModuleList([nn.Linear(input_dim, self.attn_dim) for i in range(self.num_attn_heads)])
        self.coef_matrix = nn.ParameterList([nn.Parameter(torch.FloatTensor(self.attn_dim, self.attn_dim)) for i in range(self.num_attn_heads)])
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
        self.param_initializer()

    def forward(self, X, A):
        list_X_head = list()
        for i in range(self.num_attn_heads):
            X_projected = self.projection[i](X)
            attn_matrix = self.attn_coeff(X_projected, A, self.coef_matrix[i])
            X_head = torch.matmul(attn_matrix, X_projected)
            list_X_head.append(X_head)

        X = torch.cat(list_X_head, dim=2)
        X = self.relu(X)
        return X

    def attn_coeff(self, X_projected, A, C):
        X = torch.einsum('akj,ij->aki', (X_projected, C))
        attn_matrix = torch.matmul(X, torch.transpose(X_projected, 1, 2))
        attn_matrix = torch.mul(A, attn_matrix)
        attn_matrix = self.dropout(self.tanh(attn_matrix))
        return attn_matrix

    def param_initializer(self):
        for i in range(self.num_attn_heads):
            nn.init.xavier_normal_(self.projection[i].weight.data)
            nn.init.xavier_normal_(self.coef_matrix[i].data)


#####################################################
# ===== Gconv, Readout, BN1D, ResBlock, Encoder =====#
#####################################################

class GConv(nn.Module):
    def __init__(self, input_dim, output_dim, attn, act=ACT2FN['relu']):
        super(GConv, self).__init__()
        self.attn = attn
        if self.attn is None:
            self.fc = nn.Linear(input_dim, output_dim)
            self.act = act
            nn.init.xavier_normal_(self.fc.weight.data)

    def forward(self, X, A):
        if self.attn is None:
            x = self.act(self.fc(X))
            x = torch.matmul(A, x)
        else:
            x = self.attn(X, A)
        return x, A


class Readout(nn.Module):
    def __init__(self, out_dim, molvec_dim):
        super(Readout, self).__init__()
        self.readout_fc = nn.Linear(out_dim, molvec_dim)
        nn.init.xavier_normal_(self.readout_fc.weight.data)

    def forward(self, output_H):
        molvec = self.readout_fc(output_H)
        molvec = torch.mean(molvec, dim=1)
        return molvec


class BN1d(nn.Module):
    def __init__(self, out_dim, use_bn=True):
        super(BN1d, self).__init__()
        self.use_bn = use_bn
        self.bn = nn.BatchNorm1d(out_dim)

    def forward(self, x):
        if not self.use_bn:
            return  x
        origin_shape = x.shape
        x = x.view(-1, origin_shape[-1])
        x = self.bn(x)
        x = x.view(origin_shape)
        return x


class ResBlock(nn.Module):
    def __init__(self, in_dim, out_dim, use_bn, use_attn, dp_rate, sc_type, n_attn_head=None, act=ACT2FN['relu']):
        super(ResBlock, self).__init__()
        self.use_bn = use_bn
        self.sc_type = sc_type

        attn = Attention(in_dim, out_dim, n_attn_head) if use_attn else None
        self.gconv = GConv(in_dim, out_dim, attn)

        self.bn1 = BN1d(out_dim, use_bn)
        self.dropout = nn.Dropout2d(p=dp_rate)
        self.act = act

        if not self.sc_type in ['no', 'gsc', 'sc']:
            raise Exception

        if self.sc_type != 'no':
            self.bn2 = BN1d(out_dim, use_bn)
            self.shortcut = nn.Sequential()
            if in_dim != out_dim:
                self.shortcut.add_module('shortcut', nn.Linear(in_dim, out_dim, bias=False))

        if self.sc_type == 'gsc':
            self.g_fc1 = nn.Linear(out_dim, out_dim, bias=True)
            self.g_fc2 = nn.Linear(out_dim, out_dim, bias=True)
            self.sigmoid = nn.Sigmoid()

    def forward(self, X, A):
        x, A = self.gconv(X, A)

        if self.sc_type == 'no':  # no skip-connection
            x = self.act(self.bn1(x))
            return self.dropout(x), A

        elif self.sc_type == 'sc': # basic skip-connection
            x = self.act(self.bn1(x))
            x = x + self.shortcut(X)
            return self.dropout(self.act(self.bn2(x))), A

        elif self.sc_type == 'gsc': # gated skip-connection
            x = self.act(self.bn1(x))
            x1 = self.g_fc1(self.shortcut(X))
            x2 = self.g_fc2(x)
            gate_coef = self.sigmoid(x1 +x2)
            x = torch.mul(x1, gate_coef) + torch.mul(x2, 1.0-gate_coef)
            return self.dropout(self.act(self.bn2(x))), A


class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()
        self.bs = args.batch_size
        self.molvec_dim = args.molvec_dim
        self.embedding = self.create_emb_layer([args.vocab_size, args.degree_size,
                                                args.numH_size, args.valence_size,
                                                args.isarom_size],  args.emb_train)
        self.out_dim = args.out_dim

        # Graph Convolution Layers with Readout Layer
        self.gconvs = nn.ModuleList()
        for i in range(args.n_layer):
            if i== 0:
                self.gconvs.append(
                    ResBlock(args.in_dim, self.out_dim, args.use_bn, args.use_attn, args.dp_rate, args.sc_type,
                             args.n_attn_heads, ACT2FN[args.act]))
            else:
                self.gconvs.append(
                    ResBlock(self.out_dim, self.out_dim, args.use_bn, args.use_attn, args.dp_rate, args.sc_type,
                             args.n_attn_heads, ACT2FN[args.act]))
        self.readout = Readout(self.out_dim, self.molvec_dim)

        # Molecular Vector Transformation
        self.fc1 = nn.Linear(self.molvec_dim, self.molvec_dim)
        self.fc2 = nn.Linear(self.molvec_dim, self.molvec_dim)
        self.fc3 = nn.Linear(self.molvec_dim, self.molvec_dim)
        self.bn1 = BN1d(self.molvec_dim)
        self.bn2 = BN1d(self.molvec_dim)
        self.act = ACT2FN[args.act]
        self.dropout = nn.Dropout(p=args.dp_rate)


    def forward(self, input_X, A):
        x, A, molvec = self.encoder(input_X, A)
        molvec = self.dropout(self.bn1(self.act(self.fc1(molvec))))
        molvec = self.dropout(self.bn2(self.act(self.fc2(molvec))))
        molvec = self.fc3(molvec)
        return x, A, molvec

    def encoder(self, input_X, A):
        x = self._embed(input_X)
        for i, module in enumerate(self.gconvs):
            x, A = module(x, A)
        molvec = self.readout(x)
        return x, A, molvec

    def _embed(self, x):
        list_embed = list()
        for i in range(5):
            list_embed.append(self.embedding[i](x[:, :, i].long()))
        x = torch.cat(list_embed, 2)
        return x

    def create_emb_layer(self, list_vocab_size, emb_train=False):
        list_emb_layer = nn.ModuleList()
        for i, vocab_size in enumerate(list_vocab_size):
            vocab_size += 1
            emb_layer = nn.Embedding(vocab_size, vocab_size)
            weight_matrix = torch.zeros((vocab_size, vocab_size))
            for i in range(vocab_size):
                weight_matrix[i][i] = 1
            emb_layer.load_state_dict({'weight': weight_matrix})
            emb_layer.weight.requires_grad = emb_train
            list_emb_layer.append(emb_layer)
        return list_emb_layer

# Compute Loss

In [6]:
def Compute_loss(pred_x, ground_x, vocab_size):
    symbol_loss = F.cross_entropy(pred_x[:,:vocab_size], ground_x[:,0].detach().long())
    degree_loss = F.cross_entropy(pred_x[:,vocab_size:vocab_size+6], ground_x[:,1].detach().long())
    numH_loss = F.cross_entropy(pred_x[:,vocab_size+6:vocab_size+11], ground_x[:,2].detach().long())
    valence_loss = F.cross_entropy(pred_x[:,vocab_size+11:vocab_size+17], ground_x[:,3].detach().long())
    isarom_loss = F.binary_cross_entropy(torch.sigmoid(pred_x[:,-2]), ground_x[:,4].detach().float())
    partial_loss = F.mse_loss(pred_x[:,-1], ground_x[:,5])
    total_loss = symbol_loss + degree_loss + numH_loss + valence_loss + isarom_loss + partial_loss
    return symbol_loss, degree_loss, numH_loss, valence_loss, isarom_loss, partial_loss, total_loss
    

# Classifier & Regressor

In [7]:
class Classifier(nn.Module):
    def __init__(self, out_dim, molvec_dim, classifier_dim, in_dim, dropout_rate=0.3, act=ACT2FN['relu']):
        super(Classifier, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.molvec_dim = molvec_dim
        self.classifier_dim = classifier_dim

        self.fc1 = nn.Linear(self.molvec_dim + self.out_dim, self.classifier_dim)
        self.fc2 = nn.Linear(self.classifier_dim, self.classifier_dim // 2)
        self.fc3 = nn.Linear(self.classifier_dim // 2, self.in_dim)
        self.bn1 = BN1d(self.classifier_dim)
        self.bn2 = BN1d(self.classifier_dim // 2)
        self.act = act
        self.dropout = nn.Dropout(p=dropout_rate)
        self.param_initializer()

    def forward(self, X, molvec, idx_M):
        batch_size = X.shape[0]
        num_masking = idx_M.shape[1]

        molvec = torch.unsqueeze(molvec, 1)
        molvec = molvec.expand(batch_size, num_masking, molvec.shape[-1])

        list_concat_x = list()
        for i in range(batch_size):
            target_x = torch.index_select(X[i], 0, idx_M[i])
            concat_x = torch.cat((target_x, molvec[i]), dim=1)
            list_concat_x.append(concat_x)

        concat_x = torch.stack(list_concat_x)
        pred_x = self.classify(concat_x)
        pred_x = pred_x.view(batch_size * num_masking, -1)
        return pred_x

    def classify(self, concat_x):
        x = self.dropout(self.bn1(self.act(self.fc1(concat_x))))
        x = self.dropout(self.bn2(self.act(self.fc2(x))))
        x = self.fc3(x)
        return x

    def param_initializer(self):
        nn.init.xavier_normal_(self.fc1.weight.data)
        nn.init.xavier_normal_(self.fc2.weight.data)


class Regressor(nn.Module):
    def __init__(self, molvec_dim, classifier_dim, num_aux_task=5, dropout_rate=0.3, act=ACT2FN['relu']):
        super(Regressor, self).__init__()

        self.molvec_dim = molvec_dim
        self.classifier_dim = classifier_dim
        self.fc1 = nn.Linear(self.molvec_dim, self.classifier_dim)
        self.fc2 = nn.Linear(self.classifier_dim, self.classifier_dim // 2)
        self.fc3 = nn.Linear(self.classifier_dim // 2, num_aux_task)
        self.bn1 = nn.BatchNorm1d(self.classifier_dim)
        self.bn2 = nn.BatchNorm1d(self.classifier_dim // 2)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.act = act
        self.param_initializer()

    def forward(self, molvec):
        x = self.dropout(self.bn1(self.act(self.fc1(molvec))))
        x = self.dropout(self.bn2(self.act(self.fc2(x))))
        x = self.fc3(x)
        return torch.squeeze(x)

    def param_initializer(self):
        nn.init.xavier_normal_(self.fc1.weight.data)
        nn.init.xavier_normal_(self.fc2.weight.data)
        nn.init.xavier_normal_(self.fc3.weight.data)

# Experiment

In [15]:
def train(models, optimizer, dataloader, epoch, cnt_iter, args):
    t = time.time()
    list_train_loss = list()
    epoch = epoch
    cnt_iter = cnt_iter
    reg_loss = nn.MSELoss()

    for epoch in range(epoch, args.epoch+1):
        for batch_idx, batch in enumerate(dataloader['train']):
            t1 = time.time()
            # Setting Train Mode
            for _, model in models.items():
                model.train()

            optimizer['mask'].zero_grad()
            optimizer['auxiliary'].zero_grad()

            # Get Batch Sample from DataLoader
            predict_idx, X, mask_X, true_X, A, C = batch

            # Normalize A matrix in order to prevent overflow

            # Convert Tensor into Variable and Move to CUDA
            mask_idx = Variable(predict_idx).to(args.device).long()
            input_X = Variable(X).to(args.device).float()
            mask_X = Variable(mask_X).to(args.device).float()
            true_X = Variable(true_X).to(args.device).float()
            input_A = Variable(A).to(args.device).float()
            #     mask_A = Variable(mask_A).to(args.device).float()
            input_C = Variable(C).to(args.device).float()

            t2 = time.time()
            # Encoding Masked Molecule
            encoded_X, _, molvec = models['encoder'](mask_X, input_A)
            pred_X = models['classifier'](encoded_X, molvec, mask_idx)

            # Compute Mask Task Loss
            symbol_loss, degree_loss, numH_loss, valence_loss, isarom_loss, partial_loss, mask_loss = Compute_loss(pred_X, true_X, args.vocab_size)

            # Backprogating and Updating Parameter
            mask_loss.backward()
            optimizer['mask'].step(cnt_iter)
            train_writer.add_scalar('1.status/lr', optimizer['mask'].rate(cnt_iter), cnt_iter)
            torch.cuda.empty_cache()

            t3 = time.time()

            # Compute Loss of Original Molecule Property
            if len(args.aux_task) > 0:
                _, _, molvec = models['encoder'](input_X, input_A)
                pred_C = models['regressor'](molvec)
                list_loss = [reg_loss(pred_C[:, i], input_C[:, i]) for i, label in enumerate(args.aux_task) ]
                auxiliary_loss = args.r_lambda * sum(list_loss)
                auxiliary_loss.backward()
                optimizer['auxiliary'].step(cnt_iter)
                torch.cuda.empty_cache()

            t4 = time.time()
            # print("total {:2.2f}. Prepare {:2.2f}. Mask {:2.2f}. Aux {:2.2f}".format(t4-t1, t2-t1, t3-t2, t4-t3))
            cnt_iter += 1
            setattr(args, 'epoch_now', epoch)
            setattr(args, 'iter_now', cnt_iter)

            # Prompting Status
            if cnt_iter % args.log_every == 0:
                train_writer.add_scalar('2.mask_loss/symbol', symbol_loss, cnt_iter)
                train_writer.add_scalar('2.mask_loss/degree', degree_loss, cnt_iter)
                train_writer.add_scalar('2.mask_loss/numH', numH_loss, cnt_iter)
                train_writer.add_scalar('2.mask_loss/valence', valence_loss, cnt_iter)
                train_writer.add_scalar('2.mask_loss/isarom', isarom_loss, cnt_iter)
                train_writer.add_scalar('1.status/mask', mask_loss, cnt_iter)
                if len(args.aux_task) > 0:
                    train_writer.add_scalar('1.status/auxiliary', auxiliary_loss, cnt_iter)
                    for i, task in enumerate(args.aux_task):
                        train_writer.add_scalar('3.auxiliary_loss/{}'.format(task), list_loss[i], cnt_iter)

                output = "[TRAIN] E:{:3}. P:{:>2.1f}%. Loss:{:>9.3}. Mask Loss:{:>9.3}. {:4.1f} mol/sec. Iter:{:6}.  Elapsed:{:6.1f} sec."
                elapsed = time.time() - t
                process_speed = (args.batch_size * args.log_every) / elapsed
                output = output.format(epoch, batch_idx / len(dataloader['train']) * 100.0, mask_loss, auxiliary_loss, process_speed, cnt_iter, elapsed,)
                t = time.time()
                logger.info(output)

            # Validate Model
            if cnt_iter % args.validate_every == 0:
                optimizer['mask'].zero_grad()
                optimizer['auxiliary'].zero_grad()
                validate(models, dataloader['val'], args, cnt_iter=cnt_iter, epoch=epoch)
                t = time.time()

            # Save Model
            if cnt_iter % args.save_every == 0:
                filename = save_checkpoint(epoch, cnt_iter, models, optimizer, args)
                logger.info('Saved Model as {}'.format(filename))
            del batch
                
    logger.info('Training Completed')
    

In [9]:
def validate(models, data_loader, args, **kwargs):
    t = time.time()
    cnt_iter = kwargs['cnt_iter']
    epoch = kwargs['epoch']
    temp_iter = 0
    reg_loss = nn.MSELoss()

    # For Maskingg Task Loss
    list_mask_loss = []
    list_symbol_loss = []
    list_degree_loss = []
    list_numH_loss = []
    list_valence_loss = []
    list_isarom_loss = []

    list_symbol_acc = []
    list_degree_acc = []
    list_numH_acc = []
    list_valence_acc = []
    list_isarom_acc = []

    # For Auxiliary Task Loss
    list_aux_loss = []
    list_aux_mae = []

    # For F1-Score Metric & Confusion Matrix
    confusion_symbol = np.zeros((args.vocab_size+1, args.vocab_size+1))
    confusion_degree = np.zeros((args.degree_size+1, args.degree_size+1))
    confusion_numH = np.zeros((args.numH_size+1, args.numH_size+1))
    confusion_valence = np.zeros((args.valence_size+1, args.valence_size+1))
    confusion_isarom = np.zeros((args.isarom_size+1, args.isarom_size+1))


    # Initialization Model with Evaluation Mode
    for _, model in models.items():
        model.eval()

    with torch.no_grad():
        for batch_idx, batch in enumerate(data_loader):
            # Get Batch Sample from DataLoader
            predict_idx, X, mask_X, true_X, A, C = batch

            # Convert Tensor into Variable and Move to CUDA
            mask_idx = Variable(predict_idx).to(args.device).long()
            input_X = Variable(X).to(args.device).float()
            mask_X = Variable(mask_X).to(args.device).float()
            true_X = Variable(true_X).to(args.device).float()
            input_A = Variable(A).to(args.device).float()
            #     mask_A = Variable(mask_A).to(args.device).float()
            input_C = Variable(C).to(args.device).float()

            # Encoding Masked Molecule
            encoded_X, _, molvec = models['encoder'](mask_X, input_A)
            pred_X = models['classifier'](encoded_X, molvec, mask_idx)

            # Compute Mask Task Loss & Property Regression Loss
            symbol_loss, degree_loss, numH_loss, valence_loss, isarom_loss, partial_loss, mask_loss = Compute_loss(pred_X, true_X, args.vocab_size)

            list_symbol_loss.append(symbol_loss.item())
            list_degree_loss.append(degree_loss.item())
            list_numH_loss.append(numH_loss.item())
            list_valence_loss.append(valence_loss.item())
            list_isarom_loss.append(isarom_loss.item())
            list_mask_loss.append((mask_loss).item())

            # Compute Mask Task Accuracy & Property Regression MAE
            symbol_acc, degree_acc, numH_acc, valence_acc, isarom_acc = compute_metric(pred_X, true_X)
            list_symbol_acc.append(symbol_acc)
            list_degree_acc.append(degree_acc)
            list_numH_acc.append(numH_acc)
            list_valence_acc.append(valence_acc)
            list_isarom_acc.append(isarom_acc)

            # Accumulate Mask Task Confusion Matrix for F1-Metric
            confusions = compute_confusion(pred_X, true_X, args)
            confusion_symbol += confusions[0]
            confusion_degree += confusions[1]
            confusion_numH += confusions[2]
            confusion_valence += confusions[3]
            confusion_isarom += confusions[4]

            if len(args.aux_task) > 0:
                _, _, molvec = models['encoder'](input_X, input_A)
                pred_C = models['regressor'](molvec)
                temp_loss = [reg_loss(pred_C[:, i], input_C[:, i]).item() for i, label in enumerate(args.aux_task)]
                list_aux_loss.append(temp_loss)

                pred_C = pred_C.cpu().detach().numpy()
                input_C = input_C.cpu().detach().numpy()
                list_aux_mae.append([mean_absolute_error(pred_C[:, i], input_C[:, i]) for i, label in enumerate(args.aux_task)])
                torch.cuda.empty_cache()

            temp_iter += 1

            # Prompting Status
            if temp_iter % (args.log_every * 10) == 0:
                output = "[VALID] E:{:3}. P:{:>2.1f}%. {:4.1f} mol/sec. Iter:{:6}.  Elapsed:{:6.1f} sec."
                elapsed = time.time() - t
                process_speed = (args.test_batch_size * args.log_every) / elapsed
                output = output.format(epoch, batch_idx / len(data_loader) * 100.0, process_speed, temp_iter, elapsed, )
                t = time.time()
                logger.info(output)
            del batch

    val_writer.add_figure('symbol/confusion',
                         plot_confusion_matrix(
                             confusion_symbol, range(args.vocab_size),
                             classes=LIST_SYMBOLS, title="Symbol CM @ {}".format(cnt_iter), figsize=(10, 10)),
                         cnt_iter)
    val_writer.add_figure('degree/confusion',
                         plot_confusion_matrix(confusion_degree[1:, 1:], range(args.degree_size), title="Degree CM @ {}".format(cnt_iter)),
                         cnt_iter)
    val_writer.add_figure('numH/confusion',
                         plot_confusion_matrix(confusion_numH[1:, 1:], range(args.numH_size), title="NumH CM @ {}".format(cnt_iter)),
                         cnt_iter)
    val_writer.add_figure('valence/confusion',
                         plot_confusion_matrix(confusion_valence[1:, 1:], range(args.valence_size), title="Valence CM @ {}".format(cnt_iter)),
                         cnt_iter)
    val_writer.add_figure('isarom/confusion',
                         plot_confusion_matrix(confusion_isarom[1:, 1:], range(args.isarom_size),
                                               title="isAromatic CM @ {}".format(cnt_iter), figsize=(2,2)),
                         cnt_iter)

    # Averaging Loss across the batch
    mask_loss = np.mean(np.array(list_mask_loss))
    symbol_loss = np.mean(np.array(list_symbol_loss))
    degree_loss = np.mean(np.array(list_degree_loss))
    numH_loss = np.mean(np.array(list_numH_loss))
    valence_loss = np.mean(np.array(list_valence_loss))
    isarom_loss = np.mean(np.array(list_isarom_loss))

    symbol_acc = np.mean(np.array(list_symbol_acc))
    degree_acc = np.mean(np.array(list_degree_acc))
    numH_acc = np.mean(np.array(list_numH_acc))
    valence_acc = np.mean(np.array(list_valence_acc))
    isarom_acc = np.mean(np.array(list_isarom_acc))

    val_writer.add_scalar('2.mask_loss/symbol', symbol_loss, cnt_iter)
    val_writer.add_scalar('2.mask_loss/degree', degree_loss, cnt_iter)
    val_writer.add_scalar('2.mask_loss/numH', numH_loss, cnt_iter)
    val_writer.add_scalar('2.mask_loss/valence', valence_loss, cnt_iter)
    val_writer.add_scalar('2.mask_loss/isarom', isarom_loss, cnt_iter)

    val_writer.add_scalar('4.mask_metric/acc_symbol', symbol_acc, cnt_iter)
    val_writer.add_scalar('4.mask_metric/acc_degree', degree_acc, cnt_iter)
    val_writer.add_scalar('4.mask_metric/acc_numH', numH_acc, cnt_iter)
    val_writer.add_scalar('4.mask_metric/acc_valence', valence_acc, cnt_iter)
    val_writer.add_scalar('4.mask_metric/acc_isarom', isarom_acc, cnt_iter)

    val_writer.add_scalar('4.mask_metric/f1_symbol', f1_macro(confusion_symbol[1:, 1:]), cnt_iter)
    val_writer.add_scalar('4.mask_metric/f1_degree', f1_macro(confusion_degree[1:, 1:]), cnt_iter)
    val_writer.add_scalar('4.mask_metric/f1_numH', f1_macro(confusion_numH[1:, 1:]), cnt_iter)
    val_writer.add_scalar('4.mask_metric/f1_valence', f1_macro(confusion_valence[1:, 1:]), cnt_iter)
    val_writer.add_scalar('4.mask_metric/f1_isarom', f1_macro(confusion_isarom[1:, 1:]), cnt_iter)

    if len(args.aux_task) > 0:
        list_aux_loss = np.mean(list_aux_loss, axis=0)
        list_aux_mae = np.mean(list_aux_mae, axis=0)

        for i, task in enumerate(args.aux_task):
            val_writer.add_scalar('3.auxiliary_loss/{}'.format(task), list_aux_loss[i], cnt_iter)
            val_writer.add_scalar('5.auxiliary_mae/{}'.format(task), list_aux_mae[i], cnt_iter)

        auxiliary_loss = np.mean(list_aux_loss)
        val_writer.add_scalar('1.status/auxiliary', auxiliary_loss, cnt_iter)
    val_writer.add_scalar('1.status/mask', mask_loss, cnt_iter)

    # Log model weight historgram
    log_histogram(models, val_writer, cnt_iter)

    output = "[VALID] E:{:3}. P:{:>2.1f}%. Mask Loss:{:>9.3}. Aux Loss:{:>9.3}. {:4.1f} mol/sec. Iter:{:6}.  Elapsed:{:6.1f} sec."
    elapsed = time.time() - t
    process_speed = (args.test_batch_size * args.log_every) / elapsed
    output = output.format(epoch, batch_idx / len(data_loader) * 100.0, mask_loss, auxiliary_loss, process_speed, cnt_iter, elapsed)
    logger.info(output)
    torch.cuda.empty_cache()
    
    
def experiment(dataloader, args):
    ts = time.time()
    
    # Construct Model
    num_aux_task = len(args.aux_task)
    encoder = Encoder(args)
    classifier = Classifier(args.out_dim, args.molvec_dim, args.classifier_dim, args.in_dim, args.cdp_rate, ACT2FN[args.act])

    models = {'encoder': encoder, 'classifier': classifier}
    if len(args.aux_task) > 0:
        regressor = Regressor(args.molvec_dim, args.regressor_dim, num_aux_task, args.rdp_rate, ACT2FN[args.act])
        models.update({'regressor': regressor})

    # Initialize Optimizer
    logger.info('####### Model Constructed #######')
    mask_trainable_parameters = list()
    auxiliary_trainable_parameters = list()
    for key, model in models.items():
        model.to(args.device)
        if key in ['encoder', 'classifier']:
            mask_trainable_parameters += list(filter(lambda p: p.requires_grad, model.parameters()))
        if key in ['encoder', 'regressor']:
            auxiliary_trainable_parameters += list(filter(lambda p: p.requires_grad, model.parameters()))
        logger.info('{:10}: {:>10} parameters'.format(key, sum(p.numel() for p in model.parameters())))
        setattr(args, '{}_param'.format(key), sum(p.numel() for p in model.parameters()))
    logger.info('#################################')
    
    if args.optim == 'ADAM':
        mask_optimizer = optim.Adam(mask_trainable_parameters, lr=0, betas=(0.9, 0.98), eps=1e-9)
        auxiliary_optimizer = optim.Adam(auxiliary_trainable_parameters, lr=0, betas=(0.9, 0.98), eps=1e-9)
    elif args.optim == 'RMSProp':
        mask_optimizer = optim.RMSprop(mask_trainable_parameters, lr=0)
        auxiliary_optimizer = optim.RMSprop(auxiliary_trainable_parameters, lr=0)
    elif args.optim == 'SGD':
        mask_optimizer = optim.SGD(mask_trainable_parameters, lr=0)
        auxiliary_optimizer = optim.SGD(auxiliary_trainable_parameters, lr=0)
    else:
        assert False, "Undefined Optimizer Type"
    optimizers = {'mask':mask_optimizer, 'auxiliary':auxiliary_optimizer}

    # Reload Checkpoint Model
    epoch = 0
    cnt_iter = 0
    if args.ck_filename:
        epoch, cnt_iter, models, optimizers = load_checkpoint(models, optimizers, args.ck_filename, args)
        logger.info('Loaded Model from {}'.format(args.ck_filename))
    
    mask_optimizer = NoamOpt(args.out_dim, args.lr_factor, args.lr_step, optimizers['mask'])
    auxiliary_optimizer = NoamOpt(args.out_dim, args.lr_factor, args.lr_step, optimizers['auxiliary'])
    optimizers = {'mask':mask_optimizer, 'auxiliary':auxiliary_optimizer}

    # Train Model
    validate(models, dataloader['val'], args, cnt_iter=cnt_iter, epoch=epoch)
    train(models, optimizers, dataloader, epoch, cnt_iter, args)

    # Logging Experiment Result
    te = time.time()    
    args.elapsed = te-ts
    logger.info('Training Completed')

In [12]:
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
args = parser.parse_args("")

##### SIZE #####
args.vocab_size = 40
args.degree_size = 6
args.numH_size = 5
args.valence_size = 6
args.isarom_size = 2
args.in_dim = 64
args.out_dim = 256
args.molvec_dim = 256
args.classifier_dim = 500
args.regressor_dim = 500

##### MODEL #####
args.n_layer = 6
args.use_attn = True
args.n_attn_heads = 8
args.act = 'gelu'
args.use_bn = True
args.sc_type = 'sc'
args.emb_train = True
args.train_logp = True
args.train_mr = True
args.train_tpsa = True
args.train_sas = True
args.train_mw = True
args.aux_task = ['logP', 'mr', 'tpsa', 'sas', 'mw']

##### HYPERPARAMETERS #####
args.optim = 'ADAM'
args.lr = 0.001
args.l2_coef = 0.001
args.dp_rate = 0.1
args.cdp_rate = 0.3
args.rdp_rate = 0.3
args.lr_factor = 1.0
args.lr_step = 4000
args.r_lambda = 1.0

##### EXP #####
args.epoch = 100
args.batch_size = 512
args.test_batch_size = 512
args.save_every = 100
args.validate_every = 100
args.log_every = 20


##### DEVICE #####
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

##### LOGGING #####
args.log_path = 'runs'
args.model_name = 'exp_test3'
args.model_explain = make_model_comment(args)
train_writer = SummaryWriter(join(args.log_path, args.model_name+'_train'))
val_writer = SummaryWriter(join(args.log_path, args.model_name+'_val'))
train_writer.add_text(tag='model', text_string='{}:{}'.format(args.model_name, args.model_explain), global_step= 0)
logger = get_logger(join(args.log_path, args.model_name+'_train'))

##### RESUME TRAINING #####
args.ck_filename = None # Example: 'model_ck_000_000000100.tar'

In [13]:
train_dataset_path = './dataset/xxs/train'
val_dataset_path = './dataset/xxs/val'

list_trains = get_dir_files(train_dataset_path)
list_vals = get_dir_files(val_dataset_path)

train_dataset = zincDataset(train_dataset_path, list_trains[0], 8)
sampler = SequentialSampler(train_dataset)
SortedBatchSampler = BatchSampler(sampler=sampler, batch_size=args.batch_size, drop_last=True, shuffle_batch=True)
train_dataloader = DataLoader(train_dataset,
                                  num_workers=8,
                                  collate_fn=postprocess_batch,
                                  batch_sampler=SortedBatchSampler)

val_dataset = zincDataset(val_dataset_path, list_vals[0], 8)
sampler = SequentialSampler(val_dataset)
SortedBatchSampler = BatchSampler(sampler=sampler, batch_size=args.test_batch_size, drop_last=True, shuffle_batch=False)
val_dataloader = DataLoader(val_dataset,
                                num_workers=8,
                                collate_fn=postprocess_batch,
                                batch_sampler=SortedBatchSampler)

dataloader = {'train': train_dataloader, 'val': val_dataloader}

In [16]:
logger.info("######## Starting Training ########")
result = experiment(dataloader, args)

2018-12-04 02:11:42,018 [INFO] ####### Model Constructed #######
2018-12-04 02:11:42,018 [INFO] ####### Model Constructed #######
2018-12-04 02:11:42,028 [INFO] encoder   :     547985 parameters
2018-12-04 02:11:42,028 [INFO] encoder   :     547985 parameters
2018-12-04 02:11:42,030 [INFO] classifier:      48911 parameters
2018-12-04 02:11:42,030 [INFO] classifier:      48911 parameters
2018-12-04 02:11:42,032 [INFO] logP      :     131585 parameters
2018-12-04 02:11:42,032 [INFO] logP      :     131585 parameters
2018-12-04 02:11:42,034 [INFO] mr        :     131585 parameters
2018-12-04 02:11:42,034 [INFO] mr        :     131585 parameters
2018-12-04 02:11:42,035 [INFO] tpsa      :     131585 parameters
2018-12-04 02:11:42,035 [INFO] tpsa      :     131585 parameters
2018-12-04 02:11:42,036 [INFO] #################################
2018-12-04 02:11:42,036 [INFO] #################################
2018-12-04 02:11:48,356 [INFO] [T] E:  0. P:2.3%. Loss:     10.7. Mask Loss:     10.7. 162

KeyboardInterrupt: 