<a href="https://colab.research.google.com/github/kattens/Protein-Interaction-with-LLMs/blob/main/REAL_Second_part_model_after_the_protbert_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install --upgrade transformers

In [2]:
import pandas as pd
import numpy as np
import torch

In [3]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CUDA not available")


True
Tesla T4


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
csv_file = '/content/drive/MyDrive/pairs.csv'
pairs_df = pd.read_csv(csv_file)

In [6]:
# Truncate each specified column to a maximum length of 500 characters
columns = ['masked_sequence_A', 'masked_sequence_B', 'Sequence_A', 'Sequence_B']
for col in columns:
    pairs_df[col] = pairs_df[col].apply(lambda x: x[:500] if len(x) > 500 else x)

# Find the longest string by length
pairs_df['Length'] = pairs_df['masked_sequence_A'].apply(len)
longest_string = pairs_df.loc[pairs_df['Length'].idxmax(), 'masked_sequence_A']
print(len(longest_string))


500


In [7]:
pairs_df.shape[0]

58300

Create dataset class that handles both global sequences and local sequences for protein pairs, and potentially prepares for the inclusion of 3D structural data

#the base Model (without coordinates at this point):

  ### Modeling Interactions:
  The mdel could be trained to recognize which amino acids interact by learning representations of local sequences that highlight these interactions. During training, the MLM objective helps the model learn contextual embeddings that are rich in information about which amino acids tend to be near each other and under what structural contexts they interact.

  ### Attention Mechanism:
   The custom attention mechanism can be used to weigh the importance of different amino acids in the global context when predicting the masked amino acids in the local sequences. This allows your model to focus more on the parts of the global sequences that are relevant to the interactions highlighted by the local sequences.

  ### Utilizing Global Sequences:
  While the local sequences are your primary interest, the global sequences provide the context necessary for your model to understand the broader environment in which the interactions occur. Even during prediction, you should feed the model the global sequences to utilize the learned context.

  #### This modular design not only meets your current requirements but also provides a scalable framework to incorporate additional dimensions of protein sequence data analysis in the future


In [8]:
!ls /content/drive/MyDrive/Checkpoints

added_tokens.json     generation_config.json   tokenizer_config.json
config.json	      model.safetensors        tokenizer.json
final_checkpoint.pth  special_tokens_map.json  vocab.txt


In [9]:
#in this block of code we will intialize our core model for the architecture
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

# Define the checkpoint directory and the specific checkpoint file
checkpoint_path = '/content/drive/MyDrive/Checkpoints/final_checkpoint.pth'

# Load the tokenizer first
tokenizer = AutoTokenizer.from_pretrained('/content/drive/MyDrive/Checkpoints')

# Initialize the model from the pre-trained configuration in the Checkpoints directory
model = AutoModelForMaskedLM.from_pretrained('/content/drive/MyDrive/Checkpoints')
model.resize_token_embeddings(len(tokenizer))  # Important if you've added tokens

# Determine the device to use (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # Move model to the appropriate device

# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)  # Ensure checkpoint is loaded to the correct device

# Ensure all keys in the checkpoint can be loaded to the model
missing_keys, unexpected_keys = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
if missing_keys or unexpected_keys:
    print(f"Missing keys in state dict: {missing_keys}")
    print(f"Unexpected keys in state dict: {unexpected_keys}")
else:
    print("Model state loaded successfully.")



"""
An Optimizer and a Learning rate scheduler are being set up and their states are loaded from a checkpoint.
This allows for the continuation of model training with the exact parameters and learning rate adjustments that were in use when the
training was last saved, ensuring a seamless transition and consistency in the training process.
"""

# If you need the optimizer and scheduler states
optimizer = Adam(model.parameters(), lr=checkpoint.get('learning_rate', 0.001))  # Fallback to default if not in checkpoint
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

scheduler = StepLR(optimizer, step_size=1, gamma=0.1)
if 'scheduler_state_dict' in checkpoint:
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
else:
    print("No scheduler state found in checkpoint; using default settings.")

print("Model, tokenizer, optimizer, and scheduler loaded from checkpoint.")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Model state loaded successfully.
Model, tokenizer, optimizer, and scheduler loaded from checkpoint.


In [68]:
import torch
from torch.utils.data import Dataset
import logging

# Setup logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

