#**AMP Classification using ProtBERT Embeddings + Fast MLP**
This notebook extracts ProtBERT embeddings for  peptide sequences and trains a simple Multi-Layer Perceptron (MLP) to classify antimicrobial peptides (AMPs) vs non-AMPs.

In [None]:
!pip install torch transformers scikit-learn numpy pandas tqdm

In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch import nn, optim
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, average_precision_score
import sys

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

##Load Dataset

In [None]:
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    file_path = '/content/drive/MyDrive/ampData.csv'
else:
    file_path = 'ampData.csv'

df = pd.read_csv(file_path)
df['sequence'] = df['sequence'].astype(str).str.upper().str.strip()
df = df.dropna(subset=['sequence','label']).reset_index(drop=True)
df.head()

## Extract ProtBERT Embeddings

In [None]:
tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert')
model = AutoModel.from_pretrained('Rostlab/prot_bert').to(device)

def get_embedding(sequence):
    seq = ' '.join(list(sequence))
    tokens = tokenizer(seq, return_tensors='pt', truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = model(**tokens)
    emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return emb

embeddings = []
for seq in tqdm(df['sequence'], desc='Extracting Embeddings'):
    embeddings.append(get_embedding(seq))

X = np.array(embeddings)
y = df['label'].values

np.save('X_embeddings.npy', X)
np.save('y_labels.npy', y)

## Train-Test Split

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

X_train = torch.tensor(X_train, dtype=torch.float32).to(device)
X_test = torch.tensor(X_test, dtype=torch.float32).to(device)
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1).to(device)
y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1).to(device)

## Define MLP Classifier

In [None]:
class MLPClassifier(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.layers(x)

model_mlp = MLPClassifier(X_train.shape[1]).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model_mlp.parameters(), lr=1e-4)

print(model_mlp)

## Train MLP

In [None]:
epochs = 20
batch_size = 64

for epoch in range(epochs):
    model_mlp.train()
    perm = torch.randperm(X_train.size(0))
    total_loss = 0
    for i in range(0, X_train.size(0), batch_size):
        idx = perm[i:i+batch_size]
        x_batch, y_batch = X_train[idx], y_train[idx]
        optimizer.zero_grad()
        outputs = model_mlp(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")

## Evaluate

In [None]:
model_mlp.eval()
with torch.no_grad():
    preds = model_mlp(X_test).cpu().numpy().flatten()

pred_labels = (preds >= 0.5).astype(int)
print('ROC-AUC:', roc_auc_score(y_test.cpu(), preds))
print('PR-AUC:', average_precision_score(y_test.cpu(), preds))
print('\nClassification Report:\n', classification_report(y_test.cpu(), pred_labels))
print('Confusion Matrix:\n', confusion_matrix(y_test.cpu(), pred_labels))

## Save Model

In [None]:
torch.save(model_mlp.state_dict(), 'fast_mlp_amp.pt')
print('Model saved as fast_mlp_amp.pt')

In [None]:
from google.colab import files
files.download('fast_mlp_amp.pt')