In [6]:
import Bio.SeqUtils
import Bio.PDB, Bio.PDB.Residue
from Bio import SeqIO
# pdb_path = 'BDF_GDF_unrelaxed_rank_001_alphafold2_multimer_v3_model_5_seed_000.pdb'
# chain_to_seq = {str(record.id): str(record.seq) for record in SeqIO.parse(pdb_path, 'pdb-atom')}


def extract_sequence_with_seqio(mmcif_path):
    """
    Extracts the sequence from an mmCIF file using Bio.SeqIO.

    Args:
        mmcif_path (str): Path to the mmCIF file.

    Returns:
        str: The amino acid sequence as a single-letter code string.
    """
    sequences = []
    for record in SeqIO.parse(mmcif_path, "cif-atom"):
        sequences.append(str(record.seq))
        print(record.id)
    return ''.join(sequences)

# Usage
# cif_file_path = "fold_mll4_1100_end_rbbp5_wdr5_p53x2/fold_mll4_1100_end_rbbp5_wdr5_p53x2_model_0.cif"
# sequence = extract_sequence_with_seqio(cif_file_path)
# print("Extracted sequence:", sequence)
# print(len(sequence))


????:A
????:B
????:C
????:D
????:E
Extracted sequence: VHSKSSQYRRLRTEWKNNVYLARSRIQGLGLYAAKDLEKHTMVIEYIGTIIRNEVANRREKIYEEQNRGIYMFRINNEHVIDATLTGGPARYINHSCAPNCVAEVVTFDKEDKIIIISSRRIPKGEELTYDYQFDFEDDQHKIPCHCGAWNCRKWMNMNLELLESFGQNYPEEADGTLDCISMALTCTFNRWGTLLAVGCNDGRIVIWDFLTRGIAKIISAHIHPVCSLCWSRDGHKLVSASTDNIVSQWDVLSGDCDQRFRFPSPILKVQYHPRDQNKVLVCPMKSAPVMLTLSDSKHVVLPVDDDSDLNVVASFDRRGEYIYTGNAKGKILVLKTDSQDLVASFRVTTGTSNTTAIKSIEFARKGSCFLINTADRIIRVYDGREILTCGRDGEPEPMQKLQDLVNRTPWKKCCFSGDGEYIVAGSARQHALYIWEKSIGNLVKILHGTRGELLLDVAWHPVRPIIASISSGVVSIWAQNQVENWSAFAPDFKELDENVEYEERESEFDIEDEDKSEPEQTGADAAEDEEVDVTSVDPIAAFCSSDEELEDSKALLYLPIAPEVEDPEENPYGPPPDAVQTSLMDEGASSEKKRQSSADGSQPPKKKPKTTNIELQGVPNDEVHPLLGVKGDGKSKKKQAGRPKGSKGKEKDSPFKPKLYKGDRGLPLEGSAKGKVQAELSQPLTAGGAISELLMATEEKKPETEAARAQPTPSSSATQSKPTPVKPNYALKFTLAGHTKAVSSVKFSPNGEWLASSSADKLIKIWGAYDGKFEKTISGHKLGISDVAWSSDSNLLVSASDDKTLKIWDVSSGKCLKTLKGHSNYVFCCNFNPQSNLIVSGSFDESVRIWDVKTGKCLKTLPAHSDPVSAVHFNRDGSLIVSSSYDGLCRIWDTASGQCLKTLIDDDNPPVSFVKFSPNGKYILAATLDNTLKLWDYSKGKCLK



In [8]:
from Bio.PDB import MMCIFParser, PPBuilder

def extract_sequence_from_mmcif(mmcif_path):
    """
    Extracts the amino acid sequence from the ATOM records of an mmCIF file.

    Args:
        mmcif_path (str): Path to the mmCIF file.

    Returns:
        str: The amino acid sequence as a single-letter code string.
    """
    parser = MMCIFParser(QUIET=True)
    structure = parser.get_structure("Model", mmcif_path)

    # Use the PPBuilder to build polypeptides and extract the sequence
    ppb = PPBuilder()
    sequences = []

    for pp in ppb.build_peptides(structure):
        sequences.append(pp.get_sequence())  # Seq object
        pp.get_id()

    # Combine sequences if there are multiple chains
    return ''.join(str(seq) for seq in sequences)

