# Install & Import Dependencies

In [None]:
# ! pip install rcsb-api
# ! pip install biopython
# ! pip install swifter

In [None]:
from rcsbapi.data import DataQuery as Query
import json
from rcsbapi.search import search_attributes as attrs
import pandas as pd
# from Bio.Align.PairwiseAligner import pairwise2
from Bio.Seq import Seq
# from Bio.Align import substitution_matrices
import re
# from Bio.pairwise2 import format_alignment
import os
import subprocess
import swifter
from concurrent.futures import ProcessPoolExecutor, as_completed
# blosum62_matrix = substitution_matrices.load("BLOSUM62")
import time
from Bio.Align import PairwiseAligner, substitution_matrices

# Initialize aligner once (outside the function for performance)
aligner = PairwiseAligner()
aligner.substitution_matrix = substitution_matrices.load("BLOSUM62")
aligner.mode = "global"  # Equivalent to globalds
aligner.open_gap_score = -10      # Gap open penalty
aligner.extend_gap_score = -0.5   # Gap extension penalty
# Equivalent to penalize_end_gaps=False in pairwise2
aligner.target_end_gap_score = 0.0  # No penalty for gaps at end of target
aligner.query_end_gap_score = 0.0   # No penalty for gaps at end of query
# blosum62

In [None]:
# import os
# num_cores = os.cpu_count()
# if num_cores is not None:
#     print(f"Number of logical CPU cores: {num_cores}")
# else:
#     print("Could not determine CPU count.")

# RCSB Search Query

In [None]:
q1 = attrs.rcsb_entity_source_organism.scientific_name == "Homo sapiens"
q2 = attrs.exptl.method == "X-RAY DIFFRACTION"

In [None]:
query = q1 & q2

In [None]:
results = query()
output = list()
for rid in results:
    output.append(rid)

In [None]:
len(output)

# RCSB Data Query

In [None]:
query = Query(
    input_type="entries",
    input_ids=output,
    return_data_list=[
        "exptl.method",
        "polymer_entities.polymer_entity_instances.rcsb_polymer_entity_instance_container_identifiers.entity_id",
        "polymer_entities.uniprots.rcsb_uniprot_protein.sequence",
        "polymer_entities.entity_poly.pdbx_seq_one_letter_code",
        "polymer_entities.uniprots.rcsb_uniprot_protein.source_organism"
    ]
)
query.exec(progress_bar=True)
response_data = query.get_response()
# response_data

In [None]:
len(response_data['data']['entries'])

# Creating Pandas DF

In [None]:
rcsb_ids = list()
rcsb_entity_ids = list()
uniprot_seqs = list()
pbd_ids = list()

for result in response_data['data']['entries']:
  for entity in result['polymer_entities']:
    if entity['uniprots']:
      for uniprot in entity['uniprots']:
        if uniprot['rcsb_uniprot_protein']['source_organism']['taxonomy_id'] == 9606:
          rcsb_ids.append(result['rcsb_id'])
          rcsb_entity_ids.append(entity['polymer_entity_instances'][0]['rcsb_polymer_entity_instance_container_identifiers']['entity_id'])
          uniprot_seqs.append(uniprot['rcsb_uniprot_protein']['sequence'])
          pbd_ids.append(entity['entity_poly']['pdbx_seq_one_letter_code'])

In [None]:
len(rcsb_ids), len(rcsb_entity_ids), len(uniprot_seqs), len(pbd_ids)

In [None]:
df = pd.DataFrame(
    data = {'rcsb_id': rcsb_ids, 'rcsb_entity_ids': rcsb_entity_ids, 'uniprot_seq': uniprot_seqs, 'pbd_id': pbd_ids}
)

In [None]:
df.head()

In [None]:
df.to_excel("/Users/haripat/Desktop/SF/protein/data/protein_constructs.xlsx")

In [None]:
df.shape

In [None]:
# df = pd.read_excel('/Users/haripat/Desktop/SF/protein/data/protein_constructs.xlsx')
df = pd.read_excel('/Users/haripat/Desktop/SF/protein/data/protein_constructs_w_label_masks.xlsx')

