In [1]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam
import torch.nn.functional as F

In [2]:
train_sps = torch.load("data/train_sps.ids76.pt")
train_smile = torch.load("data/train_smile.ids68.pt")
train_log_ic50 = torch.load("data/train_ic50_log.pt")

test_sps = torch.load("data/test_sps.ids76.pt")
test_smile = torch.load("data/test_smile.ids68.pt")
test_log_ic50 = torch.load("data/test_ic50_log.pt")

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

class PositionalEncoding(nn.Module):
    def __init__(self, position, d_model):
        super().__init__()
        self.pos_encoding = self.positional_encoding(position, d_model).to(device)
    
    def get_angles(self, position, i, d_model):
        angles = 1/ torch.pow(10000, (2 * (i//2)) / d_model)
        return position * angles
    
    def positional_encoding(self, position, d_model):
        angle_rads = self.get_angles(
            position = torch.arange(position, device=device).unsqueeze(1),
            i = torch.arange(d_model, device=device).unsqueeze(0),
            d_model = d_model
        )
        sines = torch.sin(angle_rads[:, 0::2])
        cosines = torch.cos(angle_rads[:, 1::2])
        angle_rads = torch.zeros_like(angle_rads)
        angle_rads[:, 0::2] = sines
        angle_rads[:, 1::2] = cosines
        pos_encoding = angle_rads.unsqueeze(0)

        return pos_encoding
    def forward(self, inputs):
        return inputs + self.pos_encoding[:, :inputs.shape[1],:]
    

def scaled_dot_product_attention(query, key, value, mask=None):
    matmul_qk = torch.matmul(query, key.transpose(-2,-1))

    depth = query.size(-1)
    logits = matmul_qk / torch.sqrt(torch.tensor(depth, dtype=torch.float32))

    if mask is not None:
        print("mask shape")
        print(mask.shape)
        print("logits shape")
        print(logits.shape)
        logits += (mask * -1e9)
    
    attention_weights = F.softmax(logits, dim=-1)

    output = torch.matmul(attention_weights, value)
    return output, attention_weights

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        self.depth = d_model // self.num_heads

        self.query_dense = nn.Linear(d_model, d_model)
        self.key_dense = nn.Linear(d_model, d_model)
        self.value_dense = nn.Linear(d_model, d_model)

        self.dense = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0,2,1,3)

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        query = self.query_dense(query)
        key = self.key_dense(key)
        value = self.value_dense(value)

        query = self.split_heads(query, batch_size)
        key = self.split_heads(key, batch_size)
        value = self.split_heads(value, batch_size)

        scaled_attention, _ = scaled_dot_product_attention(query, key, value, mask)
        scaled_attention = scaled_attention.permute(0,2,1,3)
        
        concat_attention = scaled_attention.contiguous().view(batch_size, -1, self.d_model)

        outputs = self.dense(concat_attention)
        return outputs

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, dropout):
        super().__init__()
    
        self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model, eps=1e-6)

        self.dense1 = nn.Linear(d_model, dff)
        self.dense2 = nn.Linear(dff, d_model)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-6)
    def forward(self, x, padding_mask):
        attn_output = self.multi_head_attention(x,x,x,padding_mask)
        attn_output = self.dropout1(attn_output)
        out1 = self.norm1(x+attn_output)
        
        ffn_output = F.relu(self.dense1(out1))
        ffn_output = self.dense2(ffn_output)
        ffn_output = self.dropout2(ffn_output)
        out2 = self.norm2(out1 + ffn_output)
        return out2


class Encoder(nn.Module):
    def __init__(self, vocab_size, seq_length, num_layers, dff, d_model, num_heads, dropout):
        super().__init__()
        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(seq_length, d_model)
        self.dropout = nn.Dropout(dropout)

        self.enc_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, dff, dropout) for _ in range(num_layers)])
        
    def forward(self, x, padding_mask):
        x = self.embedding(x)
        x *= torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        x = self.pos_encoding(x)
        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x, padding_mask)
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, dropout):
        super().__init__()
        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model, eps=1e-6)

        self.mha2 = MultiHeadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-6)

        self.dense1 = nn.Linear(d_model, dff)
        self.dense2 = nn.Linear(dff, d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.norm3 = nn.LayerNorm(d_model, eps=1e-6)
    
    def forward(self, x, enc_output, look_ahead_mask, padding_mask):
        attn1 = self.mha1(x,x,x,look_ahead_mask)
        attn1 = self.norm1(attn1 + x)

        attn2 = self.mha2(attn1, enc_output, enc_output, padding_mask)
        attn2 = self.dropout1(attn2)
        attn2 = self.norm2(attn2 + attn1)

        ffn_output = F.relu(self.dense1(attn2))
        ffn_output = self.dense2(ffn_output)
        ffn_output = self.dropout2(ffn_output)
        out = self.norm3(ffn_output + attn2)
        return out

class Decoder(nn.Module):
    def __init__(self, vocab_size, seq_length, num_layers, dff, d_model, num_heads, dropout):
        super().__init__()
        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(seq_length, d_model)
        self.dropout = nn.Dropout(dropout)

        self.dec_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, dff, dropout)
            for _ in range(num_layers)
        ])
    def forward(self, x, enc_output, look_ahead_mask, padding_mask):

        x = self.embedding(x)
        x *= torch.sqrt(torch.tensor(self.d_model, dtype=torch.float))
        x += self.pos_encoding(x)
        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.dec_layers[i](x, enc_output, look_ahead_mask, padding_mask)
        return x

