In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

import torchtext

import pickle

import glob
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from tqdm import tqdm

with open("./data/DTI/DTI_train.pickle", "rb") as f:
    train_data = pickle.load(f)

with open("./data/DTI/DTI_valid.pickle", "rb") as f:
    valid_data = pickle.load(f)
    
with open("./data/DTI/DTI_test.pickle", "rb") as f:
    test_data = pickle.load(f)
    
with open("./data/molecule_net/MoleculeNet_tokenizer.pickle", "rb") as f:
    molecule_tokenizer = pickle.load(f)
    
with open("./data/DTI/protein_tokenizer.pickle", "rb") as f:
    protein_tokenizer = pickle.load(f)

molecule_vocab_dim     = len(molecule_tokenizer.vocab.itos)
molecule_seq_len       = 256
molecule_embedding_dim = 512

protein_vocab_dim     = len(protein_tokenizer.vocab.itos)
protein_seq_len       = 256
protein_embedding_dim = 512
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 64

In [2]:
class DTIDataset(torch.utils.data.Dataset):
    def __init__(self, data, molecule_tokenizer, molecule_seq_len, protein_tokenizer, protein_seq_len):
        super(DTIDataset, self).__init__()

        self.data = data
        
        self.molecule_tokenizer = molecule_tokenizer
        self.molecule_vocab = molecule_tokenizer.vocab
        self.molecule_seq_len = molecule_seq_len
        
        self.protein_tokenizer = protein_tokenizer
        self.protein_vocab = protein_tokenizer.vocab
        self.protein_seq_len = protein_seq_len
        
        self.cls_token_id  = self.molecule_vocab.stoi[self.molecule_tokenizer.init_token]
        self.sep_token_id  = self.molecule_vocab.stoi[self.molecule_tokenizer.eos_token]
        self.pad_token_id  = self.molecule_vocab.stoi[self.molecule_tokenizer.pad_token]
        self.mask_token_id = self.molecule_vocab.stoi[self.molecule_tokenizer.unk_token]
        
    def __getitem__(self, idx):
        current_data = self.data.loc[idx]
        
        molecule_string = current_data['Drug']
        protein_string = current_data['Target']
        target = current_data['Y']

        molecule = self.molecule_tokenizer.numericalize(molecule_string).squeeze()
        protein = self.protein_tokenizer.numericalize(protein_string).squeeze()
        
        if len(molecule) < self.molecule_seq_len - 2:
            molecule_pad_length = self.molecule_seq_len - len(molecule) - 2
        else:
            molecule = molecule[:self.molecule_seq_len - 2]
            molecule_pad_length = 0
            
        if len(protein) < self.protein_seq_len - 2:
            protein_pad_length = self.protein_seq_len - len(protein) - 2
        else:
            protein = protein[:self.protein_seq_len - 2]
            protein_pad_length = 0
              
        molecule = torch.cat([torch.tensor([self.cls_token_id]), molecule, torch.tensor([self.sep_token_id]), torch.tensor([self.pad_token_id] * molecule_pad_length)]).long().contiguous()
        protein = torch.cat([torch.tensor([self.cls_token_id]), protein, torch.tensor([self.sep_token_id]), torch.tensor([self.pad_token_id] * protein_pad_length)]).long().contiguous()
        
        target = torch.tensor(target).type(torch.FloatTensor).contiguous()

        segment_embedding = torch.zeros(molecule.size(0))

        return molecule, protein, target, segment_embedding

    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab

    
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    
    return torch.utils.data.dataloader.default_collate(batch)

In [3]:
import torch
import torch.nn as nn

class BERT(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, pad_token_id):
        super(BERT, self).__init__()
        self.pad_token_id  = pad_token_id
        self.nhead         = 4
        self.embedding     = BERTEmbedding(vocab_dim, seq_len, embedding_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=self.nhead, batch_first=True)
        self.encoder_block = nn.TransformerEncoder(self.encoder_layer, num_layers=4)
        
        
    def forward(self, data, segment_embedding):
        pad_mask  = BERT.get_attn_pad_mask(data, data, self.pad_token_id).repeat(self.nhead, 1, 1)
        embedding = self.embedding(data, segment_embedding)
        output    = self.encoder_block(embedding, pad_mask) 
        
        return output
    
    @staticmethod
    def get_attn_pad_mask(seq_q, seq_k, i_pad):
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        pad_attn_mask = seq_k.data.eq(i_pad)
        pad_attn_mask= pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)
        
        return pad_attn_mask

