In [2]:
# Install dependencies for the script
!pip install --quiet numpy pandas matplotlib logomaker biopython

import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import logomaker
from itertools import combinations
from collections import defaultdict
from Bio.Data import IUPACData

# =============================================================
# 1. Imports and configuration
# =============================================================
INPUT_FILE = "2ongcealigned.cif"  # All proteins should be aligned in a single .cif file
PROTEIN_OF_INTEREST = "2ONG_C".strip().upper()
ERROR_THRESHOLD = 4.0
HARDEST_FIRST = True
DELTA = "Δ"
FREQ_CUTOFF = 0.15

# =============================================================
# 1b. Output directory setup
# =============================================================
OUTPUT_DIR = "StructuralAlignmentOut"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Normalized output filenames
OUTPUT_ASSIGNMENTS = os.path.join(OUTPUT_DIR, "residue_assignments.csv")
PAIRWISE_DELTA_OUT = os.path.join(
    OUTPUT_DIR,
    f"pairwise_delta_freq_ge_{str(FREQ_CUTOFF).replace('.', 'p')}.csv"
)
TOPDELTA_OUT = os.path.join(
    OUTPUT_DIR,
    f"topdelta_per_position_freq_ge_{str(FREQ_CUTOFF).replace('.', 'p')}.csv"
)

# =============================================================
# 2. Grouping options (dynamic, merged labels allowed)
# =============================================================
# GROUP_MODE:
#   - "grouped": use GROUP_REGEX to extract a token from the protein id, then map via GROUP_SPEC
#   - "all":     treat all proteins as one group "All"
GROUP_MODE = "grouped"  # or "all"
GROUP_REGEX = r".*_([A-Za-z]+)$"   # captures trailing token after underscore

# Map suffix tokens to human labels. Multiple tokens may map to the same label.
# Example: C and M collapse to "Monocyclic"; B and T collapse to "Bicyclic"; L -> "Linear".
GROUP_SPEC = [
    ("L", "Linear"),
    ("C", "Monocyclic"),
]

# Derived mappings and soft validations
REGEX_MAP = {tok: name for tok, name in GROUP_SPEC}
ORDERED_LABELS = [name for _, name in GROUP_SPEC]  # may contain duplicates by design

# Helper: unique in order
_def_seen = set()
ORDERED_LABELS_UNIQUE = []
for lbl in ORDERED_LABELS:
    if lbl not in _def_seen:
        ORDERED_LABELS_UNIQUE.append(lbl)
        _def_seen.add(lbl)

# TARGET_MODE:
#   - "manual": use MANUAL_RESIDUES below
#   - "all":    use all residue numbers present in the reference protein
TARGET_MODE = "manual"  # or "all"

# Manually specified residues (used when TARGET_MODE = "manual")
MANUAL_RESIDUES = [
    63, 315, 324, 345, 348, 349, 352, 353, 356, 427, 430, 452, 453, 454,
    458, 492, 493, 496, 499, 500, 502, 503, 504, 507, 509, 512, 573,
    577, 578, 579, 581, 582
]

# 3-letter to single letter AA mapping
AA3_TO_1 = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
AA3_TO_1.update({'SEC': 'U', 'PYL': 'O'})  # Include uncommon amino acids

# =============================================================
# 3. Plot configuration (fonts + logo colour + chunking readability)
#    All size values specified in centimeters (cm)
# =============================================================
# Fonts
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = ['Times New Roman', 'Times', 'DejaVu Serif', 'serif']

# Sequence-logo colour scheme (logomaker built-ins: 'chemistry', 'hydrophobicity', 'charge', 'monochrome')
LOGO_COLOR_SCHEME = 'chemistry'

# Readability of sequence logos (units: cm)
MAX_POSITIONS_PER_FIG = 40   # alignment positions per logo figure (chunk size)
MIN_GLOBAL_FREQ = 0.02       # minimum AA frequency to appear in logo 
WIDTH_PER_POS_CM = 0.7       # cm added per position to figure width
BASE_WIDTH_CM = 5.0          # base figure width (cm)
FIG_HEIGHT_CM = 10.0         # figure height (cm)
DPI = 400                    # saved figure resolution
XTICK_LABEL_ROT = 90         # x-axis tick rotation (degrees)

# Conversion helper: centimeters to inches for Matplotlib
CM_TO_IN = 1.0 / 2.54

# Save or not save PNGs of sequence logos
SAVE_LOGO_PNG = False

# =============================================================
# 3b. Quick group preview (headers only)
# =============================================================

