In [2]:
import os
import shutil
import sys
import copy
import json
import random
import numpy as np
import pandas as pd
from time import time
from sklearn.metrics import mean_squared_error, roc_auc_score, average_precision_score, f1_score
from lifelines.utils import concordance_index
from scipy.stats import pearsonr
import pickle 

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn.init as init
from torch.distributions import Normal

torch.manual_seed(2)  
np.random.seed(3)
from prettytable import PrettyTable
import scikitplot as skplt
import matplotlib.pyplot as plt
plt.style.use("ggplot")

print(os.path.abspath('.'))
import warnings
warnings.filterwarnings("ignore")

from DeepPurpose import utils, models, dataset
from DeepPurpose.models import *

#from modules import *
from time import time

/home/gdp/data/DeepPurpose0


In [3]:
class Bottleneck(nn.Module):
    def __init__(self, init_dim):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv1d(
            in_channels=init_dim,
            out_channels=32,
            kernel_size=5,
        padding=2)

        self.conv2 = nn.Conv1d(
            in_channels=32,
            out_channels=64,
            kernel_size=5,
        padding=2)        
        
        self.conv3 = nn.Conv1d(
            in_channels=64,
            out_channels=128,
            kernel_size=5,
        padding=2)        

        self.bn1 = nn.BatchNorm1d(init_dim)
        self.bn2 = nn.BatchNorm1d(32)
        self.bn3 = nn.BatchNorm1d(64)

    def forward(self, x_onehot):
        x = self.conv1(F.leaky_relu(self.bn1(x_onehot)))
        x = self.conv2(F.leaky_relu(self.bn2(x)))
        x = self.conv3(F.leaky_relu(self.bn3(x)))
        return F.leaky_relu(x)

In [4]:
class EncoderCNN (nn.Module):

    def __init__(self, init_dim):
        super(EncoderCNN, self).__init__()
        self.resnet = Bottleneck(init_dim)

    def forward(self, x):
        x = self.resnet(x)
        x = F.adaptive_max_pool1d(x, output_size=1)  # b*128*1000>b*128
        return x