In [4]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, dropout_rate=0.1, device=device):
        super(BERTEmbedding, self).__init__()
        self.seq_len       = seq_len
        self.vocab_dim     = vocab_dim
        self.embedding_dim = embedding_dim
        self.dropout_rate  = dropout_rate
        
        # vocab --> embedding
        self.token_embedding      = nn.Embedding(self.vocab_dim, self.embedding_dim) 
        self.token_dropout        = nn.Dropout(self.dropout_rate)    
        
        # seq len --> embedding
        self.positional_embedding = nn.Embedding(self.seq_len, self.embedding_dim)
        self.positional_dropout   = nn.Dropout(self.dropout_rate) 
        
        # segment (0, 1) --> embedding
        self.segment_embedding    = nn.Embedding(2, self.embedding_dim)
        self.segment_dropout      = nn.Dropout(self.dropout_rate) 
        
        
    def forward(self, data, segment_embedding):
        token_embedding      = self.token_embedding(data)
        token_embedding      = self.token_dropout(token_embedding)
        
        positional_encoding  = torch.arange(start=0, end=self.seq_len, step=1).long()
        # data의 device 정보 가져와서 처리
        positional_encoding  = positional_encoding.unsqueeze(0).expand(data.size()).to(device)
        positional_embedding = self.positional_embedding(positional_encoding)
        positional_embedding = self.positional_dropout(positional_embedding)
        
        segment_embedding    = self.segment_embedding(segment_embedding)
        segment_embedding    = self.segment_dropout(segment_embedding)
        
        return token_embedding + positional_embedding + segment_embedding

In [5]:
class MoleculeBranch(nn.Module):
    def __init__(self, bert, output_dim):
        super(MoleculeBranch, self).__init__()
        self.bert = bert
        d_model = 256 * 512
        self.fc   = nn.Linear(d_model, output_dim)
    
    def forward(self, x, segment_embedding):
        batch_size = x.shape[0]
        output = self.bert(x, segment_embedding)
        output = output.reshape(batch_size, -1)
        output = self.fc(output)
        
        return output

In [6]:
class ProteinBranch(nn.Module):
    def __init__(self, seq_len, vocab_dim, embedding_dim, dropout_rate):
        super(ProteinBranch, self).__init__()
        self.embedding = ProteinEmbedding(seq_len, vocab_dim, embedding_dim, dropout_rate)
        
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(5, 5), padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(5, 5), padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=32, kernel_size=(1, 1), padding=0),
            nn.ReLU(),
            nn.AvgPool2d((5, 5), padding=2)
        )
        
        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(5, 5), padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(5, 5), padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=64, kernel_size=(1, 1), padding=0),
            nn.ReLU(),
            nn.AvgPool2d((5, 5), padding=2)
        )
        
        self.conv_block_3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(5, 5), padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(5, 5), padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(1, 1), padding=0),
            nn.ReLU(),
            nn.AvgPool2d((5, 5), padding=2)
        )
        
        self.fc = nn.Linear(256 * 3 * 5, embedding_dim)
    
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.embedding(x)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        x = self.conv_block_3(x)
        x = x.reshape(batch_size, -1)
        x = self.fc(x)
        
        return x

In [7]:
class ProteinEmbedding(nn.Module):
    def __init__(self, seq_len, vocab_dim, embedding_dim, dropout_rate):
        super(ProteinEmbedding, self).__init__()
        self.seq_len = seq_len
        self.vocab_dim = vocab_dim
        self.embedding_dim = embedding_dim
        self.dropout_rate = dropout_rate
        
        self.embedding = nn.Embedding(self.vocab_dim, self.embedding_dim)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def forward(self, x):
        embedding = self.embedding(x.long())
        embedding = self.dropout(embedding).unsqueeze(1)
        
        return embedding

In [8]:
class PredictionHead(nn.Module):
    def __init__(self, moleucle_branch, protein_branch, output_dim=512, dropout_rate=0.3):
        super(PredictionHead, self).__init__()
        self.molecule_branch = moleucle_branch.to(device)
        self.protein_branch = protein_branch.to(device)
        self.output_dim = output_dim
        
        self.fc_1 = nn.Linear(self.output_dim * 2, 512)
#         self.fc_2 = nn.Linear(512, 256)
#         self.fc_3 = nn.Linear(256, 128)
        self.fc_2 = nn.Linear(512, 128)
        self.fc_out = nn.Linear(128, 1)
        
        self.dropout = nn.Dropout(dropout_rate, inplace=True)
        self.activation = nn.ReLU(inplace=True)
        
    def forward(self, molecule, protein, segment_embedding):
        molecule_embedding = self.molecule_branch(molecule, segment_embedding)
        protein_embedding = self.protein_branch(protein)
        
        merged_embedding = torch.cat([molecule_embedding, protein_embedding], dim=1)
        out = self.dropout(merged_embedding).to(device)
        
        out = self.activation(self.dropout(self.fc_1(out)))
        out = self.activation(self.dropout(self.fc_2(out)))
#         out = self.activation(self.dropout(self.fc_3(out)))
        out = self.fc_out(out).to(device)
        
        return out