def _scan_protein_ids(path: str):
    prots = []
    with open(path, "r") as f:
        for line in f:
            if line.startswith("data_"):
                prots.append(line[len("data_"):].strip().upper())
    return prots


def _group_label_from_id(pid: str) -> str:
    if GROUP_MODE == "all":
        return "All"
    m = re.match(GROUP_REGEX, pid)
    if m:
        token = m.group(1)
        return REGEX_MAP.get(token, token)
    return "Unassigned"

# Run the preview
_all_ids = _scan_protein_ids(INPUT_FILE)
print("\n=== Group Preview (from file headers only) ===")
if not _all_ids:
    print("No 'data_' blocks found in the input file.")
else:
    if GROUP_MODE == "all":
        print(f"GROUP_MODE='all' -> single group 'All' (n={len(_all_ids)})")
        for p in sorted(_all_ids):
            print(f"  - {p}")
    else:
        _preview_groups = {}
        for pid in _all_ids:
            g = _group_label_from_id(pid)
            _preview_groups.setdefault(g, []).append(pid)

        total_named = sum(len(_preview_groups[g]) for g in _preview_groups if g != "Unassigned")
        print(f"Total proteins: {len(_all_ids)} | Grouped (named): {total_named}")
        for g in sorted(_preview_groups):
            members = sorted(_preview_groups[g])
            print(f"\nGroup: {g}  (n={len(members)})")
            for m in members:
                print(f"  - {m}")


=== Group Preview (from file headers only) ===
Total proteins: 88 | Grouped (named): 88

Group: Linear  (n=59)
  - A0A059SVB0_L
  - A0A068B0N9_L
  - A0A097ZLN9_L
  - A0A097ZLP5_L
  - A0A097ZLS9_L
  - A0A140KFH3_L
  - A0A1Q1N939_L
  - A0A2R3ZE38_L
  - A0A2R3ZE39_L
  - A0A345BJ04_L
  - A0A3G9EWY9_L
  - A0A4Y5QWA6_L
  - A0A4Y5QZ62_L
  - A0A5B9G8E4_L
  - A0A6C0M6B5_L
  - A0A7G8EIM4_L
  - A0A7L7S5P2_L
  - A0A7L7S5U1_L
  - A0A7L7T499_L
  - A0A7L7TAQ2_L
  - ATY48639_L
  - B1NA83_L
  - B1NA84_L
  - B6F137_L
  - C0KWV5_L
  - C0KWV7_L
  - C0KY88_L
  - C0PPR1_L
  - D4N3A0_L
  - D4N3A1_L
  - D5SL78_L
  - DINCTG000_NP1212847_L
  - F2XF93_L
  - F2XFA6_L
  - F8TWD1_L
  - F8TWD2_L
  - G5CV39_L
  - H6UQ81_L
  - H6WBC5_L
  - HMOCHR44410_L
  - J7HWK5_L
  - P0CV94_L
  - P0CV95_L
  - Q1XBU5_L
  - Q2XSC5_L
  - Q5SBP3_L
  - Q675L2_L
  - Q6ZH94_L
  - Q84UV0_L
  - Q8H2B4_L
  - Q96376_L
  - Q9SPN0_L
  - Q9SPN1_L
  - R4I6X2_L
  - SLECTG065_NP871_L
  - SLECTG127_NP665_L
  - SLECTG240_NP1756_L
  - U5PZT6_L
  - VO

In [None]:
# =============================================================
# 4. Function: extract alpha carbons
# =============================================================

def extract_alpha_carbons(path: str):
    """Parse CIF/PDB-like alignment file and extract CA coordinates."""
    out = []
    current = None
    with open(path, "r") as f:
        for line in f:
            if line.startswith("data_"):
                current = line[len("data_"):].strip().upper()
                continue
            if not current:
                continue
            if line.startswith("ATOM") and " CA " in line:
                fields = re.split(r"\s+", line.strip())
                try:
                    resname_3 = fields[5]
                    resnum = int(fields[8])
                    x, y, z = float(fields[10]), float(fields[11]), float(fields[12])
                except Exception:
                    continue
                aa1 = AA3_TO_1.get(resname_3.upper(), "-")
                out.append((current, resnum, aa1, np.array([x, y, z], float)))
    return out

# =============================================================
# 5. Extract and prepare reference data 
# =============================================================
alpha_carbons = extract_alpha_carbons(INPUT_FILE)

specified_residues = {rnum: aa for pid, rnum, aa, _ in alpha_carbons if pid == PROTEIN_OF_INTEREST}

