In [1]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from tqdm import tqdm

In [3]:
train_sps = torch.load("data/train_sps.ids76.pt")
train_smile = torch.load("data/train_smile.ids68.pt")
train_log_ic50 = torch.load("data/train_ic50_log.pt")

test_sps = torch.load("data/test_sps.ids76.pt")
test_smile = torch.load("data/test_smile.ids68.pt")
test_log_ic50 = torch.load("data/test_ic50_log.pt")

In [20]:
class CNNModel(nn.Module):
    def __init__(self, protein_vocab_size, drug_vocab_size, embedding_dim, num_filters, filter_size, batch_size):
        super().__init__()
        self.protein_embedding = nn.Embedding(protein_vocab_size, embedding_dim)
        self.drug_embedding = nn.Embedding(drug_vocab_size, embedding_dim)
        self.protein_conv = nn.Conv1d(in_channels=embedding_dim, out_channels=num_filters, kernel_size=filter_size, padding='same')
        self.drug_conv = nn.Conv1d(in_channels=embedding_dim, out_channels=num_filters, kernel_size=filter_size, padding='same')
        self.fc = nn.Linear(num_filters*2, 1)  # Assuming you're applying some pooling/aggregation that keeps the dimension as num_filters
        self.batch_size = batch_size

    def forward(self, protein_data, drug_data):
        print("initial protein shape: "+str(protein_data.shape))
        print("initial drug shape: "+str(drug_data.shape))
        protein_embedded = self.protein_embedding(protein_data).permute(0, 2, 1)  # Change shape to (batch_size, embedding_dim, seq_length)
        drug_embedded = self.drug_embedding(drug_data).permute(0, 2, 1)  # Change shape to (batch_size, embedding_dim, seq_length)
        print("embedded protein: "+ str(protein_embedded.shape))
        print("embedded drug: "+ str(drug_embedded.shape))

        protein_conv_out = F.relu(self.protein_conv(protein_embedded))
        drug_conv_out = F.relu(self.drug_conv(drug_embedded))
        print("protein after conv: "+str(protein_conv_out.shape))
        print("drug after conv: "+str(drug_conv_out.shape))

        # Apply some form of pooling over the sequence dimension (e.g., max pooling)
        protein_pooled = F.max_pool1d(protein_conv_out, kernel_size=protein_conv_out.shape[2]).squeeze(2)
        drug_pooled = F.max_pool1d(drug_conv_out, kernel_size=drug_conv_out.shape[2]).squeeze(2)
        print("protein after pool: "+str(protein_pooled.shape))
        print("drug after pool: "+str(drug_pooled.shape))

        combined = torch.cat((protein_pooled, drug_pooled), dim=1)

        output = self.fc(combined)
        return output

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

protein_vocab_size = 76 # 토큰 어휘집의 어휘 개수
compound_vocab_size = 68 # 토큰 어휘집의 어휘 개수
protein_seq_length = 152 # 한 문장의 토큰 수
compound_seq_length = 100 # 한 문장의 토큰 수
embedding_dim = 256
num_filters = 128
filter_size = 5
batch_size = 64

In [22]:
model = CNNModel(protein_vocab_size, compound_vocab_size, embedding_dim, num_filters, filter_size, batch_size)

# Load the pretrained weights as tensors
protein_embedding_weights = torch.load('pretrain/pretrained_protein_model.pth')['embedding.weight']
compound_embedding_weights = torch.load('pretrain/pretrained_compound_model.pth')['embedding.weight']

# Directly assign the pretrained weights to the embedding layers
model.protein_embedding.weight.data = protein_embedding_weights
model.drug_embedding.weight.data = compound_embedding_weights

criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.001)

model.train()