class ProteinInteractionDataset(Dataset):
    def __init__(self, dataframe, tokenizer, mask_probability=0.15, modes=None):
        """
        Initializes the dataset.

        Args:
            dataframe (pandas.DataFrame): The dataframe containing protein sequences.
            tokenizer (transformers.BertTokenizer): The tokenizer for encoding sequences.
            mask_probability (float): The probability of masking a token for the masked language model.
            modes (list of str): List of modes to prepare data. Options include:
                                 'global_masked' - Returns sequences with random masking.
                                 'local' - Returns non-masked sequences.
                                 Modes can be combined.
        """
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.mask_probability = mask_probability
        self.modes = modes if modes else ['global']  # Default to only global if none specified

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
      row = self.dataframe.iloc[idx]
      data = {}

      # Processing for 'global' mode
      if 'global' in self.modes:
          #local_seq = f"[CLS] {row['masked_sequence_A']} [ENTITY1] [SEP] {row['masked_sequence_B']} [ENTITY2] [SEP]"

          global_seq = "[CLS] {} [ENTITY1] [SEP] {} [ENTITY2] [SEP]".format(
              " ".join(debug_df['Sequence_A']),
              " ".join(debug_df['Sequence_B'])
          )


          #logging.debug(f'Local sequence length before tokenization: {len(local_seq)}')
          input_ids, attention_mask = self.tokenize_sequence(global_seq)
          data['input_ids_local'] = input_ids
          data['attention_mask_local'] = attention_mask

          logging.debug(f'hello world')
          logging.debug(f'Local mode: input_ids dimension {input_ids.shape}, attention mask dimension {attention_mask.shape}')

      # Processing for 'local' mode
      if 'local' in self.modes:
          #local_seq = f"[CLS] {row['masked_sequence_A']} [ENTITY1] [SEP] {row['masked_sequence_B']} [ENTITY2] [SEP]"

          local_seq = "[CLS] {} [ENTITY1] [SEP] {} [ENTITY2] [SEP]".format(
              " ".join(debug_df['masked_sequence_A']),
              " ".join(debug_df['masked_sequence_B'])
          )

          logging.debug(f'hello world')
          #logging.debug(f'Local sequence: {local_seq} (length: {(local_seq)})')


          #logging.debug(f'Local sequence length before tokenization: {len(local_seq)}')
          input_ids, attention_mask = self.tokenize_sequence(local_seq)
          data['input_ids_local'] = input_ids
          data['attention_mask_local'] = attention_mask

          logging.debug(f'hello world')
          logging.debug(f'Local mode: input_ids dimension {input_ids.shape}, attention mask dimension {attention_mask.shape}')



      # Example processing for 'coords' mode, assuming it's similar to 'local' mode but with different data
      if 'coords' in self.modes:
          coords_seq = f"[CLS] {row['coords_A']} [ENTITY1] [SEP] {row['coords_B']} [ENTITY2] [SEP]"
          input_ids, attention_mask = self.tokenize_sequence(coords_seq)
          data['input_ids_coords'] = input_ids
          data['attention_mask_coords'] = attention_mask
          logging.debug(f'Coords mode: input_ids dimension {input_ids.shape}, attention mask dimension {attention_mask.shape}')

      return data


    def tokenize_sequence(self, sequence):
      max_length = 512  # Example fixed max length, adjust as necessary
      encoded = self.tokenizer.encode_plus(
          sequence,
          add_special_tokens=False,
          return_tensors='pt',
          padding=False,
          truncation=True,
          max_length=max_length  # Ensure sequences are truncated to a maximum length
      )
      return encoded['input_ids'].squeeze(0), encoded['attention_mask'].squeeze(0)

    def random_mask_sequence(self, sequence):
        tokens = self.tokenizer.tokenize(sequence)
        input_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(tokens), dtype=torch.long)
        labels = torch.full(input_ids.shape, -100)  # Use -100 to ignore these indices in loss calculations
        # Decide where to mask tokens
        mask_indices = torch.rand(input_ids.shape) < self.mask_probability
        labels[mask_indices] = input_ids[mask_indices]
        # 80% of the time, replace masked input tokens with tokenizer.mask_token
        actual_mask = mask_indices & (torch.rand(input_ids.shape) < 0.8)
        input_ids[actual_mask] = self.tokenizer.convert_tokens_to_ids([self.tokenizer.mask_token])[0]
        # 10% of the time, replace masked input tokens with a random token
        random_tokens = torch.randint(2, self.tokenizer.vocab_size, input_ids.shape)
        input_ids[mask_indices & ~actual_mask] = random_tokens[mask_indices & ~actual_mask]
        return input_ids, torch.ones_like(input_ids), labels