# Decide the target residue set based on TARGET_MODE
if TARGET_MODE == "all":
    TARGET_POSITIONS = sorted({rnum for pid, rnum, aa, _ in alpha_carbons if pid == PROTEIN_OF_INTEREST})
else:
    TARGET_POSITIONS = list(MANUAL_RESIDUES)

# Collect chosen target CA coords from reference
target_coords = [
    (rnum, aa, coords)
    for pid, rnum, aa, coords in alpha_carbons
    if pid == PROTEIN_OF_INTEREST and rnum in TARGET_POSITIONS
]

ref_resname = {rnum: aa for rnum, aa, _ in target_coords}

# =============================================================
# 6. Group residues by protein (excluding reference protein)
# =============================================================
by_protein = defaultdict(list)
for pid, rnum, aa, xyz in alpha_carbons:
    if pid != PROTEIN_OF_INTEREST:
        by_protein[pid].append((rnum, aa, xyz))

assign_rows = []

# =============================================================
# 7. Residue mapping and assignment (one-to-one per protein, with error checking)
# =============================================================
for pid, residues in by_protein.items():
    candidates, stats = {}, {}
    for t_rnum, _, t_xyz in target_coords:
        pairs = [(float(np.linalg.norm(t_xyz - xyz)), rnum, aa) for rnum, aa, xyz in residues]
        pairs.sort(key=lambda x: (x[0], x[1]))
        candidates[t_rnum] = pairs
        within = [p for p in pairs if p[0] <= ERROR_THRESHOLD]
        min_d = pairs[0][0] if pairs else float('inf')
        second = pairs[1][0] if len(pairs) > 1 else float('inf')
        stats[t_rnum] = {'num_in': len(within), 'min_d': min_d, 'ambiguity': second - min_d}

    order = sorted([t[0] for t in target_coords],
                   key=(lambda r: (stats[r]['num_in'], stats[r]['ambiguity'], stats[r]['min_d'])) if HARDEST_FIRST
                        else (lambda r: stats[r]['min_d']))

    owner, assign = {}, {}

    # initialize per-target diagnostic row
    for t_rnum in [t[0] for t in target_coords]:
        best = candidates[t_rnum][0] if candidates[t_rnum] else (float('inf'), '-', '-')
        assign[t_rnum] = {
            'Target Residue Number': t_rnum,
            'Target Residue Name': ref_resname[t_rnum],
            'Protein': pid,
            'Residue Name': '-',
            'Residue Number': '-',
            'Distance': best[0],
            'Status': 'no_unused_within_threshold',
            'Nearest Residue Name': best[2],
            'Nearest Residue Number': best[1],
            'Nearest Distance': best[0],
        }

    def set_assign(t_num, rnum, aa, dist):
        owner[rnum] = (t_num, dist)
        assign[t_num] = {
            'Target Residue Number': t_num,
            'Target Residue Name': ref_resname[t_num],
            'Protein': pid,
            'Residue Name': aa,
            'Residue Number': rnum,
            'Distance': dist,
            'Status': 'assigned',
        }

    def reassign_displaced(t_num):
        for d2, r2, aa2 in candidates[t_num]:
            if d2 <= ERROR_THRESHOLD and r2 not in owner:
                set_assign(t_num, r2, aa2, d2)
                return True
        assign[t_num]['Status'] = 'lost_swap_no_alt'
        return False

    for t_rnum in order:
        for d, rnum, aa in candidates[t_rnum]:
            if d > ERROR_THRESHOLD:
                break
            cur = owner.get(rnum)
            if cur is None:
                set_assign(t_rnum, rnum, aa, d)
                break
            prev_t, prev_d = cur
            if d < prev_d:
                set_assign(t_rnum, rnum, aa, d)
                if not reassign_displaced(prev_t):
                    pass
                break

    for t_rnum in [t[0] for t in target_coords]:
        assign_rows.append(assign[t_rnum])

# =============================================================
# 8. Save and validate assignments
# =============================================================
assign_df = pd.DataFrame(assign_rows)
assign_df.to_csv(OUTPUT_ASSIGNMENTS, index=False)
print(f"Saved residue assignments to: {OUTPUT_ASSIGNMENTS}")

mask_assigned = assign_df["Status"] == "assigned"
reuse = (
    assign_df[mask_assigned]
    .groupby(["Protein", "Residue Number"])['Target Residue Number']
    .nunique()
)
reuse = reuse[reuse > 1]
if not reuse.empty:
    print("WARNING: some residues were reused across targets:")
    print(reuse)