class Transformer(nn.Module):
    def __init__(self, protein_vocab_size, compound_vocab_size, protein_seq_length, compound_seq_length, num_layers, dff, protein_embedding_dim, compound_embedding_dim, num_heads, dropout):
        super().__init__()
        self.encoder = Encoder(protein_vocab_size, protein_seq_length, num_layers, dff, protein_embedding_dim, num_heads, dropout)
        self.layer1 = nn.Linear(protein_embedding_dim, compound_embedding_dim)  # 인코더와 디코더의 임베딩 차원 다른 문제 해결
        self.decoder = Decoder(compound_vocab_size, compound_seq_length, num_layers, dff, compound_embedding_dim, num_heads, dropout)
        self.final_layer = nn.Linear(compound_embedding_dim, 1)  # 최종적으로 IC50 값을 예측하기 위한 레이어

    def forward(self, inp, tar, enc_padding_mask, look_ahead_mask, dec_padding_mask):
        enc_output = self.encoder(inp, enc_padding_mask)
        inter_output1 = self.layer1(enc_output)
        dec_output = self.decoder(tar, inter_output1, look_ahead_mask, dec_padding_mask)

        # 디코더 출력에서 시퀀스 길이에 대해 평균을 내어 (batch_size, compound_embedding_dim) 모양을 얻습니다.
        pooled_output = torch.mean(dec_output, dim=1)

        # 평균 풀링된 출력을 최종 레이어에 통과시켜 IC50 값을 예측합니다.
        final_output = self.final_layer(pooled_output)  # (batch_size, 1) 모양의 출력
        return final_output


def create_padding_mask(x):
    # x와 0이 같은지 비교하여 마스크 생성 (x가 0이면 True, 아니면 False)
    mask = torch.eq(x, 0).float()
    # (batch_size, 1, 1, key의 문장 길이) 형태로 차원 변경
    return mask.unsqueeze(1).unsqueeze(2)


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

protein_vocab_size = 76 # 토큰 어휘집의 어휘 개수
compound_vocab_size = 68 # 토큰 어휘집의 어휘 개수
num_layers = 1 # 인코더, 디코더 layer 수
dff = 128
num_heads = 2
dropout = 0
protein_seq_length = 152 # 한 문장의 토큰 수
compound_seq_length = 100 # 한 문장의 토큰 수
protein_embedding_dim = 256
compound_embedding_dim = 256
batch_size = 64

transformer_model = Transformer(protein_vocab_size, compound_vocab_size, protein_seq_length, compound_seq_length, num_layers, dff, protein_embedding_dim, compound_embedding_dim, num_heads, dropout).to(device)

# Load the pretrained weights as tensors
protein_embedding_weights = torch.load('pretrain/pretrained_protein_model.pth')['embedding.weight']
compound_embedding_weights = torch.load('pretrain/pretrained_compound_model.pth')['embedding.weight']

# Directly assign the pretrained weights to the embedding layers
transformer_model.encoder.embedding.data = protein_embedding_weights
transformer_model.decoder.embedding.data = compound_embedding_weights

criterion = nn.MSELoss()
optimizer = Adam(transformer_model.parameters(), lr=0.0001)

transformer_model.train()

