In [1]:
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors

# Define the alphabet for protein sequences
protein_alphabet = "ACDEFGHIKLMNPQRSTVWY"

# Create a tokenizer with a BPE model
tokenizer = Tokenizer(models.BPE())

# Define a pre-tokenizer that splits on each character
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

# Define a decoder
tokenizer.decoder = decoders.ByteLevel()

# Define a trainer with the protein alphabet
trainer = trainers.BpeTrainer(
    vocab_size=len(protein_alphabet) + 2,  # +2 for <pad> and <eos>
    special_tokens=["<pad>", "<mask>", "<eos>"],
    initial_alphabet=list(protein_alphabet)
)

# Train the tokenizer on a list of protein sequences
protein_sequences = ["ACDEFGHIKLMNPQRSTVWY", "ACDEFGHIKLMNPQRSTVWY", "ACDEFGHIKLMNPQRSTVWY"]
tokenizer.train_from_iterator(protein_sequences, trainer=trainer)

# Add post-processing to handle special tokens
tokenizer.post_processor = processors.TemplateProcessing(
    single="$A <eos>",
    special_tokens=[
        ("<eos>", tokenizer.token_to_id("<eos>")),

    ],
)

# Save the tokenizer
tokenizer.save("protein_tokenizer.json")

# Example usage
encoded = tokenizer.encode("ACDEFGHIKLMNPQRSTVWY")
print(encoded.tokens)
vocab = tokenizer.get_vocab()
print("Vocabulary:", vocab)




['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '<eos>']
Vocabulary: {'F': 7, 'M': 13, 'N': 14, '<eos>': 2, 'T': 19, 'E': 6, 'G': 8, 'L': 12, '<mask>': 1, 'Q': 16, 'V': 20, 'W': 21, 'H': 9, 'C': 4, 'I': 10, '<pad>': 0, 'D': 5, 'A': 3, 'P': 15, 'R': 17, 'Y': 22, 'K': 11, 'S': 18}


In [5]:
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, TensorDataset

BATCH_SIZE = 4

seq_length = 50

df = pd.read_csv("../ML4Proteins/01_Introduction/snake_venoms/Snake_Toxins_with_Function_Classes.csv")
# Delete entries where df['Sequence'] is longer than 600
df = df[df['Sequence'].str.len() <= seq_length-1]

# Tokenize the sequences
df['Tokenized Sequence'] = df['Sequence'].apply(lambda x: tokenizer.encode(x).tokens)

# Split the data into train, validation, and test sets
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

# Convert tokenized sequences to numerical IDs
def tokenize_sequence(sequence):
    return [tokenizer.token_to_id(token) for token in sequence]

train_sequences = train_df['Tokenized Sequence'].apply(tokenize_sequence).tolist()
val_sequences = val_df['Tokenized Sequence'].apply(tokenize_sequence).tolist()
test_sequences = test_df['Tokenized Sequence'].apply(tokenize_sequence).tolist()

# Convert to PyTorch tensors
train_tensors = [torch.tensor(seq, dtype=torch.long) for seq in train_sequences]
val_tensors = [torch.tensor(seq, dtype=torch.long) for seq in val_sequences]
test_tensors = [torch.tensor(seq, dtype=torch.long) for seq in test_sequences]

# Apply padding to the sequences till length 600
train_tensors = torch.nn.utils.rnn.pad_sequence(train_tensors, batch_first=True, padding_value=tokenizer.token_to_id("<pad>"))
val_tensors = torch.nn.utils.rnn.pad_sequence(val_tensors, batch_first=True, padding_value=tokenizer.token_to_id("<pad>"))
test_tensors = torch.nn.utils.rnn.pad_sequence(test_tensors, batch_first=True, padding_value=tokenizer.token_to_id("<pad>"))

# Ensure all sequences are of length 600
train_tensors = torch.nn.functional.pad(train_tensors, (0, seq_length - train_tensors.size(1)), value=tokenizer.token_to_id("<pad>"))
val_tensors = torch.nn.functional.pad(val_tensors, (0, seq_length - val_tensors.size(1)), value=tokenizer.token_to_id("<pad>"))
test_tensors = torch.nn.functional.pad(test_tensors, (0, seq_length - test_tensors.size(1)), value=tokenizer.token_to_id("<pad>"))

# Create datasets
train_dataset = TensorDataset(train_tensors)
val_dataset = TensorDataset(val_tensors)
test_dataset = TensorDataset(test_tensors)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print("Train dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))
print("Test dataset size:", len(test_dataset))



Train dataset size: 208
Validation dataset size: 26
Test dataset size: 27


In [6]:
import torch

import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, latent_dim, max_seq_length):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Encoder
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
        
        # Latent layer
        self.latent_layer = nn.Sequential(
            nn.Linear(embedding_dim, latent_dim),
            nn.ReLU()
        )
        
        # Decoder
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=4)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=3)
        
        self.output_layer = nn.Linear(embedding_dim, vocab_size)
        self.max_seq_length = max_seq_length
        self.criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id("<pad>"))

    def encode(self, x):
        x_embedded = self.embedding(x).transpose(0, 1)
        memory = self.encoder(x_embedded)
        memory = self.latent_layer(memory.mean(dim=0))
        return memory.unsqueeze(0)

    def decode(self, memory, max_length, start_token, eos_token):
        batch_size = memory.size(1)
        device = memory.device
        decoded_tokens = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)
        logit_list = torch.full((batch_size, self.max_seq_length, vocab_size), -float("inf"), device=device)
        #print(logit_list.shape)

        for t in range(max_length - 1):
            tgt_emb = self.embedding(decoded_tokens).transpose(0, 1)
            out = self.decoder(tgt_emb, memory.expand(decoded_tokens.size(1), -1, -1))
            logits = self.output_layer(out[-1])
            logit_list[:, t, :] = logits
            next_token = torch.argmax(logits, dim=1, keepdim=True)
            decoded_tokens = torch.cat([decoded_tokens, next_token], dim=1)

            if torch.all(next_token.eq(eos_token)):
                break

        return decoded_tokens, logit_list

    def forward(self, x, max_length=None, start_token=2, eos_token=2):
        if max_length is None:
            max_length = self.max_seq_length

        memory = self.encode(x)
        decoded_tokens, logits = self.decode(memory, max_length, start_token, eos_token)
        return decoded_tokens, logits

