## AMPBAN
Here is the detailed description for AMPBAN model, from the building of benchmark dataset, indepdendent test dataset, build model to comparision with SOTA AMP prediction model;

### Benchmark dataset

In [5]:
from Bio import SeqIO
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import os

In [2]:
peptipedia_id = pd.read_csv('./data/peptipedia/activities_canon_AMP-20241204.csv', header=0)['id_sequence'].to_list()
peptipedia = pd.read_csv('./data/peptipedia/activities_canon_AMP-20241204.csv', header=0)['sequence'].to_list()
peptipedia_5_100 = [seq for seq in peptipedia if 5 <= len(seq) <= 100]

In [4]:
len(peptipedia_5_100)

32846

In [9]:
# Input and output file paths
input_file = "./data/Swiss-Prot/uniprotkb_reviewed_true_AND_length_5_TO_2024_12_26.fasta"
temp_file = "temp.fasta"
output_file = "./data/Swiss-Prot/negative_processed_samples.fasta"

# Keywords to exclude antimicrobial, antibiotic, antibacterial, antiviral, antifungal, antimalarial, antiparasitic,
# anti-protist, anticancer, defense, defensin, cathelicidin, histatin, bacteriocin, microbicidal, fungicide
keywords = ["antimicrobial", "antibacterial", "antifungal", "anticancer", "antimalarial", "anti-protist", 
            "antiviral", "antiparasitic", "antibiotic", "antibiofilm", "defense", "defensin", "cathelicidin", 
            "effector", "excreted", "bacteriocin", "microbicidal", "microbicidal", "histatin"]

# Non-standard residues
non_standard_residues = set("BJOUXZ")

def deduplicate_fasta(input_file, output_file):
    """
    Remove duplicate sequences from a FASTA file.

    Args:
        input_file (str): Path to the input FASTA file.
        output_file (str): Path to save the deduplicated FASTA file.
    """
    seen_sequences = set()  # To track unique sequences
    unique_records = []  # To store unique SeqRecord objects
    total_count = 0

    # Read the FASTA file
    for record in SeqIO.parse(input_file, "fasta"):
        total_count += 1
        seq = str(record.seq)  # Convert sequence to string
        if seq not in seen_sequences:
            seen_sequences.add(seq)  # Add sequence to the seen set
            unique_records.append(record)  # Keep the record

    # Report statistics
    print(f"输入文件中总共有 {total_count} 条序列")
    total_sequences = len(unique_records)
    print(f"Number of unique sequences: {total_sequences}")
    # Write the deduplicated sequences to the output file
    SeqIO.write(unique_records, output_file, "fasta")
    print(f"Deduplicated FASTA saved to {output_file}")

deduplicate_fasta(input_file, temp_file)

# Function to filter sequences
def is_negative_sample(record):
    # Check for keywords
    if any(keyword in record.description.lower() for keyword in keywords):
        return False
    # Check sequence length
    if not (5 <= len(record.seq) <= 100):
        return False
    # Check for non-standard residues
    if any(residue in non_standard_residues for residue in record.seq):
        return False
    return True

# Filter sequences
negative_samples = [record for record in SeqIO.parse(temp_file, "fasta") if is_negative_sample(record)]

# Write filtered sequences to output file
SeqIO.write(negative_samples, output_file, "fasta")
os.remove("temp.fasta")
print(f"Filtered {len(negative_samples)} negative samples saved to {output_file}")

输入文件中总共有 58095 条序列
Number of unique sequences: 43913
Deduplicated FASTA saved to temp.fasta
Filtered 42163 negative samples saved to ./data/Swiss-Prot/negative_processed_samples.fasta


In [12]:
import pandas as pd
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
import os

def generate_positive_samples(csv_path, output_fasta_path):
    """
    Generates a FASTA file of positive samples from a CSV.
    """
    print(f"Generating positive samples from {csv_path}...")
    df = pd.read_csv(csv_path)
    
    if 'id_sequence' not in df.columns or 'sequence' not in df.columns:
        raise ValueError("CSV file must contain 'id_sequence' and 'sequence' columns.")
    
    filtered_df = df[(df['sequence'].str.len() >= 5) & (df['sequence'].str.len() <= 100)]
    
    with open(output_fasta_path, "w") as f:
        for _, row in filtered_df.iterrows():
            f.write(f">{row['id_sequence']}\n{row['sequence']}\n")
    
    print(f"Generated {len(filtered_df)} positive samples. Saved to {output_fasta_path}")
    return output_fasta_path