In [None]:
df.head()

In [None]:
df[df['label_mask'].isna()].shape

# Cleaning Data

In [None]:
def sanitize_sequence_advanced(sequence: str) -> str:
    if not isinstance(sequence, str):
        return "" 

    ptm_replacements = {
        "(MSE)": "M",  # Selenomethionine -> Methionine
        "(SEP)": "S",  # Phosphoserine -> Serine
        "(TPO)": "T",  # Phosphothreonine -> Threonine
        "(PTR)": "Y",  # Phosphotyrosine -> Tyrosine
        "(NEP)": "K",  # N-Epsilon-Phospholysine -> Lysine
        "(MLY)": "K",  # Monomethyllysine -> Lysine
        "(M2L)": "K",  # Dimethyllysine -> Lysine
        "(M3L)": "K",  # Trimethyllysine -> Lysine
        "(ALY)": "K",  # Acetyllysine -> Lysine
        "(HLY)": "K",  # Hydroxylysine -> Lysine
        "(M1G)": "R",  # Monomethylarginine -> Arginine
        "(M2G)": "R",  # Dimethylarginine -> Arginine
        "(CIR)": "R",  # Citrulline -> Arginine
        "(HYP)": "P",  # Hydroxyproline -> Proline
        "(CGU)": "E",  # Gamma-carboxyglutamate -> Glutamate
        "(NH2)": "",   # C-Terminal Amidation -> Remove
        "(ACE)": "",   # N-Acetyl Group -> Remove
    }

    processed_seq = sequence
    for mod_code, standard_aa in ptm_replacements.items():
        processed_seq = processed_seq.replace(mod_code, standard_aa)
    valid_chars = "ACDEFGHIKLMNPQRSTVWY"
    sanitized_seq = re.sub(f"[^{valid_chars}]", "X", processed_seq.upper())
    return sanitized_seq

In [None]:
# df['pbd_id'] = df['pbd_id'].str.replace("(MSE)", "M")
# df['pbd_id'] = df['pbd_id'].str.replace("(TPO)", "T")
# df['pbd_id'] = df['pbd_id'].str.replace("(SEP)", "S")
# df['pbd_id'] = df['pbd_id'].str.replace("(NH2)", "")
# df['pbd_id'] = df['pbd_id'].str.replace("(PTR)", "Y")
# df['pbd_id'] = df['pbd_id'].str.replace("(M3L)", "K")
# df['pbd_id'] = df['pbd_id'].str.replace("(NEP)", "K")
df['pdb_sequence_sanitized'] = df['pbd_id'].apply(sanitize_sequence_advanced)
df['pdb_sequence_sanitized'] = df['pdb_sequence_sanitized'].str.replace("U", "C")
df['uniprot_seq'] = df['uniprot_seq'].str.replace("U", "C")

In [None]:
# df[df['pbd_id'].str.contains("\(")]['pbd_id']

# Labeling Data

In [None]:
# df = pd.read_csv('project_data_v3.csv')

In [None]:
df.head()

