In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [20]:
class ProteinIonDataset(Dataset):
    def __init__(self, protein_embeddings_file, ion_embeddings_file):
        self.protein_embeddings = torch.load(protein_embeddings_file)
        self.ion_embeddings = torch.load(ion_embeddings_file)

    def __len__(self):
        return len(self.protein_embeddings)

    def __getitem__(self, idx):
        protein_id = list(self.protein_embeddings.keys())[idx]
        protein_seq = self.protein_embeddings[protein_id]
        ion_seq = self.ion_embeddings[protein_id]
        return protein_seq, ion_seq

# Load the dataset
dataset = ProteinIonDataset('data/embeddings.pt', 'data/target_embeddings.pt')

# Create a DataLoader for batching
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [21]:
for proteins, ions in dataloader:
    print(proteins.shape)
    print(ions.shape)
    print(proteins[0])
    print(ions[0])
    break

torch.Size([32, 1024])
torch.Size([32, 1024])
tensor([ 0.0357, -0.1899, -0.1827,  ...,  0.0174, -0.2099, -0.0339],
       device='cuda:0')
tensor([ 0,  0,  0,  ..., 14, 14, 14])


In [29]:
def create_padding_mask(seq, pad_token):
        # Create a mask for the padding tokens
        return (seq == pad_token)

class MyTransformerDecoder(nn.Module):
    
    def __init__(self, protein_emb_size, ion_vocab_size, d_model, nhead, num_decoder_layers, dim_feedforward):
        super(MyTransformerDecoder, self).__init__()
        
        # Linear layer to adjust protein embedding size if necessary
        self.protein_emb_adjust = nn.Linear(protein_emb_size, d_model)
        
        self.ion_embedding = nn.Embedding(ion_vocab_size, d_model)
        self.transformer_decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(self.transformer_decoder_layer, num_layers=num_decoder_layers)
        self.fc_out = nn.Linear(d_model, ion_vocab_size)


    def forward(self, protein_seq, ion_seq):
        # Adjust protein sequence embedding size if necessary
        protein_seq = self.protein_emb_adjust(protein_seq)

        # Embedding the ion sequence
        ion_seq_emb = self.ion_embedding(ion_seq)
        
        # Create padding mask for the ion sequence
        padding_mask = create_padding_mask(ion_seq, 14)
        
        # Decoding
        output = self.transformer_decoder(tgt=ion_seq_emb, memory=protein_seq, tgt_key_padding_mask=padding_mask)
        # Pass through linear layer for predictions
        output = self.fc_out(output)

        return output

# Example parameters
protein_emb_size = 1024  # Size of your protein embeddings
ion_vocab_size = 16     # 13 metal-ions + 1 padding + 1 EOS + 1 no-ion
d_model = 1024           # Size of the embeddings and transformer model (can be adjusted)
nhead = 8               # Number of heads in multi-head attention models
num_decoder_layers = 3  # Number of sub-encoder-layers in the decoder stack
dim_feedforward = 2048  # Dimension of the feedforward network

In [30]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [31]:
# Model initialization
model = MyTransformerDecoder(protein_emb_size, ion_vocab_size, d_model, nhead, num_decoder_layers, dim_feedforward)
model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)

# Number of epochs
num_epochs = 1  # You can adjust this based on your needs
check_interval = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    with tqdm(dataloader, unit='batch') as tepoch:
        for i, (protein_seqs, ion_seqs) in enumerate(tepoch):
            tepoch.set_description(f"Epoch {epoch+1}")
            protein_seqs = protein_seqs.unsqueeze(1)

            protein_seqs = protein_seqs.to(device)
            ion_seqs = ion_seqs.to(device)

            # Reset gradients
            optimizer.zero_grad()

            # Forward pass
            output = model(protein_seqs, ion_seqs)

            # Compute loss
            loss = criterion(output.transpose(1, 2), ion_seqs)  # Adjust dimensions if necessary

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Check loss every 'check_interval' iterations
            if (i + 1) % check_interval == 0:
                print(f'Epoch {epoch+1}, Iteration {i+1}, Loss: {running_loss / check_interval}')
                running_loss = 0.0


    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

# Save the model
#torch.save(model.state_dict(), 'transformer_model.pth')

print("Training complete")

Epoch 1:   7%|▋         | 5/67 [01:12<14:45, 14.28s/batch]

Epoch 1, Iteration 5, Loss: 0.818752458691597


Epoch 1:  15%|█▍        | 10/67 [02:22<13:17, 13.98s/batch]

Epoch 1, Iteration 10, Loss: 0.05576927736401558


Epoch 1:  22%|██▏       | 15/67 [03:31<12:02, 13.89s/batch]

Epoch 1, Iteration 15, Loss: 0.04280492961406708


Epoch 1:  30%|██▉       | 20/67 [04:41<10:56, 13.97s/batch]

Epoch 1, Iteration 20, Loss: 0.02099681571125984


Epoch 1:  37%|███▋      | 25/67 [05:52<10:00, 14.29s/batch]

Epoch 1, Iteration 25, Loss: 0.009666849300265313


Epoch 1:  45%|████▍     | 30/67 [07:04<08:51, 14.36s/batch]

Epoch 1, Iteration 30, Loss: 0.006367138121277094


Epoch 1:  52%|█████▏    | 35/67 [08:16<07:39, 14.36s/batch]