In [69]:
#this block is for debug deminstrations
import pandas as pd

# Assuming pairs_df is your original DataFrame and tokenizer is already defined

# This block is for debug demonstrations
# Create a DataFrame from the first row
debug_df = pairs_df.iloc[[0]]  # Use double brackets to keep it as a DataFrame


# Print the length of each sequence column, considering each element as a string
print(f"Sequence_A length: {sum(len(seq) for seq in debug_df['Sequence_A'])}")
print(f"Sequence_B length: {sum(len(seq) for seq in debug_df['Sequence_B'])}")
print(f"Masked_sequence_A length: {sum(len(seq) for seq in debug_df['masked_sequence_A'])}")
print(f"Masked_sequence_B length: {sum(len(seq) for seq in debug_df['masked_sequence_B'])}")


global_seq = "[CLS] {} [ENTITY1] [SEP] {} [ENTITY2] [SEP]".format(
    " ".join(debug_df['Sequence_A']),
    " ".join(debug_df['Sequence_B'])
)

local_seq = "[CLS] {} [ENTITY1] [SEP] {} [ENTITY2] [SEP]".format(
    " ".join(debug_df['masked_sequence_A']),
    " ".join(debug_df['masked_sequence_B'])
)


# Calculate the total length of the local sequence
global_seq_length = len(global_seq)
local_seq_length = len(local_seq)

# Print the length
'''
print(f"Global sequence length: {global_seq_length}")
print(f"Local sequence length: {local_seq_length}")

'''
# Initialize the dataset with specified modes
dataset_global = ProteinInteractionDataset(debug_df, tokenizer, modes=['global'])
dataset_local = ProteinInteractionDataset(debug_df, tokenizer, modes=['local'])
dataset_coords = ProteinInteractionDataset(debug_df, tokenizer, modes=['coords'])

# Fetch data for the first entry in the dataset
data_global = dataset_global[0]  # Accessing the first item
data_local = dataset_local[0]  # Accessing the first item
data_coords = dataset_coords[0]  # Accessing the first item

# Print the data for debugging
print("Global Mode Data:")
for key, value in data_global.items():
    print(f"{key}: {value.shape}")

print("\nLocal Mode Data:")
for key, value in data_local.items():
    print(f"{key}: {value.shape}")

print("\ncoords Mode Data:")
for key, value in data_coords.items():
    print(f"{key}: {value.shape}")


Sequence_A length: 60
Sequence_B length: 60
Masked_sequence_A length: 60
Masked_sequence_B length: 60
Global Mode Data:
input_ids_local: torch.Size([7])
attention_mask_local: torch.Size([7])

Local Mode Data:
input_ids_local: torch.Size([125])
attention_mask_local: torch.Size([125])

coords Mode Data:
input_ids_coords: torch.Size([512])
attention_mask_coords: torch.Size([512])


In [55]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def collate_fn(batch):
    # Initialize containers for the various data components
    input_ids_global = []
    attention_mask_global = []
    labels_local = []
    labels_coords = []

    # Collect data for each sample in the batch
    for item in batch:
        input_ids_global.append(item['input_ids_global'])
        attention_mask_global.append(item['attention_mask_global'])

        if 'labels_local' in item:
            labels_local.append(item['labels_local'])
        if 'labels_coords' in item:
            labels_coords.append(item['labels_coords'])

    # Pad the sequences in the batch to the same length
    input_ids_global = pad_sequence(input_ids_global, batch_first=True, padding_value=0)
    attention_mask_global = pad_sequence(attention_mask_global, batch_first=True, padding_value=0)

    # Prepare output dictionary
    batch_data = {
        'input_ids_global': input_ids_global,
        'attention_mask_global': attention_mask_global,
    }

    # Only add labels to the batch if they are available
    if labels_local:
        labels_local = pad_sequence(labels_local, batch_first=True, padding_value=-100)
        batch_data['labels_local'] = labels_local
    if labels_coords:
        labels_coords = pad_sequence(labels_coords, batch_first=True, padding_value=-100)
        batch_data['labels_coords'] = labels_coords

    return batch_data