In [None]:
def create_multi_class_mask_v1(uniprot_sequence: str, pdb_construct_sequence: str) -> list[int] | None:
    """
    Generates a multi-class modification mask by globally aligning a UniProt
    sequence with a PDB construct sequence.

    The mask is the same length as the UniProt sequence. Each position is labeled:
    - 0: Maintained (the residue is the same in both sequences)
    - 1: Deleted (the residue is in UniProt but absent in the PDB construct)
    - 2: Mutated (the residue is present but changed to a different amino acid)

    Args:
        uniprot_sequence: The full-length, canonical protein sequence.
        pdb_construct_sequence: The engineered sequence from the PDB.

    Returns:
        A list of integers (0, 1, or 2) representing the modification mask,
        or None if no alignment can be generated.
    """
    # try:
      # print("--- Performing Global Alignment ---")

    # Perform the alignment
    alignments = aligner.align(uniprot_sequence, pdb_construct_sequence)

    if not alignments:
        return None

    # Best alignment is the first one
    best_alignment = alignments[0]

    # print('best_alignment', best_alignment)

    # Convert alignment object to strings
    aligned_uniprot = best_alignment.aligned[0]
    aligned_pdb = best_alignment.aligned[1]

    # Build the full aligned sequences
    seq1_aligned = []
    seq2_aligned = []

    idx_uniprot, idx_pdb = 0, 0
    for (start1, end1), (start2, end2) in zip(aligned_uniprot, aligned_pdb):
        # Handle gaps in UniProt
        while idx_uniprot < start1:
            seq1_aligned.append(uniprot_sequence[idx_uniprot])
            seq2_aligned.append("-")
            idx_uniprot += 1

        # Handle gaps in PDB
        while idx_pdb < start2:
            seq1_aligned.append("-")
            seq2_aligned.append(pdb_construct_sequence[idx_pdb])
            idx_pdb += 1

        # Add aligned region
        for i in range(end1 - start1):
            seq1_aligned.append(uniprot_sequence[start1 + i])
            seq2_aligned.append(pdb_construct_sequence[start2 + i])

        idx_uniprot = end1
        idx_pdb = end2

    # Build modification mask
    modification_mask = []
    for u_char, p_char in zip(seq1_aligned, seq2_aligned):
        if u_char == "-":
            continue  # Ignore gaps in UniProt
        if p_char == "-":
            modification_mask.append(1)  # Deleted
        elif u_char == p_char:
            modification_mask.append(0)  # Maintained
        else:
            modification_mask.append(2)  # Mutated

    # Validation check
    if len(modification_mask) != len(uniprot_sequence):
        
        return None

    return modification_mask
    # except:
    #   return None, None

def format_alignment_for_display(alignment):
    """Helper function to print the alignment nicely."""
    uniprot_alg, pdb_alg, score, begin, end = alignment

    connector = ""
    for u_char, p_char in zip(uniprot_alg, pdb_alg):
        if u_char == p_char:
            connector += "|"
        elif u_char == '-' or p_char == '-':
            connector += " "
        else:
            connector += "."

    return (
        f"Score: {score}\n\n"
        f"UniProt: {uniprot_alg}\n"
        f"         {connector}\n"
        f"PDB    : {pdb_alg}"
    )

In [None]:
i = 3
pdb_sequence = df.loc[i]['pdb_sequence_sanitized'] #df[df['pbd_id'].str.contains('HHHHHH')].loc[i]['pbd_id']
uniprot_sequence = df.loc[i]['uniprot_seq'] #df[df['pbd_id'].str.contains('HHHHHH')].loc[i]['uniprot_seq']

In [None]:
pdb_sequence

In [None]:
uniprot_sequence

In [None]:
pdb_seq = pdb_sequence
uniprot_seq = uniprot_sequence

result = create_multi_class_mask_v1(uniprot_seq, pdb_seq)

# if result:
#     mask, alignment = result

#     print("\n" + "="*80)
#     print("RESULTS")
#     print("="*80)

#     # print("\n--- Visual Alignment ---")
#     # print(format_alignment_for_display(alignment))

#     print(f"\n--- Multi-Class Mask (first 100 values) ---")
#     print(mask)

#     # --- Statistics ---
#     maintained_count = mask.count(0)
#     deleted_count = mask.count(1)
#     mutated_count = mask.count(2)

#     print("\n--- Summary ---")
#     print(f"UniProt Sequence Length: {len(uniprot_seq)}")
#     print(f"Mask Length:             {len(mask)}")
#     print(f"Residues Maintained (0): {maintained_count}")
#     print(f"Residues Deleted (1):    {deleted_count}")
#     print(f"Residues Mutated (2):    {mutated_count}")
print(result)

In [None]:
# df['label_mask'] = df.swifter.apply(lambda row: create_multi_class_mask(row['uniprot_seq'], row['pdb_sequence_sanitized'])[0], axis=1)