# =============================================================
# 9. Grouping 
# =============================================================
if GROUP_MODE == 'all':
    assign_df['Group'] = 'All'
else:
    def _extract_group(protein_id: str) -> str:
        m = re.match(GROUP_REGEX, protein_id)
        if m:
            token = m.group(1)
            return REGEX_MAP.get(token, token)
        return 'Unassigned'
    assign_df['Group'] = assign_df['Protein'].apply(_extract_group)

print("\n=== Group Summary ===")
present_groups_raw = [g for g in sorted(assign_df['Group'].dropna().unique().tolist()) if g != 'Unassigned']
if GROUP_MODE == 'all':
    print("Running in 'all' mode: single group 'All'.")
else:
    if not present_groups_raw:
        print("No valid groups found (all proteins Unassigned).")
    else:
        # keep order based on ORDERED_LABELS_UNIQUE, then append any others
        seen = set()
        present_groups = []
        for lbl in ORDERED_LABELS_UNIQUE:
            if lbl in present_groups_raw and lbl not in seen:
                present_groups.append(lbl)
                seen.add(lbl)
        for lbl in present_groups_raw:
            if lbl not in seen:
                present_groups.append(lbl)
                seen.add(lbl)

        print(f"Total groups detected: {len(present_groups)}")
        for g in present_groups:
            members = sorted(assign_df.loc[assign_df['Group'] == g, 'Protein'].unique().tolist())
            print(f"\nGroup: {g}  (n={len(members)})")
            for m in members:
                print(f"  - {m}")

# =============================================================
# 10. Frequency computation 
# =============================================================
freq_rows = []
use_groups = (present_groups if GROUP_MODE == 'grouped' else ['All'])

for group in use_groups:
    gdf = assign_df[assign_df['Group'] == group] if GROUP_MODE == 'grouped' else assign_df.copy()
    for pos in TARGET_POSITIONS:
        sdf = gdf[gdf['Target Residue Number'] == pos]
        counts = sdf['Residue Name'].value_counts(normalize=True)
        for aa, fr in counts.items():
            freq_rows.append({
                "Group": group,
                "Residue Position": pos,
                "Residue Name": aa,
                "Frequency": fr,
            })

freq_df = pd.DataFrame(freq_rows)
if freq_df.empty:
    raise SystemExit("No frequency data computed. Check grouping or targets and inputs.")

# Reindex positions for plotting 
unique_positions = freq_df["Residue Position"].unique()
position_mapping = {pos: idx for idx, pos in enumerate(sorted(unique_positions))}
reverse_position_mapping = {v: k for k, v in position_mapping.items()}
freq_df["Reindexed Position"] = freq_df["Residue Position"].map(position_mapping)

# =============================================================
# 11. Δ-frequency summary per position
# =============================================================
if GROUP_MODE == 'grouped':
    if len(use_groups) < 2:
        print("\nΔ-frequency summary skipped, need at least 2 groups.")
    else:
        pivot = freq_df.pivot_table(
            index=["Residue Position", "Residue Name"],
            columns="Group",
            values="Frequency",
            fill_value=0.0,
        ).reset_index()

        # ensure all present groups exist as columns
        for gcol in use_groups:
            if gcol not in pivot.columns:
                pivot[gcol] = 0.0

        pair_rows = []
        for g1, g2 in combinations(use_groups, 2):
            sub = pivot[["Residue Position", "Residue Name", g1, g2]].copy()
            sub["delta_mag"] = (sub[g1] - sub[g2]).abs()
            # Keep only letters passing frequency cutoff in at least one group
            mask = (sub[g1] >= FREQ_CUTOFF) | (sub[g2] >= FREQ_CUTOFF)
            sub = sub[mask]
            sub.insert(0, "Group 1", g1)
            sub.insert(1, "Group 2", g2)
            pair_rows.append(sub)

        if pair_rows:
            pairwise_delta_df = pd.concat(pair_rows, ignore_index=True)
            pairwise_delta_df.to_csv(PAIRWISE_DELTA_OUT, index=False)
            print(f"\nPairwise Δ-frequency table saved to {PAIRWISE_DELTA_OUT}")

            # Per-position top delta across all pairs and residues
            topdelta = (
                pairwise_delta_df.sort_values(["Residue Position", "delta_mag"], ascending=[True, False])
                .groupby("Residue Position", as_index=False)
                .first()
            )
            topdelta.to_csv(TOPDELTA_OUT, index=False)
            print(f"Top-Δ per position table saved to {TOPDELTA_OUT}")
            try:
                from IPython.display import display
                display(topdelta)
            except Exception:
                print(topdelta.to_string(index=False))
        else:
            print("\nNo Δ rows passed the frequency cutoff. Try lowering FREQ_CUTOFF.")
