In [13]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.optim import Adam
import torch.nn.functional as F
from tqdm import tqdm

In [3]:
train_sps = torch.load("../data/train_sps.ids76.pt")
train_smile = torch.load("../data/train_smile.ids68.pt")

In [4]:
class PretrainRNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)  # Output layer to predict the next token

    def forward(self, sequence_data):
        embedded = self.embedding(sequence_data)
        rnn_out, _ = self.rnn(embedded)
        output = self.fc(rnn_out)
        return output

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

protein_vocab_size = 76 # 토큰 어휘집의 어휘 개수
compound_vocab_size = 68 # 토큰 어휘집의 어휘 개수
protein_seq_length = 152 # 한 문장의 토큰 수
compound_seq_length = 100 # 한 문장의 토큰 수
embedding_dim = 256
hidden_dim = 128

In [6]:
sps_model = PretrainRNNModel(protein_vocab_size, embedding_dim, hidden_dim)

criterion = nn.MSELoss()
optimizer = Adam(sps_model.parameters(), lr=0.001)

sps_model.train()

PretrainRNNModel(
  (embedding): Embedding(76, 256)
  (rnn): GRU(256, 128, batch_first=True)
  (fc): Linear(in_features=128, out_features=76, bias=True)
)

In [7]:
def pretrain_model(model, device, data_loader, optimizer, criterion, num_epochs):
    model.train()  # Set model to training mode
    model.to(device)
    for epoch in range(num_epochs):
        total_loss = 0.0
        progress_bar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch_idx, (seq_data, targets) in progress_bar:
            seq_data, targets = seq_data.to(device), targets.to(device)

            optimizer.zero_grad()
            output = model(seq_data)
            loss = criterion(output.transpose(1, 2), targets)  # Adjust dimensions as needed for CrossEntropyLoss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({'avg_loss': total_loss / (batch_idx + 1)})

        avg_loss = total_loss / len(data_loader)
        print(f"\nEpoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

protein_vocab_size = 76  # Adjust based on your dataset
embedding_dim = 256
hidden_dim = 128
compound_vocab_size = 68  # Adjust based on your dataset

pretrain_protein_model = PretrainRNNModel(protein_vocab_size, embedding_dim, hidden_dim)
pretrain_compound_model = PretrainRNNModel(compound_vocab_size, embedding_dim, hidden_dim)
optimizer = Adam(pretrain_protein_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()  # Suitable for next token prediction tasks

In [14]:
num_epochs = 10
batch_size=64
shuffle = True

class NextTokenPredictionDataset(Dataset):
    def __init__(self, sequences, sequence_length):
        # sequences should be a list of lists or a 2D tensor where each inner list is a sequence
        self.sequences = [seq[:sequence_length + 1] for seq in sequences]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        # Input sequence is all but the last token
        input_seq = seq[:-1]
        # Target sequence is all but the first token
        target_seq = seq[1:]
        return torch.tensor(input_seq, dtype=torch.long), torch.tensor(target_seq, dtype=torch.long)

# Assuming train_sps is a list of lists or a 2D tensor of your sequence data
# and you have defined an appropriate sequence_length
pretraining_dataset = NextTokenPredictionDataset(train_sps, sequence_length=100)  # Adjust sequence_length as needed

pretraining_data_loader = DataLoader(pretraining_dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True)

pretrain_model(pretrain_protein_model, device, pretraining_data_loader, optimizer, criterion, num_epochs)

  return torch.tensor(input_seq, dtype=torch.long), torch.tensor(target_seq, dtype=torch.long)
Epoch 1/10: 100%|██████████| 4119/4119 [01:04<00:00, 64.10it/s, avg_loss=0.577] 



Epoch 1/10, Average Loss: 0.5773089828313801


Epoch 2/10: 100%|██████████| 4119/4119 [00:58<00:00, 70.03it/s, avg_loss=0.354] 



Epoch 2/10, Average Loss: 0.354042480078158


Epoch 3/10: 100%|██████████| 4119/4119 [01:01<00:00, 66.86it/s, avg_loss=0.35]  



Epoch 3/10, Average Loss: 0.34995898694834626


Epoch 4/10: 100%|██████████| 4119/4119 [01:05<00:00, 62.47it/s, avg_loss=0.352]



Epoch 4/10, Average Loss: 0.3518570097550862


Epoch 5/10: 100%|██████████| 4119/4119 [01:03<00:00, 65.01it/s, avg_loss=0.361]



Epoch 5/10, Average Loss: 0.3611665142229327


Epoch 6/10: 100%|██████████| 4119/4119 [01:03<00:00, 64.97it/s, avg_loss=0.378] 



Epoch 6/10, Average Loss: 0.3783399680917429


Epoch 7/10: 100%|██████████| 4119/4119 [01:05<00:00, 62.56it/s, avg_loss=0.386]



Epoch 7/10, Average Loss: 0.3856317714764731


Epoch 8/10: 100%|██████████| 4119/4119 [01:04<00:00, 64.24it/s, avg_loss=0.384]



Epoch 8/10, Average Loss: 0.38446381396172435


Epoch 9/10: 100%|██████████| 4119/4119 [01:04<00:00, 63.87it/s, avg_loss=0.394]



Epoch 9/10, Average Loss: 0.39378383127657296


Epoch 10/10: 100%|██████████| 4119/4119 [01:04<00:00, 63.52it/s, avg_loss=0.424]


Epoch 10/10, Average Loss: 0.4238755286253607





In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

embedding_dim = 256
hidden_dim = 128
compound_vocab_size = 68  # Adjust based on your dataset

pretrain_compound_model = PretrainRNNModel(compound_vocab_size, embedding_dim, hidden_dim)
optimizer = Adam(pretrain_compound_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()  # Suitable for next token prediction tasks

In [18]:
# Assuming train_sps is a list of lists or a 2D tensor of your sequence data
# and you have defined an appropriate sequence_length
pretraining_dataset = NextTokenPredictionDataset(train_smile, sequence_length=100)  # Adjust sequence_length as needed

pretraining_data_loader = DataLoader(pretraining_dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True)

pretrain_model(pretrain_compound_model, device, pretraining_data_loader, optimizer, criterion, num_epochs)

  return torch.tensor(input_seq, dtype=torch.long), torch.tensor(target_seq, dtype=torch.long)
Epoch 1/10: 100%|██████████| 4119/4119 [01:05<00:00, 62.43it/s, avg_loss=0.474]



Epoch 1/10, Average Loss: 0.47414459942555826


Epoch 2/10: 100%|██████████| 4119/4119 [00:44<00:00, 91.66it/s, avg_loss=0.393] 



Epoch 2/10, Average Loss: 0.3928944478989458


Epoch 3/10: 100%|██████████| 4119/4119 [01:06<00:00, 61.93it/s, avg_loss=0.378]



Epoch 3/10, Average Loss: 0.37776238061207307


Epoch 4/10: 100%|██████████| 4119/4119 [01:06<00:00, 61.86it/s, avg_loss=0.37] 



Epoch 4/10, Average Loss: 0.36985174314898756


Epoch 5/10: 100%|██████████| 4119/4119 [01:04<00:00, 63.66it/s, avg_loss=0.368]



Epoch 5/10, Average Loss: 0.36753168150069904


Epoch 6/10: 100%|██████████| 4119/4119 [00:57<00:00, 71.93it/s, avg_loss=0.361] 



Epoch 6/10, Average Loss: 0.3607235517814976


Epoch 7/10: 100%|██████████| 4119/4119 [01:04<00:00, 63.47it/s, avg_loss=0.358]



Epoch 7/10, Average Loss: 0.3580589554037218


Epoch 8/10: 100%|██████████| 4119/4119 [01:06<00:00, 62.26it/s, avg_loss=0.355]



Epoch 8/10, Average Loss: 0.355190786285926


Epoch 9/10: 100%|██████████| 4119/4119 [01:05<00:00, 63.13it/s, avg_loss=0.353] 



Epoch 9/10, Average Loss: 0.35313728367075464


Epoch 10/10: 100%|██████████| 4119/4119 [01:05<00:00, 62.71it/s, avg_loss=0.36] 


Epoch 10/10, Average Loss: 0.3595956486386852





In [19]:
torch.save(pretrain_protein_model.state_dict(), 'pretrained_protein_model.pth')
torch.save(pretrain_compound_model.state_dict(), 'pretrained_compound_model.pth')

In [20]:
pretrain_protein_model.state_dict()

OrderedDict([('embedding.weight',
              tensor([[ 0.1055,  1.2616, -1.3919,  ...,  0.3435,  2.3587,  1.6831],
                      [-0.2234,  1.3583,  0.8768,  ..., -2.2165,  0.8472,  0.9924],
                      [ 0.4026, -1.0388,  1.0159,  ..., -0.3748,  0.9654,  0.9290],
                      ...,
                      [-2.6123, -0.5088,  1.0939,  ..., -1.2293,  1.2940,  0.0108],
                      [ 0.3689, -0.4273, -1.8101,  ...,  0.2532,  0.2451, -0.4428],
                      [-0.7244, -0.0718, -1.0373,  ...,  0.4078,  0.9507,  3.3401]],
                     device='cuda:0')),
             ('rnn.weight_ih_l0',
              tensor([[-0.0991,  0.0993,  0.4668,  ...,  0.1583, -0.0726, -0.1771],
                      [-0.0750,  0.1386, -0.1966,  ...,  0.2892,  0.1875, -0.1490],
                      [ 0.1319,  0.0021,  0.0430,  ...,  0.0715,  0.1804,  0.0706],
                      ...,
                      [-0.1801,  0.0707, -0.0700,  ...,  0.0426, -0.0005,  0.1849