# Initialize the model
vocab_size = 23
embedding_dim = 16
latent_dim = 16
max_seq_length = seq_length

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerModel(vocab_size, embedding_dim, latent_dim, max_seq_length).to(device)

# Define the training parameters
num_epochs = 10
learning_rate = 0.001
accumulation_steps = 4  # Number of steps to accumulate gradients

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    optimizer.zero_grad()
    for i, batch in tqdm(enumerate(train_loader)):
        inputs = batch[0].to(device)
        outputs, logits = model(inputs)
        
        # Reshape outputs and targets for the loss function
        outputs = logits

        targets = torch.full((logits.shape[0], logits.shape[1], logits.shape[2]), fill_value=0, device=device, dtype=torch.float, requires_grad=False)
        for batcn_ix in range(logits.shape[0]):
            for seq_ix in range(logits.shape[1]):
                targets[batcn_ix, seq_ix, inputs[batcn_ix, seq_ix]] = 1

        loss = criterion(outputs, targets)
        loss.backward()

        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}")

    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch[0].to(device)
            outputs, logits = model(inputs)
            
            
            # Reshape outputs and targets for the loss function
            outputs = logits

            targets = torch.full((logits.shape[0], logits.shape[1], logits.shape[2]), fill_value=0, device=device, dtype=torch.float, requires_grad=False)
            for batcn_ix in range(logits.shape[0]):
                for seq_ix in range(logits.shape[1]):
                    targets[batcn_ix, seq_ix, inputs[batcn_ix, seq_ix]] = 1

            loss = criterion(outputs, targets)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")

52it [00:43,  1.19it/s]


Epoch [1/10], Loss: nan


RuntimeError: shape '[-1, 23]' is invalid for input of size 200

In [9]:
with torch.no_grad():
    for batch in val_loader:
        inputs = batch[0].to(device)
        targets = inputs.clone()

        outputs = model(inputs)
        print(outputs.shape)
        # Reshape outputs and targets for the loss function
        outputs = outputs.view(-1, vocab_size)
        targets = targets.view(-1)
        
        # Convert outputs to tokens
        predicted_tokens = [tokenizer.id_to_token(id.item()) for id in outputs.argmax(dim=1)]
        print(predicted_tokens)

        # Convert targets to tokens
        target_tokens = [tokenizer.id_to_token(id.item()) for id in targets]
        print(target_tokens)

torch.Size([32, 518, 23])
['M', 'K', 'T', 'L', 'L', 'L', 'T', 'L', 'V', 'V', 'V', 'T', 'I', 'V', 'C', 'L', 'D', 'L', 'G', 'Y', 'T', 'M', 'T', 'C', 'C', 'N', 'Q', 'Q', 'S', 'S', 'Q', 'P', 'K', 'T', 'I', 'T', 'T', 'C', 'A', 'E', 'S', 'S', 'C', 'Y', 'K', 'K', 'T', 'W', 'K', 'D', 'H', 'H', 'G', 'T', 'R', 'I', 'E', 'R', 'G', 'C', 'G', 'C', 'P', 'P', 'R', 'K', 'P', 'L', 'I', 'D', 'L', 'I', 'C', 'C', 'E', 'T', 'D', 'E', 'C', 'N', 'N', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '