In [1]:
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors

# Define the alphabet for protein sequences
protein_alphabet = "ACDEFGHIKLMNPQRSTVWY"

# Create a tokenizer with a BPE model
tokenizer = Tokenizer(models.BPE())

# Define a pre-tokenizer that splits on each character
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

# Define a decoder
tokenizer.decoder = decoders.ByteLevel()

# Define a trainer with the protein alphabet
trainer = trainers.BpeTrainer(
    vocab_size=len(protein_alphabet) + 3,  # +3 for <pad>, <eos>, and <s>
    special_tokens=["<pad>", "<mask>", "<eos>", "<s>"],
    initial_alphabet=list(protein_alphabet)
)

# Train the tokenizer on a list of protein sequences
protein_sequences = ["ACDEFGHIKLMNPQRSTVWY", "ACDEFGHIKLMNPQRSTVWY", "ACDEFGHIKLMNPQRSTVWY"]
tokenizer.train_from_iterator(protein_sequences, trainer=trainer)

# Add post-processing to handle special tokens
tokenizer.post_processor = processors.TemplateProcessing(
    single="<s> $A <eos>",
    special_tokens=[
        ("<s>", tokenizer.token_to_id("<s>")),
        ("<eos>", tokenizer.token_to_id("<eos>")),
    ],
)

# Save the tokenizer
tokenizer.save("protein_tokenizer.json")

# Example usage
encoded = tokenizer.encode("ACDEFGHIKLMNPQRSTVWY")
print(encoded.tokens)
vocab = tokenizer.get_vocab()
print("Vocabulary:", vocab)
print("Vocabulary size:", len(vocab))


['<s>', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '<eos>']
Vocabulary: {'A': 4, 'I': 11, 'Y': 23, 'R': 18, 'V': 21, '<mask>': 1, '<eos>': 2, 'E': 7, 'G': 9, 'M': 14, 'Q': 17, 'C': 5, 'F': 8, 'L': 13, 'P': 16, 'D': 6, 'N': 15, 'W': 22, 'H': 10, 'S': 19, 'T': 20, 'K': 12, '<pad>': 0, '<s>': 3}
Vocabulary size: 24




In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset

BATCH_SIZE = 32

seq_length = 50

df = pd.read_csv("../ML4Proteins/01_Introduction/snake_venoms/Snake_Toxins_with_Function_Classes.csv")
# Delete entries where df['Sequence'] is longer than 600
df = df[df['Sequence'].str.len() <= seq_length-1]

# Tokenize the sequences
df['Tokenized Sequence'] = df['Sequence'].apply(lambda x: tokenizer.encode(x).tokens)

# Split the data into train, validation, and test sets
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

# Convert tokenized sequences to numerical IDs
def tokenize_sequence(sequence):
    return [tokenizer.token_to_id(token) for token in sequence]

train_sequences = train_df['Tokenized Sequence'].apply(tokenize_sequence).tolist()
val_sequences = val_df['Tokenized Sequence'].apply(tokenize_sequence).tolist()
test_sequences = test_df['Tokenized Sequence'].apply(tokenize_sequence).tolist()

# Convert to PyTorch tensors
train_tensors = [torch.tensor(seq, dtype=torch.long) for seq in train_sequences]
val_tensors = [torch.tensor(seq, dtype=torch.long) for seq in val_sequences]
test_tensors = [torch.tensor(seq, dtype=torch.long) for seq in test_sequences]

# Apply padding to the sequences till length 600
train_tensors = torch.nn.utils.rnn.pad_sequence(train_tensors, batch_first=True, padding_value=tokenizer.token_to_id("<pad>"))
val_tensors = torch.nn.utils.rnn.pad_sequence(val_tensors, batch_first=True, padding_value=tokenizer.token_to_id("<pad>"))
test_tensors = torch.nn.utils.rnn.pad_sequence(test_tensors, batch_first=True, padding_value=tokenizer.token_to_id("<pad>"))

# Ensure all sequences are of length 600
train_tensors = torch.nn.functional.pad(train_tensors, (0, seq_length - train_tensors.size(1)), value=tokenizer.token_to_id("<pad>"))
val_tensors = torch.nn.functional.pad(val_tensors, (0, seq_length - val_tensors.size(1)), value=tokenizer.token_to_id("<pad>"))
test_tensors = torch.nn.functional.pad(test_tensors, (0, seq_length - test_tensors.size(1)), value=tokenizer.token_to_id("<pad>"))

# Create datasets
train_dataset = TensorDataset(train_tensors)
val_dataset = TensorDataset(val_tensors)
test_dataset = TensorDataset(test_tensors)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print("Train dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))
print("Test dataset size:", len(test_dataset))



Train dataset size: 208
Validation dataset size: 26
Test dataset size: 27


In [3]:
import torch

import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, latent_dim, max_seq_length):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Encoder
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=8)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        
        # Latent layer
        self.latent_layer = nn.Sequential(
            nn.Linear(embedding_dim, latent_dim),
            nn.ReLU()
        )

        self.relatent_layer = nn.Sequential(
            nn.Linear(latent_dim, embedding_dim),
            nn.ReLU()
        )
        
        # Decoder
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=8)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=6)
        
        self.output_layer = nn.Linear(embedding_dim, vocab_size)
        self.max_seq_length = max_seq_length
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)

    def encode(self, x):
        x_embedded = self.embedding(x).transpose(0, 1)
        memory = self.encoder(x_embedded)
        memory = self.latent_layer(memory.mean(dim=0))
        return memory.unsqueeze(0)

    def decode(self, memory, max_length, start_token, eos_token):
        memory = self.relatent_layer(memory)
        batch_size = memory.size(1)
        device = memory.device
        decoded_tokens = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)
        logit_list = torch.full((batch_size, self.max_seq_length, vocab_size), -float('inf'), dtype=torch.float16, device=device)
        #logit_list[:, :, 0] = 
        #print(logit_list.shape)

        for t in range(max_length):
            tgt_emb = self.embedding(decoded_tokens).transpose(0, 1)
            out = self.decoder(tgt_emb, memory)
            logits = self.output_layer(out[-1])
            logit_list[:, t, :] = logits
            next_token = torch.argmax(logits, dim=1, keepdim=True)
            decoded_tokens = torch.cat([decoded_tokens, next_token], dim=1)

            if torch.all(next_token.eq(eos_token)):
                break

        return decoded_tokens, logit_list

    def forward(self, x, max_length=None, start_token=3, eos_token=2):
        if max_length is None:
            max_length = self.max_seq_length

        memory = self.encode(x)
        decoded_tokens, logits = self.decode(memory, max_length, start_token, eos_token)
        return decoded_tokens, logits