In [9]:
def generate_dataloader(data, molecule_tokenizer, molecule_seq_len, protein_tokenizer, protein_seq_len, batch_size, collate_fn, shuffle=True, num_workers=6):
    dataset = DTIDataset(data=data, molecule_tokenizer=molecule_tokenizer, molecule_seq_len=molecule_seq_len, protein_tokenizer=protein_tokenizer, protein_seq_len=protein_seq_len)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)
    
    return data_loader


In [10]:
def train(model, iterator, optimizer, criterion, device, clip=1):
    model.train()

    epoch_loss = 0
    epoch_corrects = 0
    epoch_num_data = 0

    for molecule, protein, target, segment_emb in tqdm(iterator):
        optimizer.zero_grad()
        
        output = model(molecule.to(device), protein.to(device), segment_emb.long().to(device))
        
        output = output
        target = target.to(device)
        loss = criterion(output, target)
        loss.backward()
                
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()

        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)


@torch.no_grad()
def evaluate(model, iterator, criterion, device):
    model.eval()
    
    epoch_loss = 0
    epoch_corrects = 0
    epoch_num_data = 0

    for molecule, protein, target, segment_emb in iterator:
        optimizer.zero_grad()
        
        output = model(molecule.to(device), protein.to(device), segment_emb.long().to(device))
        
        output = output
        target = target.to(device)
        loss = criterion(output, target)
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)


@torch.no_grad()
def predict(model, iterator, device):
    model.eval()
    
    for batch, (molecule, protein, target, segment_emb) in enumerate(iterator):
        output = model(molecule.to(device), protein.to(device), segment_emb.long().to(device))
        
        molecule = molecule.clone().detach().to("cpu").tolist()
        protein = protein.clone().detach().to("cpu").tolist()
        output = output.clone().detach().to("cpu").tolist()
        target = target.clone().detach().to("cpu").tolist()
        
    return molecule, protein, output, target


def generate_epoch_dataloader(data, molecule_tokenizer, molecule_seq_len, protein_tokenizer, protein_seq_len, batch_size, collate_fn, shuffle=True, num_workers=6):
    dataset = DTIDataset(data=data, molecule_tokenizer=molecule_tokenizer, molecule_seq_len=molecule_seq_len, protein_tokenizer=protein_tokenizer, protein_seq_len=protein_seq_len)
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)
    
    return dataloader


In [11]:
molecule_bert = BERT(vocab_dim=100, seq_len=256, embedding_dim=512, pad_token_id=1).to(device)
for param in molecule_bert.parameters():
    param.requires_grad = False
    
molecule_branch = MoleculeBranch(molecule_bert, output_dim=512).to(device)
protein_branch = ProteinBranch(seq_len=256, vocab_dim=100, embedding_dim=512, dropout_rate=0.3).to(device)
model = PredictionHead(molecule_branch, protein_branch).to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=[0.9, 0.999], weight_decay=0.01)
# scheduler = CosineAnnealingLR(optimizer, T_max=10)
scheduler = ReduceLROnPlateau(optimizer)
criterion = nn.MSELoss()


In [None]:
N_EPOCHS  = 1000
PAITIENCE = 30

start_epoch = 0
if len(glob.glob("output/DTI/*.tsv")) != 0:
    print("load pretrained model ... ")
    start_epoch = len(glob.glob("output/DTI/*.tsv"))
    model.load_state_dict(torch.load('weights/DTI_single_bert_best.pt'))

n_paitience = 0
best_valid_loss = float('inf')
optimizer.zero_grad()
optimizer.step()
    
