In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, accuracy_score
from sklearn.utils import resample
import matplotlib.pyplot as plt
from tqdm import tqdm

DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Hyperparameters
NUM_CLASSES = 2  # Binary classification (0: no signal peptide, 1: signal peptide)
BATCH_SIZE = 64
EPOCHS = 20
LR = 0.001
MAX_LENGTH = 100  # You can adjust this based on your data

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')
DRIVE_PATH = "/content/drive/MyDrive/PBLRost/"
FASTA_PATH = os.path.join(DRIVE_PATH, "data/complete_set_unpartitioned.fasta")
MODEL_PATH = os.path.join(DRIVE_PATH, "models/2state_ohe_lin_v1.pt")

In [None]:
aas = 'ARNDCEQGHILKMFPSTWYV'
aa2idx = {aa: i for i, aa in enumerate(aas)}
num_aa = len(aas)

def one_hot_encode_sequence(seq, max_length):
    """One-hot encode amino acid sequence with padding"""
    # Truncate or pad sequence
    seq = seq[:max_length].ljust(max_length, 'X')
    
    # Initialize encoding matrix
    encoding = np.zeros((max_length, num_aa))
    
    for i, aa in enumerate(seq):
        if aa in aa2idx:
            encoding[i, aa2idx[aa]] = 1
    
    return encoding

In [None]:
def load_and_preprocess_data(fasta_path):
    """Load FASTA data and preprocess for residue-level classification"""
    records = []
    
    with open(fasta_path, "r") as f:
        current_record = None
        for line in f:
            if line.startswith(">"):
                if current_record is not None:
                    if current_record["sequence"] is not None and current_record["label"] is not None:
                        records.append(current_record)
                
                uniprot_ac, kingdom, type_ = line[1:].strip().split("|")
                current_record = {
                    "uniprot_ac": uniprot_ac, 
                    "kingdom": kingdom, 
                    "type": type_, 
                    "sequence": None, 
                    "label": None
                }
            else:
                if current_record["sequence"] is None:
                    current_record["sequence"] = line.strip()
                elif current_record["label"] is None:
                    current_record["label"] = line.strip()
        
        # Add last record
        if current_record is not None:
            if current_record["sequence"] is not None and current_record["label"] is not None:
                records.append(current_record)

    print(f"Total records loaded: {len(records)}")
    
    # Convert to DataFrame
    df_raw = pd.DataFrame(records)
    
    # Filter out sequences with 'P' in labels (if needed)
    df = df_raw[~df_raw["label"].str.contains("P")]
    
    # Map signal peptide types to binary classification
    # SP, LIPO, TAT, TATLIPO -> has signal peptide (1)
    # NO_SP -> no signal peptide (0)
    df["has_signal_peptide"] = df["type"].map({
        "NO_SP": 0,
        "LIPO": 1,
        "SP": 1,
        "TAT": 1,
        "TATLIPO": 1
    })
    
    # Balance the dataset
    df_majority = df[df["has_signal_peptide"] == 0]
    df_minority = df[df["has_signal_peptide"] == 1]
    
    if not df_minority.empty and not df_majority.empty:
        df_minority_upsampled = resample(
            df_minority,
            replace=True,
            n_samples=len(df_majority),
            random_state=42
        )
        df_balanced = pd.concat([df_majority, df_minority_upsampled])
    else:
        df_balanced = df.copy()
    
    # Convert residue-level labels to binary
    label_map = {'S': 1, 'T': 1, 'L': 1, 'I': 0, 'M': 0, 'O': 0}
    
    processed_data = []
    for _, row in df_balanced.iterrows():
        sequence = row["sequence"]
        label_string = row["label"]
        
        # Convert label string to binary array
        residue_labels = [label_map.get(c, 0) for c in label_string]
        
        # Only keep sequences that fit within max_length
        if len(sequence) <= MAX_LENGTH and len(residue_labels) == len(sequence):
            processed_data.append({
                'sequence': sequence,
                'labels': residue_labels,
                'has_signal_peptide': row["has_signal_peptide"]
            })
    
    return processed_data

In [None]:
class ProteinDataset(Dataset):
    def __init__(self, data, max_length):
        self.data = data
        self.max_length = max_length
        
        # Pre-encode all sequences
        self.encoded_sequences = []
        self.padded_labels = []
        self.attention_masks = []
        
        for item in data:
            # One-hot encode sequence
            encoded_seq = one_hot_encode_sequence(item['sequence'], max_length)
            self.encoded_sequences.append(encoded_seq)
            
            # Pad labels to max_length
            labels = item['labels']
            padded_labels = labels + [0] * (max_length - len(labels))  # Pad with 0s
            self.padded_labels.append(padded_labels[:max_length])
            
            # Create attention mask (1 for real residues, 0 for padding)
            attention_mask = [1] * len(labels) + [0] * (max_length - len(labels))
            self.attention_masks.append(attention_mask[:max_length])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            'sequence': torch.tensor(self.encoded_sequences[idx], dtype=torch.float32),
            'labels': torch.tensor(self.padded_labels[idx], dtype=torch.long),
            'attention_mask': torch.tensor(self.attention_masks[idx], dtype=torch.float32)
        }

