In [None]:
# Necessary imports

from concurrent.futures import ThreadPoolExecutor, as_completed
from Bio import SeqIO
import pandas as pd
import os, uuid
from collections import defaultdict
import re
from tempfile import NamedTemporaryFile
import time
import datetime
import subprocess

# User-defined configuration - edit here!
BATCH_SIZE = 5000         # Number of peptides to process in one batch
SPECIES = "human"        # or 'mouse', 'rat'
MHC_CLASS = "I"          # or 'II'
KMER_LENGTHS = [9]       # Options are 8-11 for MHC class I, 13-25 for MHC class II
STRIDE = 1              # How many amino acids to skip between k-mers, default is 1 for full overlap
MAX_THREADS = 6         # Number of parallel threads to use for processing
FASTA_FILE = ""         # Path to the FASTA file containing protein sequences
ALLELE_FILE = {
    "I": "",
    "II": ""
}   # Paths to allele files for MHC class I and II, respectively

#Optional to edit:

OUTPUT_DIR = os.path.join("hostdb", SPECIES, f"mhc{MHC_CLASS}")  #This calls the output directory based on species and MHC class; you can change this if you want to.
os.makedirs(OUTPUT_DIR, exist_ok=True)


# Do not edit below this line unless you know what you're doing!

# Helper functions!

# Converts allele names into a safe file-friendly format (e.g., HLA-A*02:01 → HLA-A0201)
def normalize_allele_format(allele):
    safe_name = allele.replace("*", "").replace(":", "").replace("/", "-")
    return safe_name

def load_alleles_from_file(path):
    try:
        with open(path, 'r') as f:
            raw_alleles = [
                line.strip() for line in f 
                if line.strip() and not line.startswith('#')
            ]
        # Normalize and deduplicate
        normalized = set()
        for allele in raw_alleles:
            norm = normalize_allele_format(allele)
            normalized.add(norm)
        return sorted(normalized)
    except FileNotFoundError:
        print(f"File not found: {path}")
        return []
    
def generate_kmers(seq, kmer_lengths, stride):
    for k in kmer_lengths:
        for i in range(0, len(seq) - k + 1, stride):
            yield seq[i:i+k]

# Runs NetMHCpan or NetMHCIIpan using subprocess and checks for errors
def run_netmhcpan(peptide_file, alleles, output_file, mhc_class='I'):
    if mhc_class == 'I':
        cmd = [
            "netMHCpan", "-p", peptide_file,
            "-a", ",".join(alleles), "-BA", "-xls", "-xlsfile", output_file
        ]
    elif mhc_class == 'II':
        cmd = [
            "netMHCIIpan", "-f", peptide_file,
            "-a", ",".join(alleles), "-inptype", "1", "-xls", "-xlsfile", output_file
        ]
    else:
        raise ValueError("Unsupported MHC class")

    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    
    # Raise error if NetMHC failed
    if result.returncode != 0:
        raise RuntimeError(
            f"NetMHC command failed.\n"
            f"Command: {' '.join(cmd)}\n"
            f"Stdout:\n{result.stdout.decode()}\n"
            f"Stderr:\n{result.stderr.decode()}"
        )

    # Check that output file was actually created
    if not os.path.exists(output_file):
        raise FileNotFoundError(
            f"Expected output file not found: {output_file}\n"
            f"Command: {' '.join(cmd)}\n"
            f"Stdout:\n{result.stdout.decode()}\n"
            f"Stderr:\n{result.stderr.decode()}"
        )

def process_peptide_batch(peptide_protein_pairs, allele):
    all_hits = []
    
    try:
        with NamedTemporaryFile(mode='w', suffix='.pep', delete=False) as temp:
            for p, _ in peptide_protein_pairs:
                temp.write(p + '\n')
            pep_file = temp.name

        out_file = f"/tmp/hostdb_{uuid.uuid4().hex}.xls"
        run_netmhcpan(pep_file, [allele], out_file, MHC_CLASS)

        df = pd.read_csv(out_file, sep='\t', skiprows=1)
        rank_col = 'BA_Rank' if 'BA_Rank' in df.columns else 'Rank'
        df = df[['Peptide', rank_col]].dropna()
        df['rank'] = df[rank_col]

        for _, row in df.iterrows():
            peptide = row['Peptide']
            rank = row['rank']
            # Find the associated protein
            protein_id = next((pid for p, pid in peptide_protein_pairs if p == peptide), "unknown")

            record = {
                'peptide': peptide,
                'rank': rank,
                'allele': allele,
                'protein': protein_id,
                'length': len(peptide)
            }

            all_hits.append(record)

    finally:
        for f in [pep_file, out_file]:
            if f and os.path.exists(f):
                os.remove(f)

    return all_hits

# Main script starts here

# Load all alleles
all_alleles_raw = load_alleles_from_file(ALLELE_FILE[MHC_CLASS])

