In [4]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from tqdm import tqdm

In [5]:
train_sps = torch.load("data/train_sps.ids76.pt")
train_smile = torch.load("data/train_smile.ids68.pt")
train_log_ic50 = torch.load("data/train_ic50_log.pt")

test_sps = torch.load("data/test_sps.ids76.pt")
test_smile = torch.load("data/test_smile.ids68.pt")
test_log_ic50 = torch.load("data/test_ic50_log.pt")

In [6]:
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.key_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.query_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.value_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)
    def forward(self, query, key, value):
        scores = torch.matmul(self.query_layer(query), self.key_layer(key).transpose(-2,-1)) / (self.hidden_dim ** 0.5)
        weights = F.softmax(scores, dim=-1)
        weighted_values = torch.matmul(weights, self.value_layer(value))
        return weighted_values, weights

class RNNAttentionModel(nn.Module):
    def __init__(self, protein_vocab_size, drug_vocab_size, embedding_dim, hidden_dim, batch_size):
        super().__init__()
        self.protein_embedding = nn.Embedding(protein_vocab_size, embedding_dim)
        self.drug_embedding = nn.Embedding(drug_vocab_size, embedding_dim)
        self.protein_gru = nn.GRU(embedding_dim, hidden_dim, batch_first = True)
        self.drug_gru = nn.GRU(embedding_dim, hidden_dim, batch_first = True)
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)
        self.batch_size = batch_size
    def forward(self, protein_data, drug_data):
        protein_embedded = self.protein_embedding(protein_data)
        drug_embedded = self.drug_embedding(drug_data)

        _, protein_hidden = self.protein_gru(protein_embedded)
        _, drug_hidden = self.drug_gru(drug_embedded)

        attention_output, _ = self.attention(query= protein_hidden, key= drug_hidden, value = drug_hidden)

        output = self.fc(attention_output.squeeze(0))
        return output

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

protein_vocab_size = 76 # 토큰 어휘집의 어휘 개수
compound_vocab_size = 68 # 토큰 어휘집의 어휘 개수
protein_seq_length = 152 # 한 문장의 토큰 수
compound_seq_length = 100 # 한 문장의 토큰 수
embedding_dim = 256
hidden_dim = 128
batch_size = 64

In [8]:
model = RNNAttentionModel(protein_vocab_size, compound_vocab_size, embedding_dim, hidden_dim, batch_size)

criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.001)

model.train()

RNNAttentionModel(
  (protein_embedding): Embedding(76, 256)
  (drug_embedding): Embedding(68, 256)
  (protein_gru): GRU(256, 128, batch_first=True)
  (drug_gru): GRU(256, 128, batch_first=True)
  (attention): Attention(
    (key_layer): Linear(in_features=128, out_features=128, bias=False)
    (query_layer): Linear(in_features=128, out_features=128, bias=False)
    (value_layer): Linear(in_features=128, out_features=128, bias=False)
  )
  (fc): Linear(in_features=128, out_features=1, bias=True)
)

In [7]:
print(train_sps.shape)
print(train_smile.shape)
print(train_log_ic50.shape)

torch.Size([263583, 152])
torch.Size([263583, 100])
torch.Size([263583, 1])


In [8]:
batch_size = 64
shuffle = True

dataset = TensorDataset(train_sps, train_smile, train_log_ic50)
data_loader = DataLoader(dataset, batch_size= batch_size, shuffle=shuffle, pin_memory=True)

num_epochs = 10

# 훈련 시작
for epoch in range(num_epochs):
    model.to(device)
    model.train()  # 모델을 훈련 모드로 설정
    total_loss = 0.0  # 에폭별 총 손실을 추적
    
    # tqdm을 사용하여 진행 상황 막대 표시
    progress_bar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, (sps, smile, log_ic50) in progress_bar:
        sps = sps.to(device)
        smile = smile.to(device)
        log_ic50 = log_ic50.to(device)

        optimizer.zero_grad()  # 그라디언트 초기화
        
        output = model(sps, smile) # drug sequence에서 look ahead mask는 필요가 없음
        
        loss = criterion(output.squeeze(), log_ic50.float())
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # 진행률 막대에 현재 평균 손실 표시
        progress_bar.set_postfix({'avg_loss': total_loss / (batch_idx + 1)})
    
    avg_loss = total_loss / len(data_loader)
    print(f"\nEpoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
Epoch 1/10: 100%|██████████| 4119/4119 [01:38<00:00, 41.65it/s, avg_loss=2.24]



Epoch 1/10, Average Loss: 2.237226895031115


Epoch 2/10: 100%|██████████| 4119/4119 [01:39<00:00, 41.57it/s, avg_loss=2.21]



Epoch 2/10, Average Loss: 2.2109129904369382


Epoch 3/10: 100%|██████████| 4119/4119 [01:41<00:00, 40.68it/s, avg_loss=2.2] 



Epoch 3/10, Average Loss: 2.197214842710891


Epoch 4/10: 100%|██████████| 4119/4119 [01:35<00:00, 43.07it/s, avg_loss=2.19]



Epoch 4/10, Average Loss: 2.192933060787754


Epoch 5/10: 100%|██████████| 4119/4119 [01:40<00:00, 41.18it/s, avg_loss=2.19]



Epoch 5/10, Average Loss: 2.191176624979953


Epoch 6/10: 100%|██████████| 4119/4119 [01:38<00:00, 41.79it/s, avg_loss=2.19]



Epoch 6/10, Average Loss: 2.189303324129374


Epoch 7/10: 100%|██████████| 4119/4119 [01:23<00:00, 49.38it/s, avg_loss=2.19]



Epoch 7/10, Average Loss: 2.194268701897862


Epoch 8/10: 100%|██████████| 4119/4119 [01:34<00:00, 43.65it/s, avg_loss=2.19]



Epoch 8/10, Average Loss: 2.186499913417763


Epoch 9/10: 100%|██████████| 4119/4119 [01:40<00:00, 40.83it/s, avg_loss=2.19]



Epoch 9/10, Average Loss: 2.1875531939112696


Epoch 10/10: 100%|██████████| 4119/4119 [01:37<00:00, 42.39it/s, avg_loss=2.19]


Epoch 10/10, Average Loss: 2.185867531550223





In [10]:
test_dataset = TensorDataset(test_sps, test_smile, test_log_ic50)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
model.eval()

total_loss = 0.0
total_samples = 0

with torch.no_grad():
    for sps, smile, log_ic50 in test_loader:
        sps = sps.to(device)
        smile = smile.to(device)
        log_ic50 = log_ic50.to(device)

        output = model(sps, smile)
        loss = criterion(output.squeeze(), log_ic50.float())

        total_loss += loss.item() * sps.size(0)
        total_samples += sps.size(0)

avg_loss = total_loss / total_samples

print(f"Test Average Loss: {avg_loss**0.5}")

  return F.mse_loss(input, target, reduction=self.reduction)


Test Average Loss: 1.4563189104590484


  return F.mse_loss(input, target, reduction=self.reduction)