In [None]:
class SimpleSignalPeptideClassifier(nn.Module):
    def __init__(self, max_length, num_aa, hidden_dim=128, num_layers=3):
        super().__init__()
        self.max_length = max_length
        self.num_aa = num_aa
        
        # Simple approach: flatten one-hot encoding and use linear layers
        self.flatten_dim = max_length * num_aa
        
        layers = []
        in_dim = self.flatten_dim
        
        for _ in range(num_layers):
            layers.extend([
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.3)
            ])
            in_dim = hidden_dim
        
        # Output layer: predict for each position
        layers.append(nn.Linear(hidden_dim, max_length))
        
        self.classifier = nn.Sequential(*layers)
    
    def forward(self, x, attention_mask=None):
        # x: [batch_size, max_length, num_aa]
        batch_size = x.size(0)
        
        # Flatten the input
        x_flat = x.view(batch_size, -1)  # [batch_size, max_length * num_aa]
        
        # Pass through classifier
        logits = self.classifier(x_flat)  # [batch_size, max_length]
        
        return logits

In [None]:
def train_model(model, train_loader, val_loader, num_epochs, device):
    """Train the model"""
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.BCEWithLogitsLoss(reduction='none')  # No reduction for masking
    
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        train_batches = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            sequences = batch['sequence'].to(device)
            labels = batch['labels'].to(device).float()
            attention_mask = batch['attention_mask'].to(device)
            
            optimizer.zero_grad()
            
            logits = model(sequences, attention_mask)
            
            # Apply mask to loss calculation
            loss = criterion(logits, labels)
            masked_loss = (loss * attention_mask).sum() / attention_mask.sum()
            
            masked_loss.backward()
            optimizer.step()
            
            train_loss += masked_loss.item()
            train_batches += 1
        
        avg_train_loss = train_loss / train_batches
        train_losses.append(avg_train_loss)
        
        # Validation
        model.eval()
        val_loss = 0
        val_batches = 0
        
        with torch.no_grad():
            for batch in val_loader:
                sequences = batch['sequence'].to(device)
                labels = batch['labels'].to(device).float()
                attention_mask = batch['attention_mask'].to(device)
                
                logits = model(sequences, attention_mask)
                loss = criterion(logits, labels)
                masked_loss = (loss * attention_mask).sum() / attention_mask.sum()
                
                val_loss += masked_loss.item()
                val_batches += 1
        
        avg_val_loss = val_loss / val_batches
        val_losses.append(avg_val_loss)
        
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
    
    return train_losses, val_losses

In [None]:
# compute percentage of false predicted labels
def sequence_level_accuracy(predictions, labels):
    correct = 0
    total = 0
    for pred, label in zip(predictions, labels):
        if len(pred) != len(label):
            min_len = min(len(pred), len(label))
            pred = pred[:min_len]
            label = label[:min_len]
        total += 1
        if (pred == label).all():
            correct += 1
    return correct / total
acc = sequence_level_accuracy(all_preds, all_labels)
print(f"Sequence Level Accuracy: {acc:.4f}")

In [None]:
def evaluate_model(model, test_loader, device):
    """Evaluate the model"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            sequences = batch['sequence'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            logits = model(sequences, attention_mask)
            predictions = (torch.sigmoid(logits) > 0.5).long()
            
            # Only collect predictions for non-padded positions
            for pred, label, mask in zip(predictions, labels, attention_mask):
                for p, l, m in zip(pred, label, mask):
                    if m.item() == 1:  # Only non-padded positions
                        all_preds.append(p.item())
                        all_labels.append(l.item())
    
    # Calculate metrics
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=['No Signal', 'Signal']))
    
    f1_weighted = f1_score(all_labels, all_preds, average='weighted')
    f1_macro = f1_score(all_labels, all_preds, average='macro')
    mcc = matthews_corrcoef(all_labels, all_preds)
    accuracy = accuracy_score(all_labels, all_preds)
    seq_lev = sequence_level_accuracy(all_labels, all_preds)
    
    print(f"F1 Score (weighted): {f1_weighted:.4f}")
    print(f"F1 Score (macro): {f1_macro:.4f}")
    print(f"Matthews Correlation Coefficient: {mcc:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Sequence level accuracy {seq_lev:.4f}")
    
    return all_preds, all_labels

In [None]:

# Load and preprocess data
data = load_and_preprocess_data(FASTA_PATH)

# Split data
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
train_data, val_data = train_test_split(train_data, test_size=0.2, random_state=42)

print(f"Train set size: {len(train_data)}")
print(f"Validation set size: {len(val_data)}")
print(f"Test set size: {len(test_data)}")

# Create datasets and loaders
train_dataset = ProteinDataset(train_data, MAX_LENGTH)
val_dataset = ProteinDataset(val_data, MAX_LENGTH)
test_dataset = ProteinDataset(test_data, MAX_LENGTH)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Initialize model
model = SimpleSignalPeptideClassifier(MAX_LENGTH, num_aa).to(DEVICE)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train model
train_losses, val_losses = train_model(model, train_loader, val_loader, EPOCHS, DEVICE)

# Evaluate model
print("\nFinal Evaluation:")
predictions, labels = evaluate_model(model, test_loader, DEVICE)

# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Curves')
plt.show()

# Save model
torch.save(model.state_dict(), MODEL_PATH)
print(f"Model saved to {MODEL_PATH}")