else:
    print("\nSkipping Δ-frequency summary because GROUP_MODE='all'.")

# =============================================================
# 12. Sequence logos plotting
# =============================================================
valid_letters = set(AA3_TO_1.values())
_groups_for_plot = use_groups

for group in _groups_for_plot:
    gdf = freq_df[freq_df['Group'] == group].copy()
    gp = gdf.pivot_table(index='Reindexed Position', columns='Residue Name', values='Frequency', fill_value=0.0)
    aa_cols = [c for c in gp.columns if isinstance(c, str) and len(c) == 1 and c in valid_letters]
    gp = gp[sorted(aa_cols)]
    if gp.empty:
        print(f"[skip] No valid AA columns for group '{group}'.")
        continue

    # Prune globally-rare letters for readability
    global_max = gp.max(axis=0)
    keep_cols = [c for c in gp.columns if global_max[c] >= MIN_GLOBAL_FREQ]
    gp = gp[keep_cols] if keep_cols else gp

    gp = gp.sort_index()
    all_pos = gp.index.to_list()
    chunks = [all_pos[i:i+MAX_POSITIONS_PER_FIG] for i in range(0, len(all_pos), MAX_POSITIONS_PER_FIG)]

    for ci, chunk in enumerate(chunks, start=1):
        window = gp.loc[chunk]
        xticklabels = [reverse_position_mapping[i] for i in window.index]
        window_seq = window.reset_index(drop=True)  # align glyphs to 0..n-1 for plotting only

        # Compute figure size in cm, convert to inches for Matplotlib
        fig_width_cm = max(BASE_WIDTH_CM, BASE_WIDTH_CM + WIDTH_PER_POS_CM * len(xticklabels))
        fig_height_cm = FIG_HEIGHT_CM
        fig, ax = plt.subplots(figsize=(fig_width_cm * CM_TO_IN, fig_height_cm * CM_TO_IN))

        _ = logomaker.Logo(window_seq, ax=ax, color_scheme=LOGO_COLOR_SCHEME)

        ax.set_xlim(-0.5, len(xticklabels) - 0.5)
        ax.set_xlabel('Residue Position', fontsize=14, fontweight='bold')
        ax.set_ylabel('Frequency', fontsize=14, fontweight='bold')
        ax.set_yticks(np.arange(0, 1.01, 0.2))
        ax.set_yticklabels([f"{t:.1f}" for t in np.arange(0, 1.01, 0.2)], fontsize=11, fontweight='bold')
        ax.set_xticks(range(len(xticklabels)))
        ax.set_xticklabels(xticklabels, rotation=XTICK_LABEL_ROT, fontsize=10, fontweight='bold')
        for s in ax.spines.values():
            s.set_linewidth(1.2)

        # Title shows the group name (e.g., "Linear", "Monocyclic", "Bicyclic")
        ax.set_title(group, fontsize=16, pad=10)

        plt.tight_layout()
        if SAVE_LOGO_PNG:
            safe_group = re.sub(r"[^A-Za-z0-9_\-]+", "_", group).lower()
            first_label, last_label = xticklabels[0], xticklabels[-1]
            out_png = os.path.join(
                OUTPUT_DIR,
                f"{safe_group}_logo_chunk{ci}_pos{first_label}-{last_label}.png"
            )
            plt.savefig(out_png, dpi=DPI, bbox_inches='tight')
        plt.show()
        plt.close(fig)

# =============================================================
# 13. Per-protein summary output
# =============================================================
try:
    from IPython.display import display as _display
except Exception:
    _display = None

_df = assign_df.copy()
_df['Target Residue Number'] = pd.to_numeric(_df['Target Residue Number'], errors='coerce')
_df['Residue Number'] = pd.to_numeric(_df['Residue Number'].where(_df['Residue Number'] != '-', np.nan), errors='coerce')

for pid in sorted(_df['Protein'].unique()):
    sub = _df[_df['Protein'] == pid].copy()
    print("\n==============================")
    print(f"Protein: {pid} (rows={len(sub)})")
    cols = ['Target Residue Number', 'Target Residue Name', 'Residue Name', 'Residue Number', 'Distance', 'Status']
    out = sub.sort_values(['Target Residue Number'])[cols]
    if _display:
        _display(out)
    else:
        print(out.to_string(index=False))

# =============================================================
# 14. Completion message
# =============================================================
print("\nScript completed successfully.")