# Usage
#mmcif_path = "structure.cif"
# sequence = extract_sequence_from_mmcif(cif_file_path)
# print("Extracted sequence:", sequence)


Extracted sequence: VHSKSSQYRRLRTEWKNNVYLARSRIQGLGLYAAKDLEKHTMVIEYIGTIIRNEVANRREKIYEEQNRGIYMFRINNEHVIDATLTGGPARYINHSCAPNCVAEVVTFDKEDKIIIISSRRIPKGEELTYDYQFDFEDDQHKIPCHCGAWNCRKWMNMNLELLESFGQNYPEEADGTLDCISMALTCTFNRWGTLLAVGCNDGRIVIWDFLTRGIAKIISAHIHPVCSLCWSRDGHKLVSASTDNIVSQWDVLSGDCDQRFRFPSPILKVQYHPRDQNKVLVCPMKSAPVMLTLSDSKHVVLPVDDDSDLNVVASFDRRGEYIYTGNAKGKILVLKTDSQDLVASFRVTTGTSNTTAIKSIEFARKGSCFLINTADRIIRVYDGREILTCGRDGEPEPMQKLQDLVNRTPWKKCCFSGDGEYIVAGSARQHALYIWEKSIGNLVKILHGTRGELLLDVAWHPVRPIIASISSGVVSIWAQNQVENWSAFAPDFKELDENVEYEERESEFDIEDEDKSEPEQTGADAAEDEEVDVTSVDPIAAFCSSDEELEDSKALLYLPIAPEVEDPEENPYGPPPDAVQTSLMDEGASSEKKRQSSADGSQPPKKKPKTTNIELQGVPNDEVHPLLGVKGDGKSKKKQAGRPKGSKGKEKDSPFKPKLYKGDRGLPLEGSAKGKVQAELSQPLTAGGAISELLMATEEKKPETEAARAQPTPSSSATQSKPTPVKPNYALKFTLAGHTKAVSSVKFSPNGEWLASSSADKLIKIWGAYDGKFEKTISGHKLGISDVAWSSDSNLLVSASDDKTLKIWDVSSGKCLKTLKGHSNYVFCCNFNPQSNLIVSGSFDESVRIWDVKTGKCLKTLPAHSDPVSAVHFNRDGSLIVSSSYDGLCRIWDTASGQCLKTLIDDDNPPVSFVKFSPNGKYILAATLDNTLKLWDYSKGKCLKTYTGHKNEKYCIFANFSVTGGKWIVSGSEDNLVYI

