<a href="https://colab.research.google.com/github/kattens/SASA-Calculation-For-LLMs/blob/main/Dummy_Base_Code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Code Agenda:
- Provide global sequence and SASA values as separate inputs.
- Use the local sequence as the label.
- No new tokens are added; it's a simple prediction task.
- Build up from this.

# Code Outline:
1. Import the dataset.
2. Import tokenizer/base model.
3. Build dataset class.
4. Build the main model architecture.
5. Create the DataLoader.
6. Split data into training and testing sets.
7. Train the model.
8. Test the model.
9. Make predictions.

In [None]:
#import the libraries:
import transformers
import pandas as pd
import csv
import numpy as np
import torch
import torch.nn as nn
import os

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Clear the CUDA memory cache
torch.cuda.empty_cache()

#collect garbage to free up memory from unused objects
import gc
gc.collect()

0

# 1.Import the dataset.

In [None]:
#THIS SHOULD BE MODIFIED:  this is the path for the main dataset
pairs_df = pd.read_csv('/home/k_ensafitakaldani001_umb_edu/Project1/merged1.csv')

In [None]:
#Just some cleaning up needed for dataset: No need to get to deep in this part

#pairs_df.head(10)
len(pairs_df)

# Count the number of NaN values in the 'SASA_A' column
nan_count = pairs_df['SASA_A'].isna().sum()

print(nan_count)


# Drop rows where the 'sasa_A' column has NaN values
df_cleaned = pairs_df.dropna(subset=['SASA_A'])

len(df_cleaned)

pairs_df = df_cleaned

len(pairs_df)

df = pairs_df
filtered_df = df[(df['Sequence_A'].str.len() >= 50) & (df['Sequence_A'].str.len() < 200)]
print(len(filtered_df))
pairs_df = filtered_df


# Count the number of NaN values in the 'SASA_A' column
nan_count = pairs_df['SASA_A'].isna().sum()

print(nan_count)

df = pairs_df
filtered_df = df[(df['Sequence_A'].str.len() >= 50) & (df['Sequence_A'].str.len() < 200)]
print(len(filtered_df))
pairs_df = filtered_df


44475
40595
0
40595


# 2.Import tokenizer/base model

In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import BertModel, BertConfig, AutoTokenizer


'''
Initialize the ProtBERT tokenizer and model -> mainly use these for pretraining
used for a variety of downstream tasks (e.g., classification, tagging).
Unlike AutoModelForMaskedLM, it is not specifically tied to masked language modeling
'''

tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert_bfd')
model = BertModel.from_pretrained('Rostlab/prot_bert_bfd')

# Define special tokens for entities -> our new tokens for seperation and sasa classification
special_tokens = ['[ENTITY1]', '[ENTITY2]', '-', 'BR', 'PE', 'EX']

# Add special tokens to the tokenizer
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})

# Check if the special tokens were added successfully
print(f"Token '[ENTITY1]' has ID: {tokenizer.convert_tokens_to_ids('[ENTITY1]')}")
print(f"Token '[ENTITY2]' has ID: {tokenizer.convert_tokens_to_ids('[ENTITY2]')}")
print(f"Token '-' has ID: {tokenizer.convert_tokens_to_ids('-')}")
print(f"Token 'BR' has ID: {tokenizer.convert_tokens_to_ids('BR')}")
print(f"Token 'PE' has ID: {tokenizer.convert_tokens_to_ids('PE')}")
print(f"Token 'EX' has ID: {tokenizer.convert_tokens_to_ids('EX')}")

# Resize the model's embedding size to accommodate the new tokens
model.resize_token_embeddings(len(tokenizer))
print('Token embeddings resized to accommodate new tokens.')

# Helper function to convert numerical token IDs back to their textual representation
def ids_to_text(ids):
    return ' '.join(tokenizer.convert_ids_to_tokens(ids))

# Check the updated size of the tokenizer's vocabulary
print(f"Updated vocabulary size: {len(tokenizer)}")

# Check if the new tokens are in the tokenizer's vocabulary
if all(token in tokenizer.get_vocab() for token in special_tokens):
    print("All special tokens are in the tokenizer's vocabulary.")
else:
    print("Some special tokens are NOT in the tokenizer's vocabulary.")


#some checks:
vocab= tokenizer.get_vocab()
print(len(vocab))