CNNModel(
  (protein_embedding): Embedding(76, 256)
  (drug_embedding): Embedding(68, 256)
  (protein_conv): Conv1d(256, 128, kernel_size=(5,), stride=(1,), padding=same)
  (drug_conv): Conv1d(256, 128, kernel_size=(5,), stride=(1,), padding=same)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

In [23]:
print(train_sps.shape)
print(train_smile.shape)
print(train_log_ic50.shape)

torch.Size([263583, 152])
torch.Size([263583, 100])
torch.Size([263583, 1])


In [24]:
batch_size = 64
shuffle = True

dataset = TensorDataset(train_sps, train_smile, train_log_ic50)
data_loader = DataLoader(dataset, batch_size= batch_size, shuffle=shuffle, pin_memory=True)

num_epochs = 10

# 훈련 시작
for epoch in range(num_epochs):
    model.to(device)
    model.train()  # 모델을 훈련 모드로 설정
    total_loss = 0.0  # 에폭별 총 손실을 추적
    
    # tqdm을 사용하여 진행 상황 막대 표시
    progress_bar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, (sps, smile, log_ic50) in progress_bar:
        sps = sps.to(device)
        smile = smile.to(device)
        log_ic50 = log_ic50.to(device)

        optimizer.zero_grad()  # 그라디언트 초기화
        
        output = model(sps, smile) # drug sequence에서 look ahead mask는 필요가 없음
        
        loss = criterion(output.squeeze(), log_ic50.float())
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # 진행률 막대에 현재 평균 손실 표시
        progress_bar.set_postfix({'avg_loss': total_loss / (batch_idx + 1)})
    
    avg_loss = total_loss / len(data_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}\n")

  return F.mse_loss(input, target, reduction=self.reduction)
Epoch 1/10:   0%|          | 9/4119 [00:00<00:48, 84.95it/s, avg_loss=8.22]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   1%|          | 26/4119 [00:00<00:59, 69.29it/s, avg_loss=5.98]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   1%|          | 40/4119 [00:00<01:03, 64.21it/s, avg_loss=4.92]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   1%|▏         | 54/4119 [00:00<01:05, 62.41it/s, avg_loss=4.36]

protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embe

Epoch 1/10:   2%|▏         | 68/4119 [00:01<01:05, 61.96it/s, avg_loss=3.95]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   2%|▏         | 75/4119 [00:01<01:05, 61.81it/s, avg_loss=3.68]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   2%|▏         | 89/4119 [00:01<01:05, 61.28it/s, avg_loss=3.47]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   3%|▎         | 103/4119 [00:01<01:05, 61.33it/s, avg_loss=3.36]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   3%|▎         | 117/4119 [00:01<01:04, 61.67it/s, avg_loss=3.23]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   3%|▎         | 131/4119 [00:02<01:04, 61.52it/s, avg_loss=3.12]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   4%|▎         | 145/4119 [00:02<01:04, 61.65it/s, avg_loss=3.04]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   4%|▍         | 159/4119 [00:02<01:04, 61.27it/s, avg_loss=2.99]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   4%|▍         | 166/4119 [00:02<01:04, 61.33it/s, avg_loss=2.92]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   4%|▍         | 180/4119 [00:02<01:04, 61.22it/s, avg_loss=2.87]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   5%|▍         | 194/4119 [00:03<01:03, 61.43it/s, avg_loss=2.83]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   5%|▌         | 208/4119 [00:03<01:03, 61.44it/s, avg_loss=2.8] 

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   5%|▌         | 222/4119 [00:03<01:03, 61.37it/s, avg_loss=2.77]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   6%|▌         | 236/4119 [00:03<01:03, 61.14it/s, avg_loss=2.74]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   6%|▌         | 250/4119 [00:04<01:03, 61.16it/s, avg_loss=2.71]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   6%|▌         | 257/4119 [00:04<01:03, 61.08it/s, avg_loss=2.7] 

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   7%|▋         | 271/4119 [00:04<01:02, 61.23it/s, avg_loss=2.68]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   7%|▋         | 285/4119 [00:04<01:02, 61.55it/s, avg_loss=2.65]

initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

Epoch 1/10:   7%|▋         | 297/4119 [00:04<01:01, 61.95it/s, avg_loss=2.65]


initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
drug after pool: torch.Size([64, 128])
initial protein shape: torch.Size([64, 152])
initial drug shape: torch.Size([64, 100])
embedded protein: torch.Size([64, 256, 152])
embedded drug: torch.Size([64, 256, 100])
protein after conv: torch.Size([64, 128, 152])
drug after conv: torch.Size([64, 128, 100])
protein after pool: torch.Size([64, 128])
d

KeyboardInterrupt: 

In [14]:
test_dataset = TensorDataset(test_sps, test_smile, test_log_ic50)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
model.eval()

total_loss = 0.0
total_samples = 0

with torch.no_grad():
    for sps, smile, log_ic50 in test_loader:
        sps = sps.to(device)
        smile = smile.to(device)
        log_ic50 = log_ic50.to(device)

        output = model(sps, smile)
        loss = criterion(output.squeeze(), log_ic50.float())

        total_loss += loss.item() * sps.size(0)
        total_samples += sps.size(0)

avg_loss = total_loss / total_samples

print(f"Test Average Loss: {avg_loss**0.5}")

Test Average Loss: 1.47486864818892


  return F.mse_loss(input, target, reduction=self.reduction)