In [45]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.colors import BoundaryNorm, ListedColormap
from matplotlib.patches import Patch
def plot_pae_plddt(pae_as_arr: np.array, plddt_array,nodes,edges):
    # Define AlphaFold plDDT color scheme
    plddt_cmap = ListedColormap(['#FF7D45', '#FFDB13', '#65CBF3', '#0053D6'])  # Orange, Yellow, Cyan, Blue
    bounds = [0, 50, 70, 90, 100]  # Define confidence levels
    norm = BoundaryNorm(bounds, plddt_cmap.N)  # Use BoundaryNorm for discrete segments
    
    
    # Identify start and end of continuous blocks where plDDT > 40
    # high_confidence = np.where(plddt_array > 40, 1, 0)
    # borders = np.diff(np.concatenate([[0], high_confidence, [0]]))
    # start_indices = np.where(borders == 1)[0]
    # end_indices = np.where(borders == -1)[0] - 1
    
    # Create the figure and axes
    fig, ax = plt.subplots(figsize=(15, 15))
    matrix_ax = ax.matshow(pae_as_arr, vmin=0., vmax=np.max(pae_as_arr), cmap='Greens_r')
    fig.colorbar(matrix_ax, label='PAE', ax=ax)
    
    # Add plDDT bars to X and Y axes
    divider_width = 0.02
    divider_offset = 0.02
    
    # plDDT bar for X-axis
    x_cb_ax = fig.add_axes([
        ax.get_position().x0,
        ax.get_position().y1 + divider_offset,  # Position above the matrix
        ax.get_position().width,
        divider_width
    ])
    x_cb_ax.imshow(
        plddt_array.reshape(1, -1),
        aspect='auto',
        cmap=plddt_cmap,
        norm=norm
    )
    x_cb_ax.set_xticks([])
    x_cb_ax.set_yticks([])
    
    # plDDT bar for Y-axis
    y_cb_ax = fig.add_axes([
        ax.get_position().x1 + divider_offset,  # Position to the right of the matrix
        ax.get_position().y0,
        divider_width,
        ax.get_position().height
    ])
    y_cb_ax.imshow(
        plddt_array.reshape(-1, 1),
        aspect='auto',
        cmap=plddt_cmap,
        norm=norm
    )
    y_cb_ax.set_xticks([])
    y_cb_ax.set_yticks([])
    
    # Draw border lines for high confidence regions
    # for start, end in zip(start_indices, end_indices):
    #     ax.axhline(start - 0.5, color='red', linestyle='-', linewidth=0.5)  # Top border
    #     ax.axhline(end + 0.5, color='red', linestyle='-', linewidth=0.5)    # Bottom border
    #     ax.axvline(start - 0.5, color='red', linestyle='-', linewidth=0.5)  # Left border
    #     ax.axvline(end + 0.5, color='red', linestyle='-', linewidth=0.5)    # Right border

    # Draw square borders and annotate with names
    for name, start, end in nodes:
        rect = Rectangle(
            (start - 0.5, start - 0.5),  # Bottom-left corner
            end - start,  # Width
            end - start,  # Height
            edgecolor='blue',
            facecolor='none',
            linewidth=2
        )
        ax.add_patch(rect)
        # Annotate the square
        ax.text(
            start + (end - start) / 2,  # X position (center of the square)
            start + (end - start) / 2,  # Y position (center of the square)
            name,  # Annotation text
            color='black',
            fontsize=10,
            ha='center',
            va='center',
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.6)  # Optional: Background for text
        )

    #from here code for edges
    # Mapping of square names to their centers for drawing edges
    square_centers = {
        name: (start + (end - start) / 2, start + (end - start) / 2)  # Center of the square
        for name, start, end in nodes
    }
    
    # Draw edges as L-shaped lines
    for edge in edges:
        if edge[0] in square_centers and edge[1] in square_centers:
            x1, y1 = square_centers[edge[0]]
            x2, y2 = square_centers[edge[1]]
            # Break the L-shape into two parts
            # Horizontal line
            ax.plot([x1, x2], [y1, y1], color='#800080', linestyle='-', linewidth=2.5)
            # Vertical line
            ax.plot([x2, x2], [y1, y2], color='#800080', linestyle='-', linewidth=2.5)

    # Add a plDDT legend
    legend_elements = [
        Patch(facecolor='#FF7D45', edgecolor='black', label='Very low (0-50)'),
        Patch(facecolor='#FFDB13', edgecolor='black', label='Low (50-70)'),
        Patch(facecolor='#65CBF3', edgecolor='black', label='Confident (70-90)'),
        Patch(facecolor='#0053D6', edgecolor='black', label='Very high (90-100)')
    ]
    ax.legend(
        handles=legend_elements,
        loc='best',  # Position the legend above the plot
        bbox_to_anchor=(1,1.2),  # Center above the plot with some padding
        title="plDDT Confidence",
        frameon=True
    )
    
    # Add title above the entire plot
    fig.suptitle("Predicted Aligned Error (PAE) with plDDT Bars \nand Confidence>40 Borders", fontsize=16, x=0.4,y=0.9)
    
    # Show the plot
    plt.savefig('all_plot.png', format='png', dpi=300, bbox_inches='tight')
    plt.show()

from typing import List, Tuple
# class SubunitInfo:
#     name: SubunitName
#     chain_names: List[str]
#     indexs: Tuple[int, int]
#     sequence: str
def extract_nodes_indexs(plddt_array: np.array, threshold: int = 40) -> List[Tuple[str,int, int]]:
    """
    Divide full sequence into nodes extracting only segments with pLDDT > threshold.

    Args:
        plddt_array (np.array): Array containing pLDDT scores for each residue.
        threshold (int): Threshold for pLDDT scores.

    Returns:
        List[Tuple[int, int]]: A list of tuples containing the start and end indices of each node.
    """
    nodes = []
    #subunits = []
    # Identify start and end of continuous blocks where pLDDT > threshold
    high_confidence = np.where(plddt_array > threshold, 1, 0)
    print(f"Number of Residues that Omitted by plddt: {len(plddt_array)}-{(high_confidence == 1).sum()}={len(plddt_array)-(high_confidence == 1).sum()}\n")
    borders = np.diff(np.concatenate([[0], high_confidence, [0]]))
    start_indices = np.where(borders == 1)[0]
    end_indices = np.where(borders == -1)[0] - 1

    chains = np.unique(np.array(json_full_data['token_chain_ids']))
    chain_occ_counter = {chain_id:0 for chain_id in chains}
    for start, end in zip(start_indices, end_indices):
        #print((int(start), int(end + 1)))
        chains_ids_in_node = np.unique(np.array(json_full_data['token_chain_ids'][start:end+1]))
        subunit_name = ""
        for chain_id in chains_ids_in_node:
                         subunit_name += chain_id + str(chain_occ_counter[chain_id]+1)
                         chain_occ_counter[chain_id] +=1
        nodes.append((subunit_name,int(start), int(end -1)))

    return nodes

