In [1]:
import numpy as np
import os
import time
from datetime import datetime

import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image

from utils import *
from models import Generator, Discriminator
from data.sparse_molecular_dataset import SparseMolecularDataset



In [71]:
class Solver(object):
    """Solver for training and testing StarGAN."""
    
    def __init__(self, config):
        """Initialize configurations."""
        self.data = SparseMolecularDataset()
        self.data.load(config.mol_data_dir)
        
        # Model configurations.
        self.z_dim = config.z_dim
        self.m_dim = self.data.atom_num_types
        self.b_dim = self.data.bond_num_types
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.post_method = config.post_method
        
        self.metric = 'validity, sas'
        
        # Training configurations.
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.dropout = config.dropout
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step
        
        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()
        
    def build_model(self):
        """Create a generator and a discriminator."""
        self.G = Generator(self.g_conv_dim, self.z_dim,
                           self.data.vertexes,
                           self.data.bond_num_types,
                           self.data.atom_num_types,
                           self.dropout)
        self.D = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout)
        self.V = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout)

        self.g_optimizer = torch.optim.Adam(list(self.G.parameters())+list(self.V.parameters()),
                                            self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        self.G.to(self.device)
        self.D.to(self.device)
        self.V.to(self.device)
        
    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))      
    
    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
        V_path = os.path.join(self.model_save_dir, '{}-V.ckpt'.format(resume_iters))
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
        self.V.load_state_dict(torch.load(V_path, map_location=lambda storage, loc: storage))
        
    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr
            
    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()
        
    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_gr395ph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

In [69]:
for g in self.g_optimizer.param_groups:
    print(g['lr'])
    break

0.0001


In [72]:
self = Solver(config)