#pretraining the bert model


#main class for training:
  1. Process two sets of sequences (global and local) using BERT to extract contextual embeddings.
  2. Integrate these two sets of embeddings using a custom attention mechanism that focuses on relevant parts of the global features for each part of the local features.
  3. Predict an output (like interaction sites or effects) using the combined features.

In [None]:
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer

# Define a sequence processor that can handle one channel of input (either global, local, or coords)
class SequenceProcessor(nn.Module):
    def __init__(self, model):
        super(SequenceProcessor, self).__init__()
        self.model = model

    def forward(self, input_ids, attention_mask=None):
        return self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

# Custom attention mechanism to integrate features from multiple channels
class CustomAttention(nn.Module):
    def __init__(self, hidden_size):
        super(CustomAttention, self).__init__()
        self.key_layer = nn.Linear(hidden_size, hidden_size)
        self.query_layer = nn.Linear(hidden_size, hidden_size)
        self.value_layer = nn.Linear(hidden_size, hidden_size)
        self.softmax = nn.Softmax(dim=-1)
        self.context_layer = nn.Linear(hidden_size, hidden_size)

    def forward(self, global_features, additional_features):
        keys = self.key_layer(global_features)
        queries = self.query_layer(additional_features)
        values = self.value_layer(global_features)
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / (keys.size(-1) ** 0.5)
        attention_weights = self.softmax(attention_scores)
        context = torch.matmul(attention_weights, values)
        processed_context = self.context_layer(context)
        return processed_context

# Main model to process multiple input channels
class ProteinInteractionModel(nn.Module):
    def __init__(self, model_identifier_or_path):
        super(ProteinInteractionModel, self).__init__()
        self.base_model = AutoModel.from_pretrained(model_identifier_or_path)
        hidden_size = self.base_model.config.hidden_size

        self.sequence_processor_global = SequenceProcessor(self.base_model)
        self.sequence_processor_local = SequenceProcessor(self.base_model)
        self.custom_attention = CustomAttention(hidden_size)
        self.mlm_head = nn.Linear(hidden_size, self.base_model.config.vocab_size)

    def forward(self, input_ids_global, attention_mask_global, input_ids_local=None, attention_mask_local=None):
        global_features = self.sequence_processor_global(input_ids_global, attention_mask_global)

        # Handle local input channel if provided
        if input_ids_local is not None and attention_mask_local is not None:
            local_features = self.sequence_processor_local(input_ids_local, attention_mask_local)
            combined_features = self.custom_attention(global_features, local_features)
        else:
            combined_features = global_features  # Use global features directly if local is not provided

        prediction_scores = self.mlm_head(combined_features)
        return prediction_scores


#Training and Validation data creator:

In [None]:
#the most important part to check if the class definition and data management is correctly working
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Splitting the dataset into 80% training and 20% testing
train_df, test_df = train_test_split(pairs_df, test_size=0.2, random_state=42)

# Assuming ProteinInteractionDataset is implemented to handle your DataFrame structure
train_dataset = ProteinInteractionDataset(train_df, tokenizer)
test_dataset = ProteinInteractionDataset(test_df, tokenizer)

#since the model isnt runnint we reduced the batch size to half
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)


# Model Setup

#Training Loop

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel
from torch.nn.utils.rnn import pad_sequence

# Training Loop
def train(model, data_loader, optimizer, epochs=1):
    model.train()
    for epoch in range(epochs):
        for batch in data_loader:
            optimizer.zero_grad()
            outputs = model(
                batch['input_ids_global'],
                batch['attention_mask_global'],
                batch['input_ids_local'],
                batch['attention_mask_local']
            )
            loss = outputs.loss  # Assume model returns a loss
            loss.backward()
            optimizer.step()
            print(f"Epoch {epoch}, Loss: {loss.item()}")

# Example Usage
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
dataframe = pairs_df
dataset = ProteinInteractionDataset(dataframe, tokenizer, modes=['global', 'local'])
data_loader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn, shuffle=True)
model = ProteinInteractionModel('bert-base-uncased')
optimizer = optim.Adam(model.parameters(), lr=1e-4)

train(model, data_loader, optimizer)


KeyError: 'Interaction_Site_A'