<a href="https://colab.research.google.com/github/gokceuludogan/protein-ml-crash-course/blob/main/Chapter_3_Protein_Property_Prediction_from_Traditional_ML_to_pLMs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Overview

In this chapter, we will focus on predicting protein subcellular localization, specifically aiming to solve a **binary classification problem**: determining whether a protein is localized in the **cytoplasm** or not. We will begin with a simple machine learning model, advance to a **Convolutional Neural Network (CNN)** built using PyTorch, and finally, we will fine-tune a large pre-trained **transformer model**, **ESM2-650M**. Each approach offers unique insights and challenges, giving us a well-rounded understanding of the problem.

### Dataset

We'll use the **DeepLoc 2.0** dataset, which provides subcellular localization labels for eukaryotic proteins. The original task is multi-label, with 10 possible localization sites, but we simplify it to a **binary classification**: whether a protein is in the cytoplasm (1) or not (0).

---

### Data Preprocessing

First, we need to preprocess the data, converting the sequences into suitable numeric representations (e.g., one-hot encoding) and extracting features for classification.


In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

# Load datasets
train_val_url = "https://services.healthtech.dtu.dk/services/DeepLoc-2.0/data/Swissprot_Train_Validation_dataset.csv"
test_url = "https://services.healthtech.dtu.dk/services/DeepLoc-2.0/data/hpa_testset.csv"

train_val_data = pd.read_csv(train_val_url)
test_data = pd.read_csv(test_url)


# Filter out sequences with 'X' or invalid amino acids
train_val_data = train_val_data[~train_val_data['Sequence'].str.contains('X')]
# Extract relevant columns: 'Cytoplasm' and 'Sequence'
train_val_data['label'] = train_val_data['Cytoplasm']   # Binary label: 1 if Cytoplasm, 0 otherwise

# Define the maximum sequence length to truncate
MAX_SEQ_LEN = 1000

# Map amino acids to integers (A -> 0, C -> 1, ..., Y -> 19)
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
aa_to_int = {aa: idx for idx, aa in enumerate(amino_acids)}

# Function to encode and truncate sequences
def encode_sequence(seq, max_len):
    # Convert sequence to integer encoding
    encoded_seq = [aa_to_int[aa] for aa in seq if aa in aa_to_int]

    # Truncate or pad the sequence
    if len(encoded_seq) > max_len:
        return encoded_seq[:max_len]
    else:
        return encoded_seq + [0] * (max_len - len(encoded_seq))  # Pad with 0s if shorter

# Apply encoding and truncating
train_val_data['encoded_Sequence'] = train_val_data['Sequence'].apply(lambda x: encode_sequence(x, MAX_SEQ_LEN))

# Convert to numpy arrays
X = np.stack(train_val_data['encoded_Sequence'].values)
y = train_val_data['label'].values

# Split dataset into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
X_train_seq, X_val_seq, y_train_seq, y_val_seq = train_test_split( train_val_data['Sequence'].values, y, test_size=0.2, random_state=42)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_val_data['label'] = train_val_data['Cytoplasm']   # Binary label: 1 if Cytoplasm, 0 otherwise



### Step 2: Simple Machine Learning Model

**Model**: Logistic Regression

We start with a logistic regression model, which is a basic yet effective classifier for binary problems. In logistic regression, the model fits a logistic curve to the data and outputs probabilities, which are then used for binary classification. The simplicity of this model makes it interpretable, but it may struggle with highly complex or non-linear patterns, especially when dealing with protein sequences.

### Why Logistic Regression?

- **Simplicity and Interpretability**: Logistic regression is easy to implement and provides a good baseline. It allows us to interpret the model's output as probabilities, which helps us understand its decision-making process.
- **Limitations**: Logistic regression assumes linear relationships, so it might not capture the intricate patterns present in protein sequences.

In [2]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# Train Logistic Regression model
clf = LogisticRegression(max_iter=1000)
clf.fit(X_train, y_train)

# Evaluate model on validation data
y_pred = clf.predict(X_val)
accuracy = accuracy_score(y_val, y_pred)

print(f"Validation Accuracy: {accuracy:.2f}")

Validation Accuracy: 0.62


## Convolutional Neural Network (CNN)

**Model**: 1D Convolutional Neural Network (CNN)

The next step is building a **CNN** using PyTorch. CNNs are excellent at learning spatial hierarchies in data. For protein sequences, a 1D CNN can capture local patterns in the sequence, which might correspond to specific motifs or domains related to subcellular localization.

### Why a CNN?