def load_sequences(file_path):
    """Load sequences from a FASTA file into a set for quick lookup."""
    return {str(record.seq) for record in SeqIO.parse(file_path, "fasta")}

def load_sequences_with_ids(file_path):
    """Load sequences from a FASTA file into a list of (id, sequence) tuples, preserving order."""
    return [(record.id, str(record.seq)) for record in SeqIO.parse(file_path, "fasta")]

# 【核心修改处】在合并时为ID添加 _1 / _0 后缀
def combine_positive_and_negative(positive_file, negative_file, combined_output_file):
    """
    Combines positive and negative samples into a single FASTA file.
    Appends '_1' to positive IDs and '_0' to negative IDs.
    """
    print(f"Combining positive samples ({positive_file}) and negative samples ({negative_file})...")
    
    with open(combined_output_file, "w") as out_f:
        # 处理阳性序列，添加 _1 后缀
        for record in SeqIO.parse(positive_file, "fasta"):
            modified_id = f"{record.id}_1"  # 关键：为阳性ID添加 _1
            new_record = SeqRecord(Seq(record.seq), id=modified_id, description="")
            SeqIO.write(new_record, out_f, "fasta")
        
        # 处理阴性序列，添加 _0 后缀
        for record in SeqIO.parse(negative_file, "fasta"):
            modified_id = f"{record.id}_0"  # 关键：为阴性ID添加 _0
            new_record = SeqRecord(Seq(record.seq), id=modified_id, description="")
            SeqIO.write(new_record, out_f, "fasta")
    
    print(f"Combined training data saved to {combined_output_file}")
    return combined_output_file

def filter_unique_sequences(test_files, train_file, output_file):
    """
    Removes sequences from the training file that exist in any of the test files.
    """
    all_test_sequences = set()
    for test_file in test_files:
        test_seqs = load_sequences(test_file)
        all_test_sequences.update(test_seqs)
        print(f"Loaded {len(test_seqs)} sequences from test file: {test_file}")
    
    train_sequences = load_sequences_with_ids(train_file)
    total_train_sequences = len(train_sequences)
    print(f"Total sequences in combined training data: {total_train_sequences}")
    
    # 过滤前统计（现在ID已带 _1/_0，统计会生效）
    count_true_before = sum(1 for seq_id, _ in train_sequences 
                           if "_1" in seq_id or "true" in seq_id.lower())
    count_false_before = sum(1 for seq_id, _ in train_sequences 
                            if "_0" in seq_id or "false" in seq_id.lower())
    
    # 过滤重复序列
    unique_train = []
    count_true = 0
    count_false = 0
    
    for seq_id, seq in train_sequences:
        if seq not in all_test_sequences:
            unique_train.append((seq_id, seq))
            
            # 过滤后统计（同样依赖 _1/_0 后缀）
            if "_1" in seq_id or "true" in seq_id.lower():
                count_true += 1
            elif "_0" in seq_id or "false" in seq_id.lower():
                count_false += 1
    
    # 保存过滤后的训练集（ID已带 _1/_0，可直接用于模型训练）
    with open(output_file, "w") as f:
        for seq_id, seq in unique_train:
            f.write(f">{seq_id}\n{seq}\n")
    
    # 输出正确的统计结果
    print("\n--- Filtering Statistics ---")
    print(f"Total sequences before filtering: {total_train_sequences}")
    print(f"Positive sequences (with '_1') before filtering: {count_true_before}")
    print(f"Negative sequences (with '_0') before filtering: {count_false_before}")
    print("-" * 50)
    print(f"Unique sequences saved to {output_file}: {len(unique_train)}")
    print(f"Positive sequences (with '_1') after filtering: {count_true}")
    print(f"Negative sequences (with '_0') after filtering: {count_false}")