for epoch in range(start_epoch, N_EPOCHS):
    train_data_loader = generate_epoch_dataloader(train_data, 
                                                 molecule_tokenizer=molecule_tokenizer, 
                                                 molecule_seq_len=256,
                                                 protein_tokenizer=protein_tokenizer,
                                                 protein_seq_len=256,
                                                 batch_size=batch_size, 
                                                 collate_fn=collate_fn,
                                                 num_workers=8)
    
    valid_data_loader = generate_epoch_dataloader(valid_data, 
                                                 molecule_tokenizer=molecule_tokenizer, 
                                                 molecule_seq_len=256,
                                                 protein_tokenizer=protein_tokenizer,
                                                 protein_seq_len=256,
                                                 batch_size=batch_size, 
                                                 collate_fn=collate_fn,
                                                 num_workers=8)
    
    print(f'Epoch: {epoch:04}')
    
    train_loss = train(model, train_data_loader, optimizer, criterion, device)
    valid_loss = evaluate(model, valid_data_loader, criterion, device)
    
    scheduler.step(valid_loss)
    
    print(f'Train MSE: {train_loss:.4f} | Train RMSE: {np.sqrt(train_loss):.4f}\nValid MSE: {valid_loss:.4f} | Valid RMSE: {np.sqrt(valid_loss):.4f}')

    with open("output/DTI/log.txt", "a") as f:
        f.write(f"Epoch: {epoch:04d} Train MSE: {train_loss:.4f}, Train RMSE: {np.sqrt(train_loss):.4f}, Valid MSE: {valid_loss:.4f}, Valid RMSE: {np.sqrt(valid_loss):.4f}\n")

    if epoch % 5 == 0:
        print("Predictions ...\n")
        test_data_loader = generate_epoch_dataloader(test_data, 
                                                          molecule_tokenizer=molecule_tokenizer, 
                                                          molecule_seq_len=256,
                                                          protein_tokenizer=protein_tokenizer,
                                                          protein_seq_len=256,
                                                          batch_size=batch_size, 
                                                          collate_fn=collate_fn,
                                                          num_workers=8)

        test_loss = evaluate(model, test_data_loader, criterion, device)
        print(f"test loss: {test_loss:.4f}")
        molecule, protein, prediction, target = predict(model, test_data_loader, device)
        prediction_results = pd.DataFrame({"molecule": molecule,
                                           "protein": protein,
                                           "output": prediction, 
                                           "target": target})
        
        prediction_results.to_csv(f"output/DTI/prediction_results_epoch-{epoch:04d}_mse-{np.round(test_loss, 4)}.tsv", sep="\t", index=False)            
        
    if n_paitience < PAITIENCE:
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'weights/DTI_single_bert_best.pt')
            n_paitience = 0
        elif best_valid_loss <= valid_loss:
            n_paitience += 1
    else:
        print("Early stop!")
        model.load_state_dict(torch.load('weights/DTI_single_bert_best.pt'))
        model.eval()
        break

Epoch: 0000


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 2165/2165 [14:47<00:00,  2.44it/s]


Train MSE: 3.0925 | Train RMSE: 1.7586
Valid MSE: 2.4874 | Valid RMSE: 1.5771
Predictions ...



  return F.mse_loss(input, target, reduction=self.reduction)


test loss: 2.4884


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch: 0001


100%|██████████| 2165/2165 [14:38<00:00,  2.46it/s]


Train MSE: 2.5712 | Train RMSE: 1.6035
Valid MSE: 1.8110 | Valid RMSE: 1.3457
Epoch: 0002


100%|██████████| 2165/2165 [14:35<00:00,  2.47it/s]


Train MSE: 2.3999 | Train RMSE: 1.5492
Valid MSE: 1.7759 | Valid RMSE: 1.3326
Epoch: 0003


100%|██████████| 2165/2165 [14:35<00:00,  2.47it/s]


Train MSE: 2.3819 | Train RMSE: 1.5433
Valid MSE: 1.9545 | Valid RMSE: 1.3981
Epoch: 0004


100%|██████████| 2165/2165 [14:34<00:00,  2.48it/s]


Train MSE: 2.3702 | Train RMSE: 1.5396
Valid MSE: 1.7722 | Valid RMSE: 1.3312
Epoch: 0005


100%|██████████| 2165/2165 [14:34<00:00,  2.47it/s]


Train MSE: 2.3572 | Train RMSE: 1.5353
Valid MSE: 1.7835 | Valid RMSE: 1.3355
Predictions ...

test loss: 1.7891
Epoch: 0006


100%|██████████| 2165/2165 [14:32<00:00,  2.48it/s]


Train MSE: 2.3676 | Train RMSE: 1.5387
Valid MSE: 1.9581 | Valid RMSE: 1.3993
Epoch: 0007


100%|██████████| 2165/2165 [14:32<00:00,  2.48it/s]


Train MSE: 2.3391 | Train RMSE: 1.5294
Valid MSE: 1.8756 | Valid RMSE: 1.3695
Epoch: 0008


100%|██████████| 2165/2165 [14:28<00:00,  2.49it/s]


Train MSE: 2.3269 | Train RMSE: 1.5254
Valid MSE: 2.0609 | Valid RMSE: 1.4356
Epoch: 0009


100%|██████████| 2165/2165 [14:31<00:00,  2.49it/s]


Train MSE: 2.3308 | Train RMSE: 1.5267
Valid MSE: 1.8758 | Valid RMSE: 1.3696
Epoch: 0010


100%|██████████| 2165/2165 [14:32<00:00,  2.48it/s]


Train MSE: 2.2969 | Train RMSE: 1.5155
Valid MSE: 1.9795 | Valid RMSE: 1.4070
Predictions ...

test loss: 1.9830
Epoch: 0011


100%|██████████| 2165/2165 [14:32<00:00,  2.48it/s]


Train MSE: 2.2762 | Train RMSE: 1.5087
Valid MSE: 2.0100 | Valid RMSE: 1.4178
Epoch: 0012


 33%|███▎      | 720/2165 [04:50<09:42,  2.48it/s]