In [46]:
#main
import numpy as np
import json
from Bio.PDB import MMCIFParser
from collections import defaultdict


# Load the JSON file
with open("fold_mll4_1100_end_rbbp5_wdr5_p53x2/fold_mll4_1100_end_rbbp5_wdr5_p53x2_full_data_0.json", "r") as file:
    json_full_data = json.load(file)
    
cif_file_path = "fold_mll4_1100_end_rbbp5_wdr5_p53x2/fold_mll4_1100_end_rbbp5_wdr5_p53x2_model_0.cif"
pae_as_arr = np.array(json_full_data['pae'])
atom_plddts = json_full_data['atom_plddts']
atom_chain_ids = json_full_data['atom_chain_ids']
token_res_ids = json_full_data['token_res_ids']

# Parse the CIF file
parser = MMCIFParser(QUIET=True)
structure = parser.get_structure("model", cif_file_path)

# Map atoms to residues
residue_plddt_sum = defaultdict(float)
residue_atom_count = defaultdict(int)

atom_index = 0  # Track atom index for atom_plddts

for model in structure:
    for chain in model:
        for residue in chain:
            residue_key = (chain.id, residue.id[1])  # Use (chain ID, residue number) as key 
            for atom in residue:
                if atom_index < len(atom_plddts):
                    residue_plddt_sum[residue_key] += atom_plddts[atom_index]
                    residue_atom_count[residue_key] += 1
                    atom_index += 1
                else:
                    print(f"Warning: atom_index {atom_index} exceeds atom_plddts length.")
                    break

# Calculate average plDDT for each residue
average_residue_plddt = {
    key: residue_plddt_sum[key] / residue_atom_count[key]
    for key in residue_plddt_sum
}
# Assuming 'average_residue_plddt' and 'pae_as_arr' are already defined
    # Convert average_residue_plddt to a list in the correct order
plddt_values = [average_residue_plddt[key] for key in sorted(average_residue_plddt)]
plddt_array = np.array(plddt_values)
# nodes =[('A1', 0, 157), ('B1', 158, 165), ('B2', 166, 517), ('B3', 518, 519), ('B4', 523, 548), ('B5', 549, 567), ('B6', 606, 630), ('C1', 722, 1029), ('D1', 1117, 1323), ('D2', 1355, 1384), ('E1', 1509, 1716), ('E2', 1746, 1777)]
# edges = [('B2', 'B4'), ('B2', 'B5'), ('B2', 'B6'), ('B2', 'C1'), ('D2', 'E2')]
nodes = extract_nodes_indexs(plddt_array)
print(nodes)
#plot_pae_plddt(pae_as_arr, plddt_array,nodes,edges)

Number of Residues that Omitted by plddt: 1815-1363=452

[('A1', 0, 155), ('B1', 158, 163), ('B2', 166, 515), ('B3', 518, 517), ('B4', 523, 546), ('B5', 549, 565), ('B6', 606, 628), ('C1', 722, 1027), ('D1', 1117, 1321), ('D2', 1355, 1382), ('E1', 1509, 1714), ('E2', 1746, 1775)]


In [72]:
for node in nodes:
    start = node[1]
    end = node[2]
    # in1=end-1
    # in_af=end+1
    print(f"end-1: {plddt_array[end-1]},end:{plddt_array[end]} end+1: {plddt_array[end+1]}")
print(len(plddt_array))

