In [55]:
import os
import random
from collections import defaultdict
from Bio import SeqIO
import pandas as pd
import os
import json
import numpy as np
import pickle
from Bio.PDB import PDBParser
import re

In [58]:
def shuffle_and_create_fasta(df, input_fasta, output_dir):
    """
    Shuffle and create FASTA files based on updated selection rules.
    Homomers have sequence IDs starting with 'T', heteromers start with 'H'.

    Args:
        df (pd.DataFrame): DataFrame containing targets and stoichiometry.
        input_fasta (str): Path to the input FASTA file containing all sequences.
        output_dir (str): Directory to save the new shuffled FASTA files.

    Returns:
        pd.DataFrame: Updated DataFrame with a new column `Target*` listing chosen subunits and their stoichiometry.
    """
    os.makedirs(output_dir, exist_ok=True)

    # Load all sequences from the input FASTA file
    all_sequences = list(SeqIO.parse(input_fasta, "fasta"))
    sequence_dict = {rec.id: rec for rec in all_sequences}

    # Create a lookup table for stoichiometry
    stoichiometry_dict = {
        row['Target'].split("s")[0]: row['Stoichiometry'] for _, row in df.iterrows()
    }

    # Target IDs to exclude
    excluded_targets = {"H1111", "H1114", "T1115", "H1137", "H1166", "H1167", "H1168", "H1185"}

    chosen_subunits = []  # List to track chosen subunits for each target

    # Process each target in the DataFrame
    for _, row in df.iterrows():
        target_id = row['Target']
        original_stoichiometry = row['Stoichiometry']
        subunit_counts = parse_stoichiometry(original_stoichiometry)
        
        # Get all sequences except those belonging to the current target and excluded targets
        excluded_sequences = {
            seq_id for seq_id in sequence_dict
            if seq_id.split("s")[0] in excluded_targets or seq_id.startswith(target_id.split("s")[0])
        }
        valid_sequences = [
            seq for seq_id, seq in sequence_dict.items() if seq_id not in excluded_sequences
        ]

        selected_sequences = []
        selected_subunits = []  # Track selected subunit IDs with stoichiometry
        used_multimers = set()  # Track multimers already used for replacement

        for subunit, count in subunit_counts.items():
            # Filter valid sequences based on the rules
            filtered_sequences = []
            for seq in valid_sequences:
                seq_id = seq.id
                seq_stoichiometry = stoichiometry_dict.get(seq_id.split("s")[0], "Unknown")
                seq_counts = parse_stoichiometry(seq_stoichiometry)

                # Exclude same multimer
                if seq_id.split("s")[0] in used_multimers:
                    continue

                # Exclude same count for X≠1
                if count != 1 and any(subunit in seq_counts and seq_counts[subunit] == count for subunit in seq_counts):
                    continue

                filtered_sequences.append(seq)

            # Randomly select a replacement sequence
            chosen = random.sample(filtered_sequences, 1)[0]
            selected_sequences.extend([chosen] * count)  # Repeat it count times
            used_multimers.add(chosen.id.split("s")[0])  # Mark multimer as used

            # Add the sequence ID and its stoichiometry
            original_seq_id = chosen.id.split("s")[0]  # Extract ID before "s" if present
            seq_stoichiometry = stoichiometry_dict.get(original_seq_id, "Unknown")
            selected_subunits.append(f"{chosen.id} ({seq_stoichiometry})")
        
        # Write the new FASTA file
        output_path = os.path.join(output_dir, f"{target_id}.fasta")
        with open(output_path, "w") as f:
            SeqIO.write(selected_sequences, f, "fasta")
        
        # Add selected subunits to the list
        chosen_subunits.append(",".join(selected_subunits))
    
    # Add the new column to the DataFrame
    df['Target*'] = chosen_subunits
    return df

def parse_stoichiometry(stoichiometry):
    """
    Parse stoichiometry string (e.g., "A2B3") into a dictionary of subunits and counts.

    Args:
        stoichiometry (str): Stoichiometry string (e.g., "A2B3").

    Returns:
        dict: Dictionary with subunit labels as keys and counts as values.
    """
    pattern = re.compile(r"([A-Z])(\d+)")
    return {match[0]: int(match[1]) for match in pattern.findall(stoichiometry)}