Epoch 1, Iteration 35, Loss: 0.004679392022080719


Epoch 1:  60%|█████▉    | 40/67 [09:28<06:27, 14.33s/batch]

Epoch 1, Iteration 40, Loss: 0.003164233732968569


Epoch 1:  67%|██████▋   | 45/67 [10:38<05:09, 14.06s/batch]

Epoch 1, Iteration 45, Loss: 0.0015086819650605321


Epoch 1:  75%|███████▍  | 50/67 [11:48<03:57, 14.00s/batch]

Epoch 1, Iteration 50, Loss: 0.0009136674692854285


Epoch 1:  82%|████████▏ | 55/67 [12:57<02:46, 13.84s/batch]

Epoch 1, Iteration 55, Loss: 0.0007454292615875601


Epoch 1:  90%|████████▉ | 60/67 [14:06<01:37, 13.93s/batch]

Epoch 1, Iteration 60, Loss: 0.000898661173414439


Epoch 1:  97%|█████████▋| 65/67 [15:18<00:28, 14.23s/batch]

Epoch 1, Iteration 65, Loss: 0.00016990129224723206


Epoch 1: 100%|██████████| 67/67 [15:46<00:00, 14.13s/batch]

Epoch [1/1], Loss: 0.00010160612873733044
Training complete





In [34]:
PAD_TOKEN = '<pad>'
EOS_TOKEN = '<eos>'
PAD_TOKEN_ID = 14
EOS_TOKEN_ID = 15

In [49]:
input_emb, _ = next(iter(dataloader))

In [52]:
input_emb = input_emb.unsqueeze(1)
input_emb.shape

torch.Size([32, 1, 1024])

In [53]:
input_emb = input_emb[0]
input_emb.shape

torch.Size([1, 1024])

In [54]:
input_emb = input_emb.unsqueeze(0)
input_emb.shape

torch.Size([1, 1, 1024])

In [57]:
input_emb = input_emb.to(device)

In [69]:
for input_emb, target_emb in dataloader:
    input_emb = input_emb.to(device)
    target_emb = target_emb.to(device)
    input_emb = input_emb.unsqueeze(1)
    print(input_emb.shape)
    target_emb2 = torch.zeros(32, 1024, dtype=torch.long).to(device)
    print(target_emb2.shape)
    output = model(input_emb, target_emb2)
    break

torch.Size([32, 1, 1024])
torch.Size([32, 1024])


In [64]:
def generate_sequence(model, input_emb, max_length=1024):
    model.eval()  # Set the model to evaluation mode

    # Initialize the target sequence with a start token or placeholder
    # Assuming get_initial_target_embedding() returns a tensor of shape [1024]
    target_seq = get_initial_target_embedding().unsqueeze(0)  # Shape: [1, 1, 1024]
    target_seq = target_seq.to(device)

    for _ in range(max_length - 1):
        with torch.no_grad():
            # Generate output using current target sequence
            output = model(input_emb, target_seq)

            # Assuming the model outputs the next token embedding
            next_token_emb = output[:, -1, :].unsqueeze(1)  # Shape: [1, 1, 1024]

            # Append the next token embedding to the sequence
            target_seq = torch.cat([target_seq, next_token_emb], dim=1)

            # Include any break conditions if necessary

    return target_seq

# Function to get the initial target embedding
def get_initial_target_embedding():
    # Define how to create the initial target embedding
    return torch.zeros(1024, dtype=torch.long)  # Example placeholder, shape: [1024]

# Example usage for inference
# input_emb: A single protein sequence embedding, shaped [1, 1, 1024]
generated_sequence = generate_sequence(model, input_emb)

RuntimeError: shape '[1, 8, 128]' is invalid for input of size 32768

In [46]:
def generate_sequence(model, input_emb, max_length=1024):
    model.eval()  # Set the model to evaluation mode

    # Start with an initial target token, e.g., a start-of-sequence token
    # Assuming token '0' is the start token
    target_seq = torch.zeros(1024, dtype=torch.long).to(device)  # [1, 1]

    for _ in range(max_length):
        with torch.no_grad():
            # Generate output using current target sequence
            print(input_emb.shape)
            print(target_seq.unsqueeze(0).shape)
            output = model(input_emb.unsqueeze(0), target_seq.unsqueeze(0))

            # Get the next token (e.g., take the token with the highest probability)
            next_token = output.argmax(dim=-1)[-1, :].unsqueeze(0)

            # Append the next token to the sequence
            target_seq = torch.cat([target_seq, next_token], dim=0)

            # Check for end-of-sequence token (e.g., if 'eos_token_id' is reached)
            if next_token.item() == EOS_TOKEN_ID:
                 break

    return target_seq

# Assuming test_dataloader is your DataLoader for test data
for input_emb, _ in dataloader:
    # Reshape input_emb to match the expected shape [1, seq_len, emb_size]
    input_emb = input_emb[0].unsqueeze(0)
    print(f"input_emb shape: {input_emb.shape}")
    # Generate sequence
    generated_sequence = generate_sequence(model, input_emb)

    # Handle generated_sequence as needed for your task
    # ...

input_emb shape: torch.Size([1, 1024])
torch.Size([1, 1024])
torch.Size([1, 1024])


RuntimeError: Tensors must have same number of dimensions: got 1 and 2