indices = np.where(plddt_array > 40)[0]  # Returns the indices
print(type(plddt_array))
if plddt_array.ndim == 1:
    print("The array is one-dimensional.")
print(indices)  # Output: [2 4]
print(indices[0])
# while index < len(arr):
    
# for first in arr:
#     while first+1 == 
# while index = prev+1:
#     index++
#     new_nodes.append()
#print(arr)

end-1: 61.56642857142856,end:62.2175 end+1: 52.888888888888886
end-1: 51.458749999999995,end:44.54222222222223 end+1: 45.22666666666667
end-1: 43.196250000000006,end:43.17777777777777 end+1: 42.958333333333336
end-1: 42.958333333333336,end:38.55111111111111 end+1: 40.642857142857146
end-1: 53.22666666666666,end:50.97375 end+1: 46.30222222222221
end-1: 61.61000000000001,end:52.31555555555556 end+1: 42.54625
end-1: 63.45750000000001,end:53.735 end+1: 40.660000000000004
end-1: 81.525,end:73.22125 end+1: 66.16571428571429
end-1: 63.300000000000004,end:62.5975 end+1: 46.53888888888889
end-1: 55.843999999999994,end:50.06666666666666 end+1: 47.14
end-1: 63.945555555555565,end:63.7075 end+1: 47.76222222222222
end-1: 61.71600000000001,end:53.75222222222222 end+1: 49.74
1815
<class 'numpy.ndarray'>
The array is one-dimensional.
[   0    1    2 ... 1774 1775 1776]
0


In [80]:
def find_consecutive_groups_in_order(arr, gap_threshold=1):
    """
    Extracts groups of consecutive integers from an array while preserving the order.

    Args:
        arr (list[int]): The input list of integers.
        gap_threshold (int): The maximum allowed gap between consecutive numbers in a group.

    Returns:
        list[list[int]]: A list of groups, each containing consecutive numbers.
    """

    groups = []  # To store groups of consecutive numbers
    current_group = [int(arr[0])]  # Start with the first number

    for i in range(1, len(arr)):
        if arr[i] - arr[i - 1] <= gap_threshold:
            current_group.append(int(arr[i]))
        else:
            groups.append(current_group)
            current_group = [int(arr[i])]

    # Append the last group
    groups.append(current_group)

    return groups

# Example usage
gap_threshold = 1
groups = find_consecutive_groups_in_order(indices, gap_threshold)
#print(groups)  # Output: [[3, 4, 5, 6], [9, 10], [20, 21, 22], [25]]
for group in groups:
    print(f"{group[0]},{group[len(group)-1]}")

0,156
158,164
166,516
518,518
523,547
549,566
606,629
722,1028
1117,1322
1355,1383
1509,1715
1746,1776


In [84]:
#outline_cheack
for group in groups:
        print(f"{group[0]},{group[len(group)-1]}")
        print(f"end-1:{plddt_array[len(group)-2]} end: {plddt_array[len(group)-1]},end+1: {plddt_array[len(group)]}")
        print(f"aminoac:{plddt_array[len(group)-2]} aminoac: {plddt_array[len(group)-1]},aminoac:{plddt_array[len(group)]}")



0,156
end-1:62.2175 end: 52.888888888888886,36.745000000000005
158,164
end-1:80.705 end: 77.58111111111111,79.58583333333333
166,516
end-1:75.63714285714286 end: 74.5125,73.23714285714286
518,518
end-1:14.981111111111112 end: 54.494285714285716,57.237
523,547
end-1:73.31166666666667 end: 62.127272727272725,71.28999999999999
549,566
end-1:75.98625 end: 81.065,80.75999999999999
606,629
end-1:58.698181818181816 end: 73.31166666666667,62.127272727272725
722,1028
end-1:68.01875 end: 64.4075,64.63
1117,1322
end-1:88.66875 end: 88.32857142857144,87.48499999999999
1355,1383
end-1:69.1225 end: 69.46125,80.9525
1509,1715
end-1:88.32857142857144 end: 87.48499999999999,83.0542857142857
1746,1776
end-1:80.9525 end: 81.72500000000001,79.85249999999999


In [89]:
type(atom_plddts)
#print(atom_plddts)
type(atom_plddts[0])
print(len(atom_plddts))
14156

14156


In [93]:
len(average_residue_plddt)
type(average_residue_plddt)
#print(average_residue_plddt[155])


dict