In [37]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from tqdm import tqdm

In [38]:
train_sps = torch.load("data/train_sps.ids76.pt")
train_smile = torch.load("data/train_smile.ids68.pt")
train_log_ic50 = torch.load("data/train_ic50_log.pt")

test_sps = torch.load("data/test_sps.ids76.pt")
test_smile = torch.load("data/test_smile.ids68.pt")
test_log_ic50 = torch.load("data/test_ic50_log.pt")

In [39]:
class RNNModel(nn.Module):
    def __init__(self, protein_vocab_size, drug_vocab_size, embedding_dim, hidden_dim, batch_size):
        super().__init__()
        self.protein_embedding = nn.Embedding(protein_vocab_size, embedding_dim)
        self.drug_embedding = nn.Embedding(drug_vocab_size, embedding_dim)
        self.protein_gru = nn.GRU(embedding_dim, hidden_dim, batch_first = True)
        self.drug_gru = nn.GRU(embedding_dim, hidden_dim, batch_first = True)
        self.fc = nn.Linear(hidden_dim*2, 1)
        self.batch_size = batch_size
    def forward(self, protein_data, drug_data):
        protein_embedded = self.protein_embedding(protein_data)
        drug_embedded = self.drug_embedding(drug_data)

        _, protein_hidden = self.protein_gru(protein_embedded)
        _, drug_hidden = self.drug_gru(drug_embedded)

        combined = torch.cat((protein_hidden, drug_hidden), dim=2)

        output = self.fc(combined.squeeze(0))
        return output

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

protein_vocab_size = 76 # 토큰 어휘집의 어휘 개수
compound_vocab_size = 68 # 토큰 어휘집의 어휘 개수
protein_seq_length = 152 # 한 문장의 토큰 수
compound_seq_length = 100 # 한 문장의 토큰 수
embedding_dim = 256
hidden_dim = 128
batch_size = 64

In [41]:
model = RNNModel(protein_vocab_size, compound_vocab_size, embedding_dim, hidden_dim, batch_size)

# Load the pretrained weights as tensors
protein_embedding_weights = torch.load('pretrain/pretrained_protein_model.pth')['embedding.weight']
compound_embedding_weights = torch.load('pretrain/pretrained_compound_model.pth')['embedding.weight']

# Directly assign the pretrained weights to the embedding layers
model.protein_embedding.weight.data = protein_embedding_weights
model.drug_embedding.weight.data = compound_embedding_weights

criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.001)

model.train()


RNNModel(
  (protein_embedding): Embedding(76, 256)
  (drug_embedding): Embedding(68, 256)
  (protein_gru): GRU(256, 128, batch_first=True)
  (drug_gru): GRU(256, 128, batch_first=True)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

In [42]:
print(train_sps.shape)
print(train_smile.shape)
print(train_log_ic50.shape)

torch.Size([263583, 152])
torch.Size([263583, 100])
torch.Size([263583, 1])


In [43]:
batch_size = 64
shuffle = True

dataset = TensorDataset(train_sps, train_smile, train_log_ic50)
data_loader = DataLoader(dataset, batch_size= batch_size, shuffle=shuffle, pin_memory=True)

num_epochs = 10

# 훈련 시작
for epoch in range(num_epochs):
    model.to(device)
    model.train()  # 모델을 훈련 모드로 설정
    total_loss = 0.0  # 에폭별 총 손실을 추적
    
    # tqdm을 사용하여 진행 상황 막대 표시
    progress_bar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, (sps, smile, log_ic50) in progress_bar:
        sps = sps.to(device)
        smile = smile.to(device)
        log_ic50 = log_ic50.to(device)

        optimizer.zero_grad()  # 그라디언트 초기화
        
        output = model(sps, smile) # drug sequence에서 look ahead mask는 필요가 없음
        
        loss = criterion(output.squeeze(), log_ic50.float())
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # 진행률 막대에 현재 평균 손실 표시
        progress_bar.set_postfix({'avg_loss': total_loss / (batch_idx + 1)})
    
    avg_loss = total_loss / len(data_loader)
    print(f"\nEpoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")

Epoch 1/10:   0%|          | 0/4119 [00:00<?, ?it/s]

Epoch 1/10:  10%|▉         | 392/4119 [00:08<01:20, 46.03it/s, avg_loss=2.44]


KeyboardInterrupt: 

In [None]:
test_dataset = TensorDataset(test_sps, test_smile, test_log_ic50)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
model.eval()

total_loss = 0.0
total_samples = 0

with torch.no_grad():
    for sps, smile, log_ic50 in test_loader:
        sps = sps.to(device)
        smile = smile.to(device)
        log_ic50 = log_ic50.to(device)

        output = model(sps, smile)
        loss = criterion(output.squeeze(), log_ic50.float())

        total_loss += loss.item() * sps.size(0)
        total_samples += sps.size(0)

avg_loss = total_loss / total_samples

print(f"Test Average Loss: {avg_loss**0.5}")

Test Average Loss: 1.472918256792171


  return F.mse_loss(input, target, reduction=self.reduction)