Transformer(
  (encoder): Encoder(
    (embedding): Embedding(76, 256)
    (pos_encoding): PositionalEncoding()
    (dropout): Dropout(p=0, inplace=False)
    (enc_layers): ModuleList(
      (0): EncoderLayer(
        (multi_head_attention): MultiHeadAttention(
          (query_dense): Linear(in_features=256, out_features=256, bias=True)
          (key_dense): Linear(in_features=256, out_features=256, bias=True)
          (value_dense): Linear(in_features=256, out_features=256, bias=True)
          (dense): Linear(in_features=256, out_features=256, bias=True)
        )
        (dropout1): Dropout(p=0, inplace=False)
        (norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        (dense1): Linear(in_features=256, out_features=128, bias=True)
        (dense2): Linear(in_features=128, out_features=256, bias=True)
        (dropout2): Dropout(p=0, inplace=False)
        (norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
      )
    )
  )
  (layer1): Linear(in_feat

In [10]:
print(train_sps.shape)
print(train_smile.shape)
print(train_log_ic50.shape)

torch.Size([263583, 152])
torch.Size([263583, 100])
torch.Size([263583, 1])


In [11]:
from tqdm import tqdm

# DataLoader 설정
batch_size = 64
shuffle = True

dataset = TensorDataset(train_sps, train_smile, train_log_ic50)  # 종속 변수를 포함합니다.
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True)

# 에폭 수 설정
num_epochs = 10

# 훈련 시작
for epoch in range(num_epochs):
    transformer_model.to(device)
    transformer_model.train()  # 모델을 훈련 모드로 설정
    total_loss = 0.0  # 에폭별 총 손실을 추적
    
    # tqdm을 사용하여 진행 상황 막대 표시
    progress_bar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, (sps, smile, log_ic50) in progress_bar:
        sps = sps.to(device)
        smile = smile.to(device)
        log_ic50 = log_ic50.to(device)

        sps_mask = create_padding_mask(sps)
        smile_mask = create_padding_mask(sps)

        optimizer.zero_grad()  # 그라디언트 초기화
        
        output = transformer_model(sps, smile, sps_mask, None, smile_mask) # drug sequence에서 look ahead mask는 필요가 없음
        
        loss = criterion(output.squeeze(), log_ic50.float())
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # 진행률 막대에 현재 평균 손실 표시
        progress_bar.set_postfix({'avg_loss': total_loss / (batch_idx + 1)})
    
    avg_loss = total_loss / len(data_loader)
    print(f"\nEpoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")


Epoch 1/10:   0%|          | 0/4119 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 6/4119 [00:00<02:29, 27.47it/s, avg_loss=41.2]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch

Epoch 1/10:   0%|          | 12/4119 [00:00<02:25, 28.15it/s, avg_loss=34.5]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   0%|          | 18/4119 [00:00<02:31, 27.06it/s, avg_loss=28]  

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   1%|          | 24/4119 [00:00<02:29, 27.37it/s, avg_loss=22.7]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   1%|          | 30/4119 [00:01<02:30, 27.20it/s, avg_loss=18.7]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   1%|          | 36/4119 [00:01<02:31, 26.94it/s, avg_loss=16]  

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   1%|          | 42/4119 [00:01<02:27, 27.56it/s, avg_loss=14.1]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   1%|          | 48/4119 [00:01<02:27, 27.54it/s, avg_loss=12.7]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   1%|▏         | 54/4119 [00:01<02:28, 27.38it/s, avg_loss=11.5]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   1%|▏         | 60/4119 [00:02<02:29, 27.14it/s, avg_loss=10.6]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   2%|▏         | 66/4119 [00:02<02:27, 27.53it/s, avg_loss=9.8] 

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   2%|▏         | 72/4119 [00:02<02:32, 26.57it/s, avg_loss=9.18]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   2%|▏         | 79/4119 [00:02<02:15, 29.72it/s, avg_loss=8.56]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch

Epoch 1/10:   2%|▏         | 85/4119 [00:03<02:20, 28.66it/s, avg_loss=8.11]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   2%|▏         | 91/4119 [00:03<02:23, 28.02it/s, avg_loss=7.7] 

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   2%|▏         | 97/4119 [00:03<02:25, 27.62it/s, avg_loss=7.35]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   3%|▎         | 103/4119 [00:03<02:27, 27.29it/s, avg_loss=7.06]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   3%|▎         | 109/4119 [00:03<02:25, 27.50it/s, avg_loss=6.78]

mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


Epoch 1/10:   3%|▎         | 113/4119 [00:04<02:25, 27.50it/s, avg_loss=6.62]


mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 152, 152])
mask shape
torch.Size([64, 1, 1, 152])
logits shape
torch.Size([64, 2, 100, 152])


KeyboardInterrupt: 

In [None]:
test_dataset = TensorDataset(test_sps, test_smile, test_log_ic50)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
transformer_model.eval()

total_loss = 0.0
total_samples = 0

with torch.no_grad():
    for sps, smile, log_ic50 in test_loader:
        sps = sps.to(device)
        smile = smile.to(device)
        log_ic50 = log_ic50.to(device)

        sps_mask = create_padding_mask(sps)
        smile_mask = create_padding_mask(sps)

        output = transformer_model(sps, smile, sps_mask, None, smile_mask)
        loss = criterion(output.squeeze(), log_ic50.float())

        total_loss += loss.item() * sps.size(0)
        total_samples += sps.size(0)

avg_loss = total_loss / total_samples

print(f"Test Average Loss: {avg_loss**0.5}")

Test Average Loss: 1.4736469710761064