# Initialize the model
vocab_size = 24
embedding_dim = 128
latent_dim = 32
max_seq_length = 50  # Set a default value for max_seq_length

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerModel(vocab_size, embedding_dim, latent_dim, max_seq_length).to(device)

# Define the training parameters
num_epochs = 100
learning_rate = 1e-4
accumulation_steps = 4  # Number of steps to accumulate gradients

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    optimizer.zero_grad()
    for i, batch in tqdm(enumerate(train_loader)):
        inputs = batch[0].to(device)
        outputs, logits = model(inputs)
        


        # targets = torch.full((logits.shape[0], logits.shape[1], logits.shape[2]), fill_value=0, device=device, dtype=torch.float, requires_grad=False)
        # for batcn_ix in range(logits.shape[0]):
        #     for seq_ix in range(logits.shape[1]):
        #         targets[batcn_ix, seq_ix, inputs[batcn_ix, seq_ix]] = 1
        #print(outputs.shape)
        loss = criterion(logits.view(-1, vocab_size), inputs.view(-1))
        loss.backward()


        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}")

    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch[0].to(device)
            outputs, logits = model(inputs)
            
            # Reshape outputs and targets for the loss function

            # targets = torch.full((logits.shape[0], logits.shape[1], logits.shape[2]), fill_value=0, device=device, dtype=torch.float, requires_grad=False)
            # for batcn_ix in range(logits.shape[0]):
            #     for seq_ix in range(logits.shape[1]):
            #         targets[batcn_ix, seq_ix, inputs[batcn_ix, seq_ix]] = 1

            loss = criterion(logits.view(-1,vocab_size), inputs.view(-1))
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")

7it [00:10,  1.50s/it]


Epoch [1/100], Loss: 3.1164
Validation Loss: 3.0137


7it [00:10,  1.45s/it]


Epoch [2/100], Loss: 3.0117
Validation Loss: 3.0684


7it [00:07,  1.09s/it]


Epoch [3/100], Loss: 3.0028
Validation Loss: 3.0762


7it [00:07,  1.07s/it]


Epoch [4/100], Loss: 2.9813
Validation Loss: 3.0645


7it [00:06,  1.03it/s]


Epoch [5/100], Loss: 2.9727
Validation Loss: 3.0332


7it [00:07,  1.03s/it]


Epoch [6/100], Loss: 2.9478
Validation Loss: 3.0195


7it [00:07,  1.09s/it]


Epoch [7/100], Loss: 2.9392
Validation Loss: 3.0293


7it [00:09,  1.33s/it]


Epoch [8/100], Loss: 2.9280
Validation Loss: 3.0059


1it [00:02,  2.92s/it]


KeyboardInterrupt: 

In [46]:
with torch.no_grad():
    for batch in val_loader:
        inputs = batch[0].to(device)
        targets = inputs.clone()

        outputs, logits = model(inputs)
        print(outputs.shape)
        # Reshape outputs and targets for the loss function
        outputs = outputs
        targets = targets.view(-1)
        
        # Convert outputs to tokens
        predicted_tokens = [tokenizer.id_to_token(id.item()) for id in outputs.argmax(dim=1)]
        print(predicted_tokens)

        # Convert targets to tokens
        target_tokens = [tokenizer.id_to_token(id.item()) for id in targets]
        print(target_tokens)

torch.Size([26, 51])
['<mask>', 'C', 'E', None, 'F', 'N', None, '<mask>', 'A', 'S', 'H', None, 'G', 'L', 'E', None, '<s>', '<mask>', 'A', None, 'D', '<mask>', 'A', None, 'L', None]
['<s>', 'C', 'T', 'T', 'G', 'P', 'C', 'C', 'R', 'Q', 'C', 'K', 'L', 'K', 'P', 'A', 'G', 'T', 'T', 'C', 'W', 'K', 'T', 'S', 'R', 'T', 'S', 'H', 'Y', 'C', 'T', 'G', 'K', 'S', 'C', 'D', 'C', 'P', 'V', 'Y', 'Q', 'G', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<s>', 'L', 'V', 'S', 'V', 'S', 'P', 'A', 'F', 'N', 'G', 'N', 'Y', 'F', 'V', 'E', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<s>', 'N', 'L', 'L', 'Q', 'F', 'A', 'F', 'M', 'I', 'R', 'Q', 'A', 'N', 'K', 'R', 'R', 'R', 'P', 'V', 'I', 'P', 'Y', 'E', 'E', 'Y', 'G', 'L', 'Y',