- **Ability to Capture Local Patterns**: In biological sequences, local patterns or motifs (e.g., specific amino acid combinations) can play a crucial role in protein function and localization. A CNN can efficiently extract these local features.
- **Hierarchical Learning**: CNNs use multiple convolutional layers to learn complex patterns, making them well-suited for biological sequence data.
- **Drawbacks**: While CNNs can capture local features, they may not be as effective in understanding long-range dependencies in sequences compared to transformers.

**Model Architecture**:

- **Convolutional Layer**: Extracts local patterns from the protein sequence.
- **Max Pooling Layer**: Reduces dimensionality, making the model more efficient.
- **Fully Connected Layer**: Combines learned features to make a final binary classification decision.

### Convert X_train and y_train to PyTorch Tensors and Datasets

You need to ensure that your input data (X_train) and labels (y_train) are in a format suitable for the CNN. You can then use TensorDataset to combine X_train_tensor and y_train_tensor and create a DataLoader to iterate over the dataset in mini-batches.




In [3]:
import torch
from torch.utils.data import TensorDataset, DataLoader
# Convert to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.long)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_val_tensor = torch.tensor(X_val, dtype=torch.long)
y_val_tensor = torch.tensor(y_val, dtype=torch.long)

# Create Datasets and DataLoaders
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

### Define the CNN Model
Now that the data is prepared, we can build a simple CNN for the task.

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Define constants
BATCH_SIZE = 32
EMBED_DIM = 32  # Dimension of embedding space
SEQ_LEN = 1000  # Fixed sequence length
VOCAB_SIZE = len(amino_acids)  # Assuming you have defined this

# Updated model definition
class CNNProtein(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(CNNProtein, self).__init__()
        # Embedding layer to convert integer-encoded sequences to dense vectors
        self.embedding = nn.Embedding(vocab_size, embed_dim)

        # Convolutional layers
        self.conv1 = nn.Conv1d(in_channels=embed_dim, out_channels=64, kernel_size=5, padding=2)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=5, padding=2)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)


        # Dropout for regularization
        self.dropout = nn.Dropout(p=0.5)  # Dropout probability 0.5

        # Fully connected layers
        self.fc1 = nn.Linear(128 * (SEQ_LEN // 4), 256)  # Output from CNN layer after pooling
        self.fc2 = nn.Linear(256, 1)  # Binary output (cytoplasm or not)

    def forward(self, x):
        # Pass the input through the embedding layer
        x = self.embedding(x)  # Shape: (batch_size, seq_len, embed_dim)
        x = x.permute(0, 2, 1)  # Change shape to (batch_size, embed_dim, seq_len)

        # Pass through convolutional layers
        x = F.relu(self.conv1(x))  # Conv1d expects (batch_size, channels, seq_len)
        x = self.pool(x)           # Pooling layer to reduce sequence length
        x = F.relu(self.conv2(x))
        x = self.pool(x)

        # Flatten the tensor
        x = x.view(x.size(0), -1)

        x = self.dropout(x)

        # Pass through fully connected layers
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))  # Sigmoid for binary classification
        return x

# Create the model
model = CNNProtein(vocab_size=VOCAB_SIZE, embed_dim=EMBED_DIM)

# Define loss function and optimizer
criterion = nn.BCELoss()  # Binary cross-entropy for binary classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

### Training loop

In [5]:
# Training function
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(), labels.float())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_loader)

# Validation function
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), labels.float())
            val_loss += loss.item()

    return val_loss / len(val_loader)

# Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    val_loss = validate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")


Epoch 1/10, Train Loss: 0.6040, Val Loss: 0.5513
Epoch 2/10, Train Loss: 0.5308, Val Loss: 0.5304
Epoch 3/10, Train Loss: 0.4881, Val Loss: 0.5210
Epoch 4/10, Train Loss: 0.4486, Val Loss: 0.5375
Epoch 5/10, Train Loss: 0.4097, Val Loss: 0.5278
Epoch 6/10, Train Loss: 0.3731, Val Loss: 0.5418
Epoch 7/10, Train Loss: 0.3385, Val Loss: 0.5530
Epoch 8/10, Train Loss: 0.3048, Val Loss: 0.6062
Epoch 9/10, Train Loss: 0.2777, Val Loss: 0.6003
Epoch 10/10, Train Loss: 0.2557, Val Loss: 0.5742


The model begins to overfit after a few epochs, which is a common issue in neural networks, particularly when they lack sufficient regularization. Overfitting occurs when a model learns the noise and details of the training data too well, resulting in poor generalization to new, unseen data. This is evident when the training loss continues to decrease while the validation loss starts to increase.