# Get the number of amino acids
num_amino_acids = len(tokenizer.get_vocab())
print(num_amino_acids)

Token '[ENTITY1]' has ID: 30
Token '[ENTITY2]' has ID: 31
Token '-' has ID: 32
Token 'BR' has ID: 33
Token 'PE' has ID: 34
Token 'EX' has ID: 35
Token embeddings resized to accommodate new tokens.
Updated vocabulary size: 36
All special tokens are in the tokenizer's vocabulary.
36
36


#  3.Build dataset class.

In [None]:
from torch.utils.data import Dataset

class SampleDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_len):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)

    def get_SASA_sequence(self, row):
        SASA_A = row['SASA_A']
        SASA_B = row['SASA_B']

        try:
            if isinstance(SASA_A, str):
                SASA_A = SASA_A.split(", ")
            elif not isinstance(SASA_A, list):
                SASA_A = [str(SASA_A)]

            if isinstance(SASA_B, str):
                SASA_B = SASA_B.split(", ")
            elif not isinstance(SASA_B, list):
                SASA_B = [str(SASA_B)]
        except Exception as e:
            print(f"Error processing SASA sequences: {e}")
            print(f"SASA_A: {SASA_A}, SASA_B: {SASA_B}")

        SASA_sequence = f"[ENTITY1] {' '.join(SASA_A)} [SEP] [ENTITY2] {' '.join(SASA_B)}"
        return SASA_sequence

    def __getitem__(self, idx):
        row = self.dataset.iloc[idx]
        Global_Sequence_A = row['Sequence_A']
        Global_Sequence_B = row['Sequence_B']
        Local_Sequence_A = row['masked_sequence_A']
        Local_Sequence_B = row['masked_sequence_B']

        # Enhanced sequences with new tokens
        Global_sequence = f"[ENTITY1] {Global_Sequence_A} [SEP] [ENTITY2] {Global_Sequence_B}"
        Local_sequence = f"[ENTITY1] {Local_Sequence_A} [SEP] [ENTITY2] {Local_Sequence_B}"
        SASA_sequence = self.get_SASA_sequence(row)

        # Tokenize input, label, and SASA sequences
        global_inputs = self.tokenizer(Global_sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_len)
        labels = self.tokenizer(Local_sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_len)
        SASA_inputs = self.tokenizer(SASA_sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_len)

        # Ensure all outputs are properly returned
        return {
            'global_input_ids': global_inputs['input_ids'].squeeze(0),
            'global_attention_mask': global_inputs['attention_mask'].squeeze(0),
            'SASA_input_ids': SASA_inputs['input_ids'].squeeze(0),  # SASA tokenized and part of input
            'SASA_attention_mask': SASA_inputs['attention_mask'].squeeze(0),  # Attention mask for SASA inputs
            'labels': labels['input_ids'].squeeze(0)  # Local sequences as labels
        }

def collate_fn(batch):
    inputs = {k: torch.stack([d[k] for d in batch]) for k in batch[0]}
    return inputs


# 4. Build the main model architecture.

In [None]:
# Gated Mechanism with Improvements

