In [1]:
###################################################################
# This is an official PyTorch implementation for Target-specific
# Generation of Molecules (TagMol)
# Author: Junde Li, The Pennsylvania State University
# Date: Aug 1, 2022
###################################################################

import os, time
import re
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import pickle
from rdkit import Chem
from dataloader import *

  from .autonotebook import tqdm as notebook_tqdm


In [44]:
# --------------------------
# Hyperparameters
# --------------------------

lr             = 1e-4
beta1          = 0.0
beta2          = 0.9
batch_size     = 16
max_epoch      = 400
num_workers    = 2
ligand_size    = 32
x_dim         = 1024
z_dim         = 128
save_step      = 100

name           = "model/tagmol"
log_dir        = f"{name}"
models_dir     = f"{name}/saved_models"
device         = torch.device("cuda:0")

os.makedirs(log_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)

In [45]:
train_dataset = PDBbindPLDataset(root_dir='data/pdbbind/tiny-set',
                                    n_points=5000, 
                                    lig_size=32,
                                    train=True,
                                    transform=transforms.Compose([
                                        Normalize(),
                                        RandomRotateJitter(),
                                        ToTensor()
                                    ]))

train_dataloader = DataLoader(train_dataset, batch_size=8,
                        shuffle=True)

for i_batch, sample_batched in enumerate(train_dataloader):
    protein = sample_batched['protein']
    atoms, bonds = sample_batched['ligand']
    break

In [46]:
# --------------------------
# Define Encoder Network
# --------------------------

class STN(nn.Module):
    """Spatial transformer network for alignment"""

    def __init__(self, k):
        super(STN, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        identity = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
            batchsize, 1)
        if x.is_cuda:
            identity = identity.cuda()
        x = x + identity
        x = x.view(-1, self.k, self.k)
        return x


class PointNetEncoder(nn.Module):
    """PointNet Encoder Network for protein embedding."""

    def __init__(self, x_dim, channel=4, feature_transform=False):
        super(PointNetEncoder, self).__init__()
        self.stn = STN(k=3)

        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, x_dim, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(x_dim)
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STN(k=64)

    def forward(self, x):
        B, D, N = x.size()
        trans = self.stn(x[:,:3,:]) # channel=3 only
        x = x.transpose(2, 1)
        if D > 3:
            feature = x[:, :, 3:]
            x = x[:, :, :3]
        x = torch.bmm(x, trans) # shape=(B, N, 3)

        if D > 3:
            x = torch.cat([x, feature], dim=2)
        x = x.transpose(2, 1) # shape=(B, D, N)
        x = F.relu(self.bn1(self.conv1(x))) # shape=(B, 64, N)

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2, 1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2, 1)
        else:
            trans_feat = None
        
        x = F.relu(self.bn2(self.conv2(x))) # shape=(B, 128, N)
        x = self.bn3(self.conv3(x)) # shape=(B, x_dim, N)
        # Aggregate point features by max pooling.
        x = torch.max(x, 2, keepdim=True)[0] # shape=(B, x_dim)
        x = x.view(-1, x_dim)
        
        return x, trans, trans_feat


In [47]:
self = PointNetEncoder(x_dim, channel=4, feature_transform=True)
x, trans, trans_feat = self(protein.transpose(2, 1))

In [48]:
x.shape

torch.Size([8, 1024])