To mitigate this issue, techniques such as early stopping can be implemented. Early stopping involves monitoring the validation performance during training and saving the model checkpoint that exhibits the best validation performance, preventing the model from continuing to learn the noise of the training data.






## Fine-tuning ESM2-150M

**Model**: **ESM2-150M (Evolutionary Scale Modeling)**

Finally, we will fine-tune **ESM2-150M**, a transformer model specifically designed for protein sequences.  Fine-tuning involves adapting the pre-trained model to the specific classification problem by continuing the training process on the new dataset, adjusting its weights for the task at hand. ESM2 uses attention mechanisms to capture both local and global relationships in protein sequences, making it especially powerful for tasks like subcellular localization prediction.

### Why ESM2?

- **Global Context**: Transformers excel at capturing long-range dependencies in sequences, which is important for proteins where distant residues can interact to determine function and localization.
- **Pre-training on Protein Data**: ESM2 has been pre-trained on massive protein datasets, making it particularly adept at understanding the language of proteins (i.e., amino acid sequences).
- **State-of-the-Art Performance**: Transformer models like ESM2 are known to provide cutting-edge performance on a variety of protein prediction tasks, including structure and function prediction.

**Key Components**:

- **Self-Attention Mechanism**: Allows the model to weigh the importance of different parts of the sequence, making it highly effective at capturing global patterns.
- **Fine-tuning**: We can fine-tune ESM2 on the specific task of binary classification for cytoplasmic localization by updating the model weights on our labeled data.

In [None]:
from transformers import EsmForSequenceClassification, EsmTokenizer
import torch.optim as optim
# Load pre-trained ESM2 model and tokenizer
model_name = "facebook/esm2_t30_150M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
model = EsmForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Limit training samples for the sake of training time
X_train_seq, X_val_seq, y_train_seq, y_val_seq = X_train_seq[:2000], X_val_seq, y_train[:2000], y_val

# Tokenize input sequences for training and validation
train_inputs = tokenizer(list(X_train_seq), return_tensors="pt", padding=True, truncation=True, max_length=1000)
val_inputs = tokenizer(list(X_val_seq), return_tensors="pt", padding=True, truncation=True, max_length=1000)

# Convert labels to tensors
train_labels = torch.tensor(y_train_seq, dtype=torch.long)
val_labels = torch.tensor(y_val_seq, dtype=torch.long)

# Create DataLoader for batching
train_data = TensorDataset(train_inputs['input_ids'], train_inputs['attention_mask'], train_labels)
val_data = TensorDataset(val_inputs['input_ids'], val_inputs['attention_mask'], val_labels)

train_loader = DataLoader(train_data, batch_size=2, shuffle=True)
val_loader = DataLoader(val_data, batch_size=2, shuffle=False)

# Set up optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


# Function to train the model
def train(model, train_loader, optimizer, device):
    model.train()
    total_train_loss = 0
    for batch in train_loader:
        input_ids, attention_mask, labels = [b.to(device) for b in batch]

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_train_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    avg_train_loss = total_train_loss / len(train_loader)
    return avg_train_loss

# Function to evaluate the model on validation set
def validate(model, val_loader, device):
    model.eval()
    total_val_loss = 0
    correct_preds = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids, attention_mask, labels = [b.to(device) for b in batch]

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            val_loss = outputs.loss
            total_val_loss += val_loss.item()

            # Get predictions
            preds = torch.argmax(outputs.logits, dim=1)
            correct_preds += torch.sum(preds == labels).item()

    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = correct_preds / len(y_val)
    return avg_val_loss, val_accuracy

# Fine-tuning loop
for epoch in range(3):  # Fine-tuning for 3 epochs
    train_loss = train(model, train_loader, optimizer, device)
    val_loss, val_accuracy = validate(model, val_loader, device)

    print(f"Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### ESM2 Advantages

- **Pre-trained Knowledge**: The model benefits from knowledge gained during its pre-training phase, which was conducted on a massive protein dataset.
- **Attention Mechanism**: ESM2 can capture long-range dependencies within protein sequences, which is critical for correctly predicting properties like localization.
- **Drawbacks**: Large models like ESM2 are computationally expensive and require more resources for fine-tuning compared to simpler models like CNNs or logistic regression.

---

### Conclusion

In summary, we have explored various approaches to protein sequence analysis, from traditional machine learning methods to advanced deep learning techniques like Convolutional Neural Networks and the ESM2 model. These methods can provide significant insights into protein subcellular localization and other critical tasks. For further exploration of different models and challenges in protein understanding, I recommend visiting [TorchProtein's benchmark](https://torchprotein.ai/benchmark). This resource offers a comprehensive overview of existing models and their performance on a variety of protein-related tasks.