class ProtBertSeq2Seq(nn.Module):
    def __init__(self, model, num_amino_acids, seq_len, dropout_rate=0.1):
        super(ProtBertSeq2Seq, self).__init__()
        self.model = model
        self.seq_len = seq_len
        self.num_amino_acids = num_amino_acids
        self.dropout = nn.Dropout(dropout_rate)

        # Improved Gating mechanism using MLP
        self.gate_mlp = nn.Sequential(
            nn.Linear(model.config.hidden_size * 2, model.config.hidden_size),
            nn.ReLU(),  # ReLU for non-linearity
            nn.Linear(model.config.hidden_size, model.config.hidden_size),
            nn.Sigmoid()  # Can experiment with other activations
        )

        # Dynamic learnable scalars for global and SASA outputs
        self.alpha = nn.Parameter(torch.randn(1))  # Dynamically learnable alpha
        self.beta = nn.Parameter(torch.randn(1))   # Dynamically learnable beta

        # Incorporate multi-head attention for combining global and SASA outputs
        self.attention = nn.MultiheadAttention(embed_dim=model.config.hidden_size, num_heads=4)

        # Adding layer normalization to stabilize training
        self.layer_norm = nn.LayerNorm(model.config.hidden_size)

        # Option for residual connections
        self.use_residual = True  # Option to use residual connections

        # Classifier for prediction
        self.classifier = nn.Sequential(
            nn.Linear(model.config.hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_amino_acids)
        )

    def forward(self, global_input_ids, global_attention_mask, SASA_input_ids, SASA_attention_mask):
        # Process global sequence inputs
        global_outputs = self.model(input_ids=global_input_ids, attention_mask=global_attention_mask).last_hidden_state
        global_outputs = self.dropout(global_outputs)

        # Process SASA sequence inputs
        SASA_outputs = self.model(input_ids=SASA_input_ids, attention_mask=SASA_attention_mask).last_hidden_state
        SASA_outputs = self.dropout(SASA_outputs)

        # Multi-head attention over global and SASA outputs
        attention_output, _ = self.attention(global_outputs, SASA_outputs, SASA_outputs)

        # Concatenate global and SASA outputs with attention output
        combined_inputs = torch.cat((attention_output, SASA_outputs), dim=-1)

        # Apply improved gating mechanism (MLP-based)
        gate_output = self.gate_mlp(combined_inputs)
        mixed_outputs = gate_output * (self.alpha * global_outputs) + (1 - gate_output) * (self.beta * SASA_outputs)

        # Apply residual connections if enabled
        if self.use_residual:
            mixed_outputs += global_outputs + SASA_outputs  # Residual connections

        # Apply layer normalization to stabilize the mixed outputs
        mixed_outputs = self.layer_norm(mixed_outputs)

        # Pass through classifier
        logits = self.classifier(mixed_outputs)

        return logits


In [None]:
'''
This section is still part of building the model's main architecture.
We're just defining the training and evaluation functions.
'''


from torch.amp import autocast, GradScaler

# Function to save checkpoint -> DONT FORGET TO MODIFY THE PATH IN YOUR SYSTEM
def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)

# Training function with updated autocast and GradScaler
def train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs, accumulation_steps=2, checkpoint_path="checkpoint240.pth.tar"):
    model.train()  # Set the model to training mode

    start_epoch = 0
    loss_history = []
    val_loss_history = []
    val_accuracy_history = []
    scaler = GradScaler()  # Updated mixed precision scaler

    # Load checkpoint if it exists
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        loss_history = checkpoint['loss_history']
        val_loss_history = checkpoint.get('val_loss_history', [])
        val_accuracy_history = checkpoint.get('val_accuracy_history', [])
        print(f"Loaded checkpoint '{checkpoint_path}' (epoch {checkpoint['epoch']})")

    for epoch in range(start_epoch, epochs):
        total_loss = 0
        model.train()  # Ensure model is in training mode
        for i, batch in enumerate(train_loader):
            global_input_ids = batch['global_input_ids'].to(device)
            global_attention_mask = batch['global_attention_mask'].to(device)
            SASA_input_ids = batch['SASA_input_ids'].to(device)
            SASA_attention_mask = batch['SASA_attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()

            # Updated autocast to torch.amp.autocast
            with autocast():
                outputs = model(global_input_ids, global_attention_mask, SASA_input_ids, SASA_attention_mask)
                loss = criterion(outputs.view(-1, model.num_amino_acids), labels.view(-1)) / accumulation_steps

            scaler.scale(loss).backward()  # Scaled backward pass

            if (i + 1) % accumulation_steps == 0:  # Perform optimizer step every few batches
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            total_loss += loss.item() * accumulation_steps  # Accumulate the loss

            if i % 10 == 0:
                print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {loss.item()}')

            # Freeing memory after each batch
            del global_input_ids, global_attention_mask, SASA_input_ids, SASA_attention_mask, labels, loss
            torch.cuda.empty_cache()

        # Average loss for the epoch
        average_loss = total_loss / len(train_loader)
        loss_history.append(average_loss)
        print(f'End of Epoch {epoch + 1}, Training Loss: {average_loss:.4f}')

        # --- Evaluate the model on the validation set after each epoch ---
        avg_val_loss, val_accuracy = evaluate_model(model, val_loader, criterion, device)
        val_loss_history.append(avg_val_loss)
        val_accuracy_history.append(val_accuracy)
        print(f'Epoch {epoch + 1} - Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy * 100:.2f}%')

        # Save checkpoint after every epoch
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss_history': loss_history,
            'val_loss_history': val_loss_history,
            'val_accuracy_history': val_accuracy_history
        }, filename=checkpoint_path)

    return loss_history, val_loss_history, val_accuracy_history