if __name__ == "__main__":
    # 路径配置（根据你的实际路径调整）
    positive_csv = './data/peptipedia/activities_canon_AMP-20241204.csv'
    negative_fasta = './data/Swiss-Prot/negative_processed_samples.fasta'
    test_files = [
        './data/independent_test_data/MFA_independent.fasta',
        './data/independent_test_data/xiao_independent_new.fasta',
        './data/independent_test_data/xu_independent.fasta'
    ]
    output_file = './data/training_data.fasta'

    # 步骤1：从CSV生成阳性序列FASTA（无后缀）
    positive_fasta = generate_positive_samples(positive_csv, "./data/peptipedia/positive_samples.fasta")
    
    # 步骤2：合并阳/阴性序列，并添加 _1/_0 后缀（核心修复步骤）
    combined_train_file = "./data/combined_training_data.fasta"
    combine_positive_and_negative(positive_fasta, negative_fasta, combined_train_file)
    
    # 步骤3：过滤训练集中与测试集重复的序列
    filter_unique_sequences(test_files, combined_train_file, output_file)
    
    print("\nProcess completed successfully!")

Generating positive samples from ./data/peptipedia/activities_canon_AMP-20241204.csv...
Generated 32846 positive samples. Saved to ./data/peptipedia/positive_samples.fasta
Combining positive samples (./data/peptipedia/positive_samples.fasta) and negative samples (./data/Swiss-Prot/negative_processed_samples.fasta)...
Combined training data saved to ./data/combined_training_data.fasta
Loaded 1221 sequences from test file: ./data/independent_test_data/MFA_independent.fasta
Loaded 1756 sequences from test file: ./data/independent_test_data/xiao_independent_new.fasta
Loaded 3072 sequences from test file: ./data/independent_test_data/xu_independent.fasta
Total sequences in combined training data: 75009

--- Filtering Statistics ---
Total sequences before filtering: 75009
Positive sequences (with '_1') before filtering: 32846
Negative sequences (with '_0') before filtering: 42163
--------------------------------------------------
Unique sequences saved to ./data/training_data.fasta: 69646
Po

### Extract features for training data and test data

In [None]:
# First step: Predict structure for peptides
# The core script predict_structure_ESMfold.py was embedd in below shell script, change directory path for different datasets
# Which may need run in another environment "conda activate django"
!bash ./script/predict_all_fasta_ESMFold_individual_folders.sh

In [None]:
# Second step: Extract structure feature with Progres
# The core script generate_egnn_fea.py was embedd in below shell script, change directory path for different datasets
# Which may need run in another environment "conda activate prog"
!bash ./script/generate_stru_fea_indep.sh

In [None]:
# The structure embedding of training data and test data were below
/mnt/e/Doctorale/Project/AMPBAN/data/protein_egnn_embeddings.pt
/mnt/e/Doctorale/Project/AMPBAN/data/independent_test_data/xu_independent_structure.pt
/mnt/e/Doctorale/Project/AMPBAN/data/independent_test_data/xiao_independent_new_stru.pt
/mnt/e/Doctorale/Project/AMPBAN/data/independent_test_data/MFA_independent_stru.pt 

In [14]:
# Third step: Extract sequence feature
import torch
from transformers import AutoModelForMaskedLM
from Bio import SeqIO
import os
from tqdm import tqdm  # For progress bars