In [5]:
class DecoderCNN(nn.Module):
	def __init__(self, latent_feature_dim, init_dim, kernel_size=5):
		super(DecoderCNN, self).__init__()       
		class View(nn.Module):
			def __init__(self, shape):
				super(View, self).__init__()
				self.shape = shape
			def forward(self, x):
				return x.view(*self.shape)

		self.fc1 = nn.Sequential(
			nn.Linear(latent_feature_dim, 128),            
			nn.BatchNorm1d(128),
			nn.LeakyReLU())                      
                            
		self.network = nn.Sequential(
			nn.Linear(128, 128 * 1000),            
			View(shape=(-1, 128, 1000)),            
			nn.BatchNorm1d(128),
			nn.LeakyReLU(),
			nn.ConvTranspose1d(in_channels = 128,out_channels = 64, kernel_size=5, padding = kernel_size//2),

			nn.BatchNorm1d(64),
			nn.LeakyReLU(),
			nn.ConvTranspose1d(in_channels = 64,out_channels = 32, kernel_size=5, padding = kernel_size//2),

			nn.BatchNorm1d(32),
			nn.LeakyReLU(),
			nn.ConvTranspose1d(in_channels = 32,out_channels = 26, kernel_size=5, padding = kernel_size//2))
        
	def forward(self, in_x):
		in_x=self.fc1(in_x)
		x=self.network(in_x)
		return x

In [6]:
def ONE_HOT_Encoder(train, batch_size, **config):
   
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

	BATCH_SIZE = batch_size   
	params = {'batch_size': BATCH_SIZE, 'shuffle': False, 'drop_last': False}
	training_generator = data.DataLoader(data_process_loader(train.index.values, train.Label.values, train, **config), **params)

	v_f_D = torch.tensor([], dtype=torch.float)
	v_f_P = torch.tensor([], dtype=torch.float)    
	y_label = torch.tensor([], dtype=torch.float) 
	for i, (v_d, v_p, label) in enumerate(training_generator):
		if target_encoding == 'Transformer':
			v_p = v_p
		else:
			v_p = v_p.float() 
		if drug_encoding == "MPNN" or drug_encoding == 'Transformer':
			v_d = v_d
		else:
			v_d = v_d.float()               

		v_D = v_d
		#print(v_D.shape)
		v_P = v_p
		#print(v_P.shape)

		v_f_D = torch.cat([v_f_D, v_D.detach()], dim=0)
       
		v_f_P = torch.cat([v_f_P, v_P.detach()], dim=0)
		y_label = torch.cat([y_label, torch.from_numpy(np.array(label)).float()], dim=0)
	return v_f_D, v_f_P, y_label.float().unsqueeze(1)

In [7]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder, init_dim, latent_feature_dim):
        super(VAE, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

        self.h2mu = nn.Linear(128, latent_feature_dim)
        self.h2log_var = nn.Linear(128, latent_feature_dim)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            if isinstance(m, nn.Conv1d): 
                init.xavier_normal(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + (std * eps)
        return z
    
    def loss_function(self, features_reconstruction, features, mu, log_var):
        #re_loss = nn.MSELoss(reduction = 'sum')
        re_loss = nn.BCEWithLogitsLoss(reduction = 'sum')  
        reconstruction_loss = re_loss(features_reconstruction.reshape(-1, 26000), features.reshape(-1, 26000))/features_reconstruction.shape[0]
        KLD = -0.5 * torch.sum(1. + log_var - torch.pow(mu, 2) - torch.exp(log_var))/mu.shape[0]
        return reconstruction_loss, KLD

    def forward(self, x):
        h = self.encoder(x).squeeze(-1)
        mu = self.h2mu(h)
        log_var = self.h2log_var(h)       
        z = self.reparameterize(mu, log_var)
        latent_representation = z
        features_reconstruction = self.decoder(z)
        #print(z)
        #print(features_reconstruction)
        re_loss, KLD=self.loss_function(features_reconstruction, x, mu, log_var)
        return latent_representation, features_reconstruction, mu, log_var, re_loss, KLD
    
    def predict(self, x):
        
        h = self.encoder(x).squeeze(-1)
        mu = self.h2mu(h)
        log_var = self.h2log_var(h)       
        z = self.reparameterize(mu, log_var)
        return z

In [8]:
#DTC
drug_encoding, target_encoding = 'CNN', 'CNN'

DTC_sub=pd.read_csv('VAE_data/df_DTC_undrugs_targets.csv')
Smile = DTC_sub['SMILES'].values
Target = DTC_sub['Target_Sequence'].values
y = DTC_sub['Label'].values
train, val, test = utils.data_process(Smile, Target, y,
                                drug_encoding, target_encoding, 
                                split_method='train_full',frac=[0.7,0.1,0.2],
                                random_seed = 1)

print('DTC subset Done!----------------------------------')
#DB
Smile2 = np.load("VAE_data/DB_smiles.npy", allow_pickle=True)
Target2 = np.load("VAE_data/DB_targets.npy", allow_pickle=True)
y2 = np.load("VAE_data/DB_y.npy", allow_pickle=True)
train2, val2, test2 = utils.data_process(Smile2, Target2, y2,
                                drug_encoding, target_encoding, 
                                split_method='train_full',frac=[0.7,0.1,0.2],
                                random_seed = 1)
print('DB Done!-----------------------------------')

in total: 51863 drug-target pairs
encoding drug...
unique drugs: 6677
drug encoding finished...
encoding protein...
unique target sequence: 766
protein encoding finished...
splitting dataset...
Done.
DTC subset Done!----------------------------------
in total: 66434 drug-target pairs
encoding drug...
unique drugs: 10661
drug encoding finished...
encoding protein...
unique target sequence: 1413
protein encoding finished...
splitting dataset...
Done.
DB Done!-----------------------------------


In [9]:
#encode data for DTC
config = utils.generate_config2(drug_encoding = drug_encoding, target_encoding = target_encoding)
v_f_D, v_f_P, y = ONE_HOT_Encoder(train, 1024, **config)

#encode data for DB
config = utils.generate_config2(drug_encoding = drug_encoding, target_encoding = target_encoding)
v_f_D2, v_f_P2, y2 = ONE_HOT_Encoder(train2, 1024, **config)

In [10]:
v_f_D_all = torch.cat([v_f_D, v_f_D2], dim=0)
v_f_P_all = torch.cat([v_f_P, v_f_P2], dim=0)

In [11]:
n,d,c = v_f_P_all.shape
s=int(0.1*n)
index = torch.LongTensor(random.sample(range(n), s))
#v_f_D_P = torch.cat([v_f_D, v_f_P], dim=1)
v_f_D_val = torch.index_select(v_f_D_all, 0, index)
#y_val = torch.index_select(y, 0, index)
v_f_P_val = torch.index_select(v_f_P_all, 0, index)

In [18]:
#设置模型参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_feature_dim = 10
init_dim=26

model = VAE(EncoderCNN(init_dim), DecoderCNN(latent_feature_dim, init_dim, kernel_size=5),init_dim, latent_feature_dim).to(device)
#print(model)

#设置训练参数
LR = 1e-4
epochs = 200
batch_size=256
optimizer = torch.optim.Adam(model.parameters(), lr=LR)#, weight_decay= 1e-5 
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',factor=0.5, patience=4, verbose=True)

In [19]:
train_ids = TensorDataset(v_f_P_all)
train_loader = DataLoader(train_ids, batch_size=batch_size, shuffle=True)

val_ids = TensorDataset(v_f_P_val)
val_loader = DataLoader(val_ids, batch_size=batch_size, shuffle=True)

In [20]:
train_ids2 = TensorDataset(v_f_D_all)
train_loader2 = DataLoader(train_ids2, batch_size=batch_size, shuffle=True)

val_ids2 = TensorDataset(v_f_D_val)
val_loader2 = DataLoader(val_ids2, batch_size=batch_size, shuffle=True)

In [None]:
# train
loss_epoch = []
loss_val_epoch=[]
best_model = copy.deepcopy(model)
best_loss = 1e8

t_start = time()
for epoch in range(1, epochs + 1):
    model.train()
    train_loss = []
    for batch_idx, data in enumerate(train_loader):
        x = data[0]
        x = x.to(device)
        embedding, features_reconstruction, mu, log_var, re_loss, KLD_loss = model(x)
        loss = re_loss + 1*KLD_loss
        train_loss.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        t_now = time()
        # print statistics every 100 batch
        if (batch_idx + 1) % 100 == 0:
            print('Epoch [{}/{}], loss = {:.4f}, loss_re = {:.4f}, loss_KLD = {:.4f}, Total time = {:.4f} hours'.format(epoch, epochs, loss.item(), re_loss.item(), KLD_loss.item(), int(t_now - t_start)/3600))

        
    loss_epoch.append(np.mean(train_loss))
       
    # validate, select the best model up to now
    model.eval()
    val_losses = []
    with torch.no_grad():
        for val_batch_idx, val_data in enumerate(val_loader):
            val_x = val_data[0]
            val_x = val_x.to(device)
            val_embedding, val_features_reconstruction, val_mu, val_log_var, val_re_loss, val_KLD_loss = model(val_x)
            val_loss = val_re_loss + val_KLD_loss
            val_losses.append(val_loss.item())

        # 对和求平均，得到平均损失
        loss_val_epoch.append(np.mean(val_losses))
        #if val_avg_loss < best_loss:
        #    best_model = copy.deepcopy(model)
        #    best_loss = val_avg_loss
    print('Valerate at Epoch ' + str(epoch) + ' with loss ' + str(np.mean(val_losses))[:7])  
#----------------------------------------
fontsize = 16
iter_num = list(range(1,len(loss_epoch)+1))
plt.figure(0)
plt.plot(iter_num, loss_epoch, "bo-")
plt.xlabel("epoch", fontsize = fontsize)
plt.ylabel("train_loss", fontsize = fontsize)

fontsize = 16
iter_num = list(range(1,len(loss_val_epoch)+1))
plt.figure(1)
plt.plot(iter_num, loss_val_epoch, "mo-")
plt.xlabel("epoch", fontsize = fontsize)
plt.ylabel("val_loss", fontsize = fontsize)

In [22]:
torch.save(model.state_dict(), 'vae_protein_dtc_db.pt')

In [None]:
state_dict = torch.load('./vae_protein_dtc_db.pt')
model.load_state_dict(state_dict)

In [23]:
#############################unique target#####################################
#Dtc
AA = pd.Series(train['Target Sequence'].unique()).apply(trans_protein)
AA2 = AA.apply(protein_2_embed)

target_onehot_np = np.stack(AA2.values, axis=0)
target_onehot_th = torch.from_numpy(target_onehot_np)

#Db
AA_dtc = pd.Series(train2['Target Sequence'].unique()).apply(trans_protein)
AA2_dtc = AA_dtc.apply(protein_2_embed)

target_onehot_db = np.stack(AA2_dtc.values, axis=0)
target_onehot_db = torch.from_numpy(target_onehot_db)

In [24]:
retrain_ids1 = TensorDataset(target_onehot_th)
retrain_loader1 = DataLoader(retrain_ids1, batch_size=batch_size, shuffle=False) #DTC

retrain_ids2 = TensorDataset(target_onehot_db)
retrain_loader2 = DataLoader(retrain_ids2, batch_size=batch_size, shuffle=False) #DB

In [28]:
#re-encoder
features_vae = torch.tensor([], dtype=torch.float)
for batch_idx, data in enumerate(retrain_loader2):
    x = data[0]
    x = x.float().to(device)
    feature_vae = model.predict(x)
    features_vae = torch.cat([features_vae, feature_vae.cpu().detach()], dim=0)

In [30]:
torch.save(features_vae, 'features_target_db_dtcdb_vae.pt')