In [None]:
# df['label_mask'] = [
#     create_multi_class_mask(row['uniprot_seq'], row['pdb_sequence_sanitized'])[0]
#     for _, row in df.iterrows()
# ]

In [None]:
def create_multi_class_mask(uniprot_sequence: str, pdb_construct_sequence: str) -> list[int] | None:
    """
    Generates a multi-class modification mask by globally aligning a UniProt
    sequence with a PDB construct sequence.

    The mask is the same length as the UniProt sequence. Each position is labeled:
    - 0: Maintained (the residue is the same in both sequences)
    - 1: Deleted (the residue is in UniProt but absent in the PDB construct)
    - 2: Mutated (the residue is present but changed to a different amino acid)

    Args:
        uniprot_sequence: The full-length, canonical protein sequence.
        pdb_construct_sequence: The engineered sequence from the PDB.

    Returns:
        A list of integers (0, 1, or 2) representing the modification mask,
        or None if no alignment can be generated.
    """
    # print("--- Performing Global Alignment ---")


    alignments = pairwise2.align.globalds(
        uniprot_sequence,
        pdb_construct_sequence,
        blosum62,
        -10,  # Gap open penalty
        -0.5  # Gap extend penalty
    )

    if not alignments:
        print("Error: Could not generate an alignment.")
        return None

    best_alignment = alignments[0]
    aligned_uniprot, aligned_pdb, score, begin, end = best_alignment


    modification_mask = []

    for uniprot_char, pdb_char in zip(aligned_uniprot, aligned_pdb):
        if uniprot_char == '-':
            # This case means there's an insertion in the PDB sequence (e.g., a tag).
            # It doesn't correspond to a position in the UniProt sequence, so we skip it.
            continue

        if pdb_char == '-':
            # A gap in the PDB sequence means the UniProt residue was deleted.
            modification_mask.append(1) # 1 = Deleted
        elif uniprot_char == pdb_char:
            # The characters match, so the residue was maintained.
            modification_mask.append(0) # 0 = Maintained
        else:
            # The characters are different, so the residue was mutated.
            modification_mask.append(2) # 2 = Mutated

    if len(modification_mask) != len(uniprot_sequence):
        print(f"Error: Mask length ({len(modification_mask)}) does not match UniProt sequence length ({len(uniprot_sequence)}).")
        return None

    return modification_mask

In [None]:
for i, row in df[500:5000].iterrows():
    print("-----")
    print(f"IN ROW {i}")
    create_multi_class_mask(row['uniprot_seq'], row['pdb_sequence_sanitized'])

In [None]:
def process_group(group_df):
    ret = dict()
    for i, row in group_df.iterrows():
        ret[i] = create_multi_class_mask(row['uniprot_seq'], row['pdb_sequence_sanitized'])
    return ret

def process_groups(groups):
    ret = dict()
    for group_df in groups:
        op = process_group(group_df=group_df)
        ret.update(op)
    return ret

def partition_dataframe(df, chunk_size=100, list_size=10):
    all_chunked_dfs = []
    num_chunks = (len(df) + chunk_size - 1) // chunk_size

    for i in range(num_chunks):
        start_index = i * chunk_size
        end_index = min((i + 1) * chunk_size, len(df))
        all_chunked_dfs.append(df.iloc[start_index:end_index])

    final_partitions = []
    num_sublists = (len(all_chunked_dfs) + list_size - 1) // list_size

    for i in range(num_sublists):
        start_index_sublist = i * list_size
        end_index_sublist = min((i + 1) * list_size, len(all_chunked_dfs))
        final_partitions.append(all_chunked_dfs[start_index_sublist:end_index_sublist])

    return final_partitions


In [None]:
def prcoess_data_cc(df, blosum62_mat):
    global blosum62
    blosum62 = blosum62_mat
    gps = partition_dataframe(df)
    ops = dict()
    with ProcessPoolExecutor(4) as executor:
        results = executor.map(process_groups, gps)
        for result in results:
            ops.update(result)

In [None]:
df[0:1000].shape

In [None]:
prcoess_data_cc(df[0:1000], blosum62_mat=blosum62_matrix)