Generator(
  (layers): Sequential(
    (0): Linear(in_features=8, out_features=128, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.0, inplace=True)
    (3): Linear(in_features=128, out_features=256, bias=True)
    (4): Tanh()
    (5): Dropout(p=0.0, inplace=True)
    (6): Linear(in_features=256, out_features=512, bias=True)
    (7): Tanh()
    (8): Dropout(p=0.0, inplace=True)
  )
  (edges_layer): Linear(in_features=512, out_features=405, bias=True)
  (nodes_layer): Linear(in_features=512, out_features=45, bias=True)
  (dropoout): Dropout(p=0.0, inplace=False)
)
G
The number of parameters: 396610
Discriminator(
  (gcn_layer): GraphConvolution(
    (linear1): Linear(in_features=5, out_features=128, bias=True)
    (linear2): Linear(in_features=128, out_features=64, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (agg_layer): GraphAggregation(
    (sigmoid_linear): Sequential(
      (0): Linear(in_features=69, out_features=128, bias=True)
      (1): Sigmoid()
    )
    (tan

In [8]:
conv_dim, m_dim, b_dim, dropout = self.d_conv_dim, self.m_dim, self.b_dim, self.dropout

graph_conv_dim, aux_dim, linear_dim = conv_dim
in_features, out_features, b_dim, dropout = graph_conv_dim[-1], aux_dim, b_dim, dropout

In [9]:
G = Generator(self.g_conv_dim, self.z_dim,
                           self.data.vertexes,
                           self.data.bond_num_types,
                           self.data.atom_num_types,
                           self.dropout)
D = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout)

In [10]:
num_params = 0
for p in G.parameters():
    print(p.size())

torch.Size([128, 8])
torch.Size([128])
torch.Size([256, 128])
torch.Size([256])
torch.Size([512, 256])
torch.Size([512])
torch.Size([405, 512])
torch.Size([405])
torch.Size([45, 512])
torch.Size([45])


In [11]:
conv_dims, z_dim, vertexes, edges, nodes, dropout = self.g_conv_dim, self.z_dim,\
                           self.data.vertexes,\
                           self.data.bond_num_types,\
                           self.data.atom_num_types,\
                           self.dropout

In [17]:
self.nodes_layer

Linear(in_features=512, out_features=45, bias=True)

In [35]:
self.vertexes = vertexes
self.edges = edges
self.nodes = nodes

layers = []
for c0, c1 in zip([z_dim]+conv_dims[:-1], conv_dims):
    layers.append(nn.Linear(c0, c1))
    layers.append(nn.Tanh())
    layers.append(nn.Dropout(p=dropout, inplace=True))
self.layers = nn.Sequential(*layers)
self.edges_layer = nn.Linear(conv_dims[-1], edges * vertexes * vertexes)
self.nodes_layer = nn.Linear(conv_dims[-1], vertexes * nodes)
self.dropout = nn.Dropout(p=dropout)

In [36]:
self.layers

Sequential(
  (0): Linear(in_features=8, out_features=128, bias=True)
  (1): Tanh()
  (2): Dropout(p=0.0, inplace=True)
  (3): Linear(in_features=128, out_features=256, bias=True)
  (4): Tanh()
  (5): Dropout(p=0.0, inplace=True)
  (6): Linear(in_features=256, out_features=512, bias=True)
  (7): Tanh()
  (8): Dropout(p=0.0, inplace=True)
)

In [38]:
x = torch.randn(7, 7, 8)
output = self.layers(x)
edges_logits = self.edges_layer(output)\
               .view(-1,self.edges,self.vertexes,self.vertexes)
edges_logits = (edges_logits + edges_logits.permute(0,1,3,2))/2
edges_logits = self.dropout(edges_logits.permute(0,2,3,1))

nodes_logits = self.nodes_layer(output)
nodes_logits = self.dropout(nodes_logits.view(-1,self.vertexes,self.nodes))

In [53]:
self.nodes_layer(output).size()

torch.Size([7, 7, 45])

In [52]:
nodes_logits.view(-1,self.vertexes,self.nodes).size()

torch.Size([49, 9, 5])

In [51]:
import torch.nn as nn
class Generator(nn.Module):
    """Generator network."""
    def __init__(self, conv_dims, z_dim, vertexes, edges, nodes, dropout):
        super(Generator, self).__init__()

        self.vertexes = vertexes
        self.edges = edges
        self.nodes = nodes

        layers = []
        for c0, c1 in zip([z_dim]+conv_dims[:-1], conv_dims):
            layers.append(nn.Linear(c0, c1))
            layers.append(nn.Tanh())
            layers.append(nn.Dropout(p=dropout, inplace=True))
        self.layers = nn.Sequential(*layers)

        self.edges_layer = nn.Linear(conv_dims[-1], edges * vertexes * vertexes)
        self.nodes_layer = nn.Linear(conv_dims[-1], vertexes * nodes)
        self.dropoout = nn.Dropout(p=dropout)

    def forward(self, x):
        output = self.layers(x)
        edges_logits = self.edges_layer(output)\
                       .view(-1,self.edges,self.vertexes,self.vertexes)
        edges_logits = (edges_logits + edges_logits.permute(0,1,3,2))/2
        edges_logits = self.dropoout(edges_logits.permute(0,2,3,1))

        nodes_logits = self.nodes_layer(output)
        nodes_logits = self.dropoout(nodes_logits.view(-1,self.vertexes,self.nodes))

        return edges_logits, nodes_logits

In [6]:
import argparse

def str2bool(v):
    return v.lower() in ('true')
parser = argparse.ArgumentParser()

# Model configuration.
parser.add_argument('--z_dim', type=int, default=8, help='dimension of domain labels')
parser.add_argument('--g_conv_dim', default=[128,256,512], help='number of conv filters in the first layer of G')
parser.add_argument('--d_conv_dim', type=int, default=[[128, 64], 128, [128, 64]], help='number of conv filters in the first layer of D')
parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
parser.add_argument('--post_method', type=str, default='softmax', choices=['softmax', 'soft_gumbel', 'hard_gumbel'])

# Training configuration.
parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size')
parser.add_argument('--num_iters', type=int, default=2000, help='number of total iterations for training D')
parser.add_argument('--num_iters_decay', type=int, default=1000, help='number of iterations for decaying lr')
parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')

# Test configuration.
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')

# Miscellaneous.
parser.add_argument('--num_workers', type=int, default=1)
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
parser.add_argument('--use_tensorboard', type=str2bool, default=False)

# Directories.
parser.add_argument('--mol_data_dir', type=str, default='data/gdb9_9nodes.sparsedataset')
parser.add_argument('--log_dir', type=str, default='molgan/logs')
parser.add_argument('--model_save_dir', type=str, default='molgan/models')
parser.add_argument('--sample_dir', type=str, default='molgan/samples')
parser.add_argument('--result_dir', type=str, default='molgan/results')

# Step size.
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=1000)
parser.add_argument('--model_save_step', type=int, default=10000)
parser.add_argument('--lr_update_step', type=int, default=1000)

config = parser.parse_known_args()[0]