# Filter by species
if SPECIES == "human":
    all_alleles = [a for a in all_alleles_raw if a.startswith("HLA-")]
elif SPECIES == "mouse":
    all_alleles = [a for a in all_alleles_raw if a.startswith("H-2-")]
else:
    raise ValueError(f"Unknown species: {SPECIES}")

print(f"🧬 Running database generation for {len(all_alleles)} alleles.")

# seen peptide dictionary
seen_peptides_by_allele = defaultdict(set)

# Preprocess FASTA proteins once

# Split at any ambiguous amino acid (X, B, Z, J, U, O)
ambiguous_split_pattern = re.compile(r"[XBZJUO]")
all_proteins = []

for record in SeqIO.parse(FASTA_FILE, "fasta"):
    raw_seq = str(record.seq).upper()
    fragments = ambiguous_split_pattern.split(raw_seq)
    for idx, frag in enumerate(fragments):
        frag = frag.strip()
        if len(frag) >= min(KMER_LENGTHS):  # Only store valid-length chunks
            frag_id = f"{record.id}_frag{idx}"
            all_proteins.append((frag_id, frag))
print(f"✅ Loaded {len(all_proteins)} protein fragments from {FASTA_FILE}.")


# This function will process all proteins for a given allele

def process_allele(allele):
    num_hits = 0
    batch_peptides = []
    safe_allele_name = allele.replace("*", "").replace(":", "").replace("/", "-")
    output_path = os.path.join(OUTPUT_DIR, f"{safe_allele_name}_hits.csv")
    num_proteins = len(all_proteins)
    protein_so_far = 0
    batches = 0
    start_time = time.time()
    for protein_id, sequence in all_proteins:
        local_seen = set()
        new_peptides = []
        protein_so_far += 1
        if protein_so_far % 100 == 0:
            elapsed = time.time() - start_time
            time_left = (elapsed / protein_so_far) * (num_proteins - protein_so_far)
            eta = datetime.timedelta(seconds=int(time_left))
            print(f"Processed {protein_so_far}/{num_proteins} proteins for {allele}... ETA: {eta}")
        for p in generate_kmers(sequence, KMER_LENGTHS, STRIDE):
            if p not in seen_peptides_by_allele[allele] and p not in local_seen and p not in batch_peptides:
                new_peptides.append(p)
                local_seen.add(p)

        for p in new_peptides:
            batch_peptides.append((p, protein_id))

        # Process in batches
        while len(batch_peptides) >= BATCH_SIZE:
            batch, batch_peptides = batch_peptides[:BATCH_SIZE], batch_peptides[BATCH_SIZE:]
            new_hits = process_peptide_batch(batch, allele)
            if new_hits:
                batch_df = pd.DataFrame(new_hits)
                batch_df.to_csv(output_path, mode='a', index=False, header=not os.path.exists(output_path))
                seen_peptides_by_allele[allele].update([p for p, _ in batch])
                num_hits += len(new_hits)
                batches += 1
                if batches % 10 == 0:
                    elapsed = time.time() - start_time
                    print(f"Processed {batches} batches for {allele} so far.")

    # Final leftovers
    if batch_peptides:
        new_hits = process_peptide_batch(batch_peptides, allele)
        if new_hits:
                batch_df = pd.DataFrame(new_hits)
                batch_df.to_csv(output_path, mode='a', index=False, header=not os.path.exists(output_path))
                seen_peptides_by_allele[allele].update([p for p, _ in batch_peptides])
                num_hits += len(new_hits)
                batches += 1
                print(f"Processed final batch {batches} for {allele}.")
    
    return num_hits

# Main processing loop using ThreadPoolExecutor for parallel processing

with ThreadPoolExecutor(max_workers=MAX_THREADS) as executor:
    futures = {executor.submit(process_allele, allele): allele for allele in all_alleles}
    total = len(futures)
    
    for future in as_completed(futures):
        allele = futures[future]
        try:
            hits = future.result()
            print(f"✅ {allele}: {hits} hits found.")
        except Exception as e:
            print(f"❌ {allele}: {e}")

print("✅ All alleles processed.")

🧬 Running database generation for 6 alleles.
✅ Loaded 20644 protein fragments from /home/benjamin-marwedel/Downloads/human_proteome.fasta.
❌ HLA-DPA10103-DPB10101: name 'run_netmhcpan' is not defined
❌ HLA-DPA10103-DPB10401: name 'run_netmhcpan' is not defined
❌ HLA-DPA10103-DPB10202: name 'run_netmhcpan' is not defined
❌ HLA-DPA10103-DPB10402: name 'run_netmhcpan' is not defined
❌ HLA-DPA10103-DPB10201: name 'run_netmhcpan' is not defined
❌ HLA-DPA10103-DPB10301: name 'run_netmhcpan' is not defined
✅ All alleles processed.