def embed_sequences(fasta_file, output_path, model_name='./script/Synthyra/ESMplusplus_small', batch_size=128, max_len=512,
                    pooling_types=['mean', 'cls'], num_workers=4, use_gpu=True):
    """
    Embeds protein sequences from a FASTA file using an ESM model and saves the embeddings.

    Args:
        fasta_file (str): Path to the FASTA file containing protein sequences.
        output_path (str): Path to save the embeddings as a .pth file.
        model_name (str, optional): Name of the ESM model to use.
            Defaults to "ESMplusplus_small".
        batch_size (int, optional): Batch size for embedding.
            Adjust based on your GPU memory. Defaults to 2.
        max_len (int, optional): Maximum sequence length.
            Sequences longer than this will be truncated. Defaults to 512.
        pooling_types (list, optional): Types of pooling to apply.
            Defaults to ['mean', 'cls'].
        num_workers (int, optional): Number of worker processes for data loading.
            Defaults to 4.
        use_gpu (bool, optional): Whether to use GPU if available. Defaults to True.
    """
    # 1. Load the ESM model and tokenizer
    try:
        model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
        tokenizer = model.tokenizer
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    device = torch.device('cuda' if use_gpu and torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()  # Set model to evaluation mode

    # 2. Read sequences from the FASTA file
    try:
        records = list(SeqIO.parse(fasta_file, "fasta"))
    except Exception as e:
        print(f"Error reading FASTA file: {e}")
        return
    sequences = [str(record.seq) for record in records]
    sequence_ids = [record.id for record in records] #get the sequence ids

    # 3. Embed the sequences
    embedding_dict = {}
    with torch.no_grad():  # Disable gradient calculation for inference
        for i in tqdm(range(0, len(sequences), batch_size), desc="Embedding Sequences"):
            batch_sequences = sequences[i:i + batch_size]
            batch_ids = sequence_ids[i:i+batch_size] #get the ids for the current batch

            # Tokenize the batch
            inputs = tokenizer(batch_sequences, return_tensors="pt", truncation=True, padding=True, max_length=max_len)
            inputs = {key: val.to(device) for key, val in inputs.items()}

            # Get the model outputs
            outputs = model(**inputs)
            # Extract the last hidden state
            last_hidden_state = outputs.last_hidden_state

            # 4. Apply pooling
            pooled_embeddings = []
            if 'mean' in pooling_types:
                # Mean pooling (handle padding)
                input_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(last_hidden_state.size()).float()
                sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
                sum_mask = torch.sum(input_mask_expanded, 1)
                mean_pooled = sum_embeddings / sum_mask
                pooled_embeddings.append(mean_pooled)
            if 'cls' in pooling_types:
                # CLS token pooling
                cls_pooled = last_hidden_state[:, 0]
                pooled_embeddings.append(cls_pooled)

            # Concatenate pooling results
            if len(pooled_embeddings) > 1:
                batch_embeddings = torch.cat(pooled_embeddings, dim=1)
            else:
                batch_embeddings = pooled_embeddings[0]

            # Convert to the desired dtype
            batch_embeddings = batch_embeddings.to(torch.float32).cpu()  # Move to CPU before saving

            # Store the embeddings in the dictionary, use sequence IDs as keys
            for j, seq_id in enumerate(batch_ids):
                embedding_dict[seq_id] = batch_embeddings[j]

    # 5. Save the embeddings
    try:
        torch.save(embedding_dict, output_path)
        print(f"Embeddings saved to: {output_path}")
    except Exception as e:
        print(f"Error saving embeddings: {e}")
        return

if __name__ == "__main__":
    # Example usage:
    fasta_file = "/mnt/e/Doctorale/Project/PlantAMP/plant_amp/amp_test_pos.fa"  # Replace with your FASTA file
    output_path = "/mnt/e/Doctorale/Project/PlantAMP/plant_amp/amp_test_pos.pth"  # Replace with your desired output path
    embed_sequences(fasta_file, output_path)
    print("Done!")

modeling_esm_plusplus.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/Synthyra/ESMplusplus_small:
- modeling_esm_plusplus.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Embedding Sequences: 100%|████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.06s/it]

Embeddings saved to: /mnt/e/Doctorale/Project/PlantAMP/plant_amp/amp_test_pos.pth
Done!





In [None]:
# The sequence embedding of training data and test data were below
/mnt/e/Doctorale/Project/AMPBAN/data/training_pos_aug.pth
/mnt/e/Doctorale/Project/AMPBAN/data/training_neg_aug.pth
/mnt/e/Doctorale/Project/AMPBAN/data/independent_test_data/xu_embeddings_by_id.pth
/mnt/e/Doctorale/Project/AMPBAN/data/independent_test_data/Xiao_embeddings_by_id.pth
/mnt/e/Doctorale/Project/AMPBAN/data/independent_test_data/MFA_embeddings_by_id.pth

### Training model