def evaluate_model(model, dataloader, criterion, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    correct_predictions = 0
    total_labels = 0

    with torch.no_grad():  # No gradient computation during evaluation
        for i, batch in enumerate(dataloader):
            global_input_ids = batch['global_input_ids'].to(device)
            global_attention_mask = batch['global_attention_mask'].to(device)
            SASA_input_ids = batch['SASA_input_ids'].to(device)
            SASA_attention_mask = batch['SASA_attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Use autocast for mixed precision evaluation
            with autocast():
                outputs = model(global_input_ids, global_attention_mask, SASA_input_ids, SASA_attention_mask)
                loss = criterion(outputs.view(-1, model.num_amino_acids), labels.view(-1))

            total_loss += loss.item()

            # For accuracy calculation (optional, depending on your task)
            predictions = torch.argmax(outputs, dim=-1)
            correct_predictions += (predictions == labels).sum().item()
            total_labels += labels.numel()

            # Free memory after each batch
            del global_input_ids, global_attention_mask, SASA_input_ids, SASA_attention_mask, labels, loss
            torch.cuda.empty_cache()

    average_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_labels if total_labels > 0 else 0

    print(f'Evaluation completed, Average Loss: {average_loss:.4f}, Accuracy: {accuracy * 100:.2f}%')

    return average_loss, accuracy

# 5. Create the DataLoader.

In [None]:
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split #split the train and test

# Splitting the dataset into training and validation
train_df, val_df = train_test_split(pairs_df, test_size=0.2, random_state=42)

# Manually reduce the validation DataFrame size by taking 20% of the original validation set
val_df = val_df.sample(frac=0.5, random_state=42)  # Keep 20% of the original validation split

print(f"Total dataset size: {len(pairs_df)}")
print(f"Training data size: {len(train_df)}")

# Setup DataLoaders for training and validation
train_dataset = SampleDataset(train_df, tokenizer,500)
val_dataset = SampleDataset(val_df, tokenizer, 500)

# Using pin_memory for faster host to device transfer
# Increasing num_workers to use multiple CPU cores for data loading
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,  collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn,  num_workers=2, pin_memory=True)

# Log the setup
print(f"Number of batches in train_loader: {len(train_loader)}, Each batch has {train_loader.batch_size} samples.")
print(f"Number of batches in val_loader: {len(val_loader)}, Each batch has {val_loader.batch_size} samples.")


Total dataset size: 40595
Training data size: 32476
Number of batches in train_loader: 2030, Each batch has 16 samples.
Number of batches in val_loader: 254, Each batch has 16 samples.


# 7. Train the model.

In [None]:
from torch.optim import Adam, lr_scheduler

# Hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 2e-4
WEIGHT_DECAY = 1e-5
SEQ_LEN = 500
LR_STEP_SIZE = 2
LR_GAMMA = 0.5
ACCUMULATION_STEPS = 2
NUM_EPOCHS = 10
DROPOUT = 0.1


# Load the tokenizer
tokenizer = tokenizer

# Load the base BertModel or similar from transformers suited for your needs
base_model = model  # Replace with your base model

# Get the number of amino acids (adjust based on whether you are using a classification task)
num_amino_acids = len(vocab)  # Adjust this if your task isn't directly classification

# Initialize your custom model with the base model
my_model = ProtBertSeq2Seq(model=base_model, num_amino_acids=num_amino_acids, seq_len=500)

# Move the model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#lora_model.to(device)

my_model.to(device)

# Optimizer and scheduler setup
optimizer = Adam(my_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = lr_scheduler.StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=LR_GAMMA)

# Loss func
criterion = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)


In [None]:
from torch.cuda.amp import GradScaler, autocast

# Define number of epochs
num_epochs = 10

# Call the training function, which will also evaluate after each epoch
train_loss_history, val_loss_history, val_accuracy_history = train_model(
    my_model,  # Using the LoRA model
    train_loader,  # Training data loader
    val_loader,  # Validation data loader
    optimizer,  # Optimizer for LoRA model
    criterion,  # Loss function
    device,  # Device (CPU or GPU)
    num_epochs  # Number of epochs
)


  scaler = GradScaler()  # Updated mixed precision scaler
  checkpoint = torch.load(checkpoint_path)


Loaded checkpoint 'checkpoint240.pth.tar' (epoch 10)
