# Protein structure analysis notebook
This notebook is used to evaluate and visualise structural metrics.

It is assumed you have a results table in the format of the one in the supplementary data, and a protein quality master CSV in its respective format. The below script plots distribution of clashscores in each epoch number of mutated sequences used to generate the alphafold2 model.



In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties, fontManager
import matplotlib as mpl
from statsmodels.stats.multicomp import pairwise_tukeyhsd

# ----------------------------
# Font configuration
# ----------------------------
font_path = r"C:\Users\james\Downloads\abadi-mt_freefontdownload_org\abadi-mt.ttf"
fontManager.addfont(font_path)
font_prop = FontProperties(fname=font_path)
mpl.rcParams['font.family'] = font_prop.get_name()

# ----------------------------
# Load quality data
# ----------------------------
quality_df = pd.read_csv(
    r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\protein_quality_master.csv"
).dropna(subset=["clashscore"])

# ----------------------------
# Boxplot of clashscore by iteration group
# ----------------------------
iteration_order = ["wt", "1", "5", "10", "var"]
custom_colors = ['#12436D', '#28A197', '#801650', '#F46A25', '#A285D1']

plt.figure(figsize=(10, 6))
sns.set(style="whitegrid")

sns.boxplot(
    data=quality_df,
    x="iteration_number",
    y="clashscore",
    palette=custom_colors,
    showfliers=True,
    order=iteration_order
)

sns.stripplot(
    data=quality_df,
    x="iteration_number",
    y="clashscore",
    color="black",
    alpha=0.4,
    jitter=0.2,
    order=iteration_order
)

plt.xlabel("Iteration Number", fontsize=12, fontproperties=font_prop)
plt.ylabel("Clash Score", fontsize=12, fontproperties=font_prop)
plt.title("Clash Scores by Iteration Group", fontsize=24, fontproperties=font_prop)
plt.grid(True)
plt.tight_layout()
plt.show()

# ----------------------------
# Tukey HSD test
# ----------------------------
tukey = pairwise_tukeyhsd(
    endog=quality_df["clashscore"],
    groups=quality_df["iteration_number"],
    alpha=0.05
)

print(tukey.summary())


# Salt bridges and contacts
The next step counts salt bridges and contacts on the collected models for variant and WT.

In [None]:
# -*- coding: utf-8 -*-
"""
Salt-bridge partner change per protein + regression of Delta melt vs Delta salt-bridge.

- Protein list comes from results_table.csv (all families).
- WT/VAR structures discovered by walking the project tree.
- Bar chart: sign-only colours (orange +, dark blue −, grey 0).
- Regression: Delta melt temperature (VAR − WT) vs Delta salt-bridge partners,
  using a robust fallback fit (no SVD issues) and avoiding the 'Δ' glyph.
"""

import os
import re
from typing import Optional, Dict, List, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from Bio.PDB import PDBParser
from Bio import pairwise2

# ----------------------------
# Config
# ----------------------------

seq_csv = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\results_table.csv"

WT_MARKER  = "wild_type_structures"
VAR_MARKER = "collected_models"

chain_hint = "A"
salt_cutoff = 4.0
include_histidine = False

# Explicit exclusions
incomplete_pairs = {"2agl", "5ur0", "7mx6", "6i2a", "1taq", "5fkw"}

# Colours
NEG_COLOR   = "#12436D"
POS_COLOR   = "#F46A25"
ZERO_COLOR  = "#BFBFBF"
SCATTER_COLOR = "#444444"
LINE_COLOR    = "#F46A25"

plt.rcParams["axes.unicode_minus"] = False

# ----------------------------
# Amino-acid & atom maps
# ----------------------------
AA3_TO_1 = {
    "ALA":"A","ARG":"R","ASN":"N","ASP":"D","CYS":"C","GLU":"E","GLN":"Q","GLY":"G",
    "HIS":"H","ILE":"I","LEU":"L","LYS":"K","MET":"M","PHE":"F","PRO":"P","SER":"S",
    "THR":"T","TRP":"W","TYR":"Y","VAL":"V","MSE":"M","HSD":"H","HSE":"H","HSP":"H"
}

ACIDIC = {"ASP", "GLU"}
BASIC  = {"ARG", "LYS"}

ACIDIC_O_ATOMS = {"ASP": {"OD1", "OD2"}, "GLU": {"OE1", "OE2"}}
BASIC_N_ATOMS  = {
    "ARG": {"NH1", "NH2", "NE"},
    "LYS": {"NZ"},
    "HIS": {"ND1", "NE2"},
    "HSD": {"ND1"}, "HSE": {"NE2"}, "HSP": {"ND1", "NE2"},
}

# ----------------------------
# Robust ID extraction
# ----------------------------
ID_PATTERN = re.compile(r"(?i)\b([0-9][A-Za-z0-9]{3})\b")

def extract_pdb_id(filename: str) -> Optional[str]:
    base = os.path.splitext(os.path.basename(filename))[0]
    m = ID_PATTERN.search(base)
    if m:
        return m.group(1).lower()
    m2 = re.search(r"(?i)([A-Za-z0-9]{4})", base)
    return m2.group(1).lower() if m2 else None

def walk_index(root: str, marker_substr: str) -> Dict[str, List[str]]:
    """Map pdb_id -> list of PDB paths under directories containing marker_substr."""
    idx: Dict[str, List[str]] = {}
    ms = marker_substr.lower()
    for dirpath, _, files in os.walk(root):
        if ms not in dirpath.lower():
            continue
        for f in files:
            if not f.lower().endswith(".pdb"):
                continue
            pid = extract_pdb_id(f)
            if not pid:
                continue
            idx.setdefault(pid, []).append(os.path.join(dirpath, f))
    for k in idx:
        idx[k].sort()
    return idx

# ----------------------------
# Structure utilities
# ----------------------------
def pick_chain(structure, hint=None):
    model = structure[0]
    if hint and hint in model:
        return model[hint]
    return list(model.get_chains())[0]

def load_chain_residues(pdb_path, chain_hint):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("x", pdb_path)
    chain = pick_chain(structure, chain_hint)
    res_list, seq = [], []
    for res in chain:
        if res.id[0] != " ":
            continue
        res_list.append(res)
        aa1 = AA3_TO_1.get((res.get_resname() or "").upper(), "X")
        seq.append(aa1)
    return res_list, "".join(seq), chain.id

def residue_salt_atoms(res, role):
    name = (res.get_resname() or "").upper()
    coords = []
    if role == "acidic":
        atoms = ACIDIC_O_ATOMS.get(name, set())
        for a in res:
            if a.get_name().upper() in atoms:
                coords.append(a.coord)
    elif role == "basic":
        valid = BASIC | ({"HIS","HSD","HSE","HSP"} if include_histidine else set())
        if name in valid:
            atoms = BASIC_N_ATOMS.get(name, set())
            for a in res:
                if a.get_name().upper() in atoms:
                    coords.append(a.coord)
    return np.array(coords, dtype=float) if coords else np.zeros((0,3), dtype=float)

def residue_charge_class(res):
    name = (res.get_resname() or "").upper()
    if name in ACIDIC:
        return "acidic"
    if name in BASIC or (include_histidine and name in {"HIS","HSD","HSE","HSP"}):
        return "basic"
    return None

def count_saltbridge_partners(residue, others, cutoff_A=4.0):
    cls = residue_charge_class(residue)
    if cls is None:
        return 0
    my_atoms = residue_salt_atoms(residue, role=cls)
    if my_atoms.size == 0:
        return 0
    opp_role = "basic" if cls == "acidic" else "acidic"
    coords_list, owner_idx = [], []
    for k, other in enumerate(others):
        if other is residue:
            continue
        opp_coords = residue_salt_atoms(other, role=opp_role)
        if opp_coords.size:
            coords_list.append(opp_coords)
            owner_idx.extend([k]*opp_coords.shape[0])
    if not coords_list:
        return 0
    B = np.vstack(coords_list)
    owner_idx = np.asarray(owner_idx, dtype=int)
    d2 = np.sum((my_atoms[:,None,:] - B[None,:,:])**2, axis=-1)
    close_mask = d2 <= (cutoff_A**2)
    if not np.any(close_mask):
        return 0
    close_atom_cols = np.any(close_mask, axis=0)
    partner_res_indices = set(owner_idx[close_atom_cols].tolist())
    return len(partner_res_indices)

def align_and_mutations(wt_res, wt_seq, var_res, var_seq):
    aln = pairwise2.align.globalms(wt_seq, var_seq, 2, -1, -5, -0.5, one_alignment_only=True)[0]
    sA, sB = aln.seqA, aln.seqB
    i = j = 0
    for a, b in zip(sA, sB):
        wt_obj = var_obj = None
        if a != "-":
            wt_obj = wt_res[i]; i += 1
        if b != "-":
            var_obj = var_res[j]; j += 1
        if wt_obj is not None and var_obj is not None and a != b and a != "X" and b != "X":
            yield (wt_obj, var_obj, a, b)

def process_pair(pdbid, wt_path, var_path):
    wt_res, wt_seq, _ = load_chain_residues(wt_path, chain_hint)
    var_res, var_seq, _ = load_chain_residues(var_path, chain_hint)
    rows = []
    for wt_obj, var_obj, a, b in align_and_mutations(wt_res, wt_seq, var_res, var_seq):
        wt_sb = count_saltbridge_partners(wt_obj, wt_res, salt_cutoff)
        var_sb = count_saltbridge_partners(var_obj, var_res, salt_cutoff)
        rows.append({
            "pdb_id": pdbid,
            "wt_resSeq": wt_obj.get_id()[1],
            "var_resSeq": var_obj.get_id()[1],
            "wt_resname": wt_obj.get_resname(),
            "var_resname": var_obj.get_resname(),
            "wt_aa": a,
            "var_aa": b,
            "WT_saltbridges": wt_sb,
            "VAR_saltbridges": var_sb,
            "Delta_saltbridges": var_sb - wt_sb
        })
    return rows

# ----------------------------
# Safe linear regression
# ----------------------------
def safe_linear_fit(x: np.ndarray, y: np.ndarray) -> Tuple[float,float,float]:
    x = np.asarray(x, dtype=np.float64)
    y = np.asarray(y, dtype=np.float64)
    msk = np.isfinite(x) & np.isfinite(y)
    x, y = x[msk], y[msk]
    if x.size < 2:
        return 0.0, float(np.nanmean(y) if y.size else 0.0), np.nan
    x_mean, y_mean = x.mean(), y.mean()
    x_var = np.sum((x - x_mean)**2)
    if x_var == 0:
        return 0.0, y_mean, 0.0
    try:
        m, b = np.polyfit(x, y, 1)
    except Exception:
        cov_xy = np.sum((x - x_mean)*(y - y_mean))
        m = cov_xy / x_var
        b = y_mean - m*x_mean
    if np.allclose(y, y.mean()):
        r2 = 0.0
    else:
        r = np.corrcoef(x, y)[0,1]
        r2 = float(r**2) if np.isfinite(r) else np.nan
    return float(m), float(b), r2

# ----------------------------
# 1) Read protein list
# ----------------------------
seq_df = pd.read_csv(seq_csv)
seq_df["pdb_id"] = seq_df["pdb_id"].astype(str).str.lower()
display_ids = sorted(set(seq_df["pdb_id"]) - incomplete_pairs)

# ----------------------------
# 2) Build WT/VAR indices
# ----------------------------
wt_index  = walk_index(project_root, WT_MARKER)
var_index = walk_index(project_root, VAR_MARKER)

# ----------------------------
# 3) Process pairs
# ----------------------------
all_rows = []
failed, missing = {}, []
for pid in display_ids:
    wt_paths = wt_index.get(pid, [])
    var_paths = var_index.get(pid, [])
    if not wt_paths or not var_paths:
        missing.append(pid)
        continue
    try:
        all_rows.extend(process_pair(pid, wt_paths[0], var_paths[0]))
    except Exception as e:
        failed[pid] = str(e)

df = pd.DataFrame(all_rows)

# ----------------------------
# 4) Aggregate Delta salt-bridge per protein
# ----------------------------
per_protein = (df.groupby("pdb_id")["Delta_saltbridges"].sum()
               if not df.empty else pd.Series(dtype=float))
per_protein = per_protein.reindex(display_ids, fill_value=0.0).sort_values()

# ----------------------------
# 5) Bar chart
# ----------------------------
bar_colors = [POS_COLOR if v>0 else NEG_COLOR if v<0 else ZERO_COLOR
              for v in per_protein.values]

plt.figure(figsize=(12, max(3, 0.35*len(per_protein))))
plt.barh(per_protein.index, per_protein.values, color=bar_colors, edgecolor="black", linewidth=0.5)
plt.axvline(0, color="black", linewidth=1)
plt.xlabel("Total Delta salt-bridge partners (VAR − WT)")
plt.ylabel("Protein (pdb_id)")
plt.title("Salt-bridge partner change per protein")
plt.tight_layout()
plt.show()

# ----------------------------
# 6) Delta melt regression
# ----------------------------
if "avg_melt_temp" not in seq_df.columns:
    raise KeyError("results_table.csv must contain 'avg_melt_temp' for regression plot.")

melt = seq_df[["pdb_id","wt_or_var","avg_melt_temp"]].copy()
melt["avg_melt_temp"] = pd.to_numeric(melt["avg_melt_temp"], errors="coerce")

melt_agg = melt.dropna(subset=["avg_melt_temp"]).groupby(["pdb_id","wt_or_var"], as_index=False)["avg_melt_temp"].mean()
melt_pivot = melt_agg.pivot(index="pdb_id", columns="wt_or_var", values="avg_melt_temp")
melt_pivot.columns = [str(c).lower() for c in melt_pivot.columns]



print("Salt-bridge partner change counts:")
print(f"  Increased: {(per_protein>0).sum()}")
print(f"  Decreased: {(per_protein<0).sum()}")
print(f"  No change: {(per_protein==0).sum()}")


In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# ----------------------------
# Config
# ----------------------------
mut_path = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\mutation_contact_counts_all.csv"
seq_path = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\var_wt_sequences.csv"
incomplete_pairs = {"2agl", "5ur0", "7mx6", "6i2a"}

# ----------------------------
# Colormap (blue -> white -> orange)
# ----------------------------
colors = [(0, "#12436D"), (0.5, "white"), (1, "#F46A25")]
accessible_cmap = LinearSegmentedColormap.from_list("darkblue_white_orange", colors)

# ----------------------------
# Load and filter
# ----------------------------
mut = pd.read_csv(mut_path)
seq = pd.read_csv(seq_path)

mut = mut[~mut["pdb_id"].isin(incomplete_pairs)].copy()
seq = seq[~seq["pdb_id"].isin(incomplete_pairs)].copy()

# ----------------------------
# Prepare positions and lengths
# ----------------------------
mut = mut.rename(columns={"var_resSeq": "position"})  # or "wt_resSeq" if preferred
mut["position"] = pd.to_numeric(mut["position"], errors="coerce").astype("Int64")
seq["length"] = seq["sequence"].str.len().astype("Int64")

df = (mut.merge(seq[["pdb_id", "length"]], on="pdb_id", how="left")
        .dropna(subset=["position", "length"])
        .assign(position=lambda d: d["position"].astype(int),
                length=lambda d: d["length"].astype(int)))

denom = (df["length"] - 1).replace(0, 1)
df["pct"] = np.clip(100 * (df["position"] - 1) / denom, 0, 100)

# ========================
# 1) Normalized mutation-location heatmaps
# ========================
BIN_SIZE = 2
bins   = np.arange(0, 100 + BIN_SIZE, BIN_SIZE)
labels = [f"{int(b)}–{int(b+BIN_SIZE)}" for b in bins[:-1]]
df["pct_bin"] = pd.cut(df["pct"], bins=bins, labels=labels, include_lowest=True, right=False)

heat = (df.groupby(["pdb_id", "pct_bin"]).size()
          .unstack(fill_value=0).reindex(columns=labels, fill_value=0))
heat_norm = heat.div(heat.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)

plt.figure(figsize=(16, max(4, 0.35 * heat_norm.shape[0])))
sns.heatmap(heat_norm, cmap=accessible_cmap, vmin=0, vmax=heat_norm.values.max(),
            cbar_kws={"label": "Fraction of mutations"})
plt.xlabel("Sequence %")
plt.ylabel("Protein")
plt.title("Heatmap of mutation locations by % of sequence")
plt.tight_layout()
plt.show()

overall = (heat.sum(axis=0) / heat.values.sum())
plt.figure(figsize=(16, 2.6))
sns.heatmap(overall.to_frame().T, cmap=accessible_cmap, vmin=0, vmax=overall.max(),
            cbar_kws={"label": "Global fraction"})
plt.yticks([0.5], ["All proteins"])
plt.xlabel("Sequence % (binned)")
plt.title("Overall normalized mutation distribution")
plt.tight_layout()
plt.show()

# ========================
# 2) Mean Δ-contacts heatmaps
# ========================
delta_col = next((c for c in ["Delta_contacts", "Delta_contaacts"] if c in mut.columns), None)
if delta_col is None:
    raise KeyError(f"Delta column not found. Columns: {mut.columns.tolist()}")

df[delta_col] = pd.to_numeric(df[delta_col], errors="coerce")

BIN_SIZE = 5
bins   = np.arange(0, 100 + BIN_SIZE, BIN_SIZE)
labels = [f"{int(b)}–{int(b+BIN_SIZE)}" for b in bins[:-1]]
df["pct_bin"] = pd.cut(df["pct"], bins=bins, labels=labels, include_lowest=True, right=False)

heat_mean = (df.groupby(["pdb_id", "pct_bin"])[delta_col]
               .mean().unstack().reindex(columns=labels))

v = np.nanmax(np.abs(heat_mean.values))
plt.figure(figsize=(16, max(4, 0.35 * heat_mean.shape[0])))
sns.heatmap(heat_mean, cmap=accessible_cmap, center=0, vmin=-v, vmax=v,
            cbar_kws={"label": "Mean Δ contacts"})
plt.xlabel("Sequence %")
plt.ylabel("Protein")
plt.title("Location of change in contacts")
plt.tight_layout()
plt.show()

overall_mean = df.groupby("pct_bin")[delta_col].mean().reindex(labels)
plt.figure(figsize=(16, 2.6))
sns.heatmap(overall_mean.to_frame().T, cmap=accessible_cmap, center=0, vmin=-v, vmax=v,
            cbar_kws={"label": "Mean Δ contacts"})
plt.yticks([0.5], ["All proteins"])
plt.xlabel("Sequence % (binned)")
plt.title("Overall mean change in contacts across sequence bins")
plt.tight_layout()
plt.show()


# Foldseek analysis

This section analyses and plots foldseek results. It calculates the mean aln-tm score for all hits that cross a 1% bin of the sequence.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.font_manager import FontProperties, fontManager
import matplotlib as mpl

# ----------------------------
# Font setup
# ----------------------------
font_path = r"C:\Users\james\Downloads\abadi-mt_freefontdownload_org\abadi-mt.ttf"
fontManager.addfont(font_path)
font_prop = FontProperties(fname=font_path)
mpl.rcParams["font.family"] = font_prop.get_name()

# ----------------------------
# Accessible colormap
# ----------------------------
colors = [(0, "#12436D"), (0.5, "white"), (1, "#F46A25")]
accessible_cmap = LinearSegmentedColormap.from_list("darkblue_white_orange", colors)

# ----------------------------
# Build heatmap matrix (mean per bin only)
# ----------------------------
def build_heatmap_matrix(df_subset):
    records = []
    for _, row in df_subset.iterrows():
        if pd.isna(row["qstart"]) or pd.isna(row["qend"]) or pd.isna(row["qlen"]):
            continue

        pdb_id = str(row["PDB"]).strip().split("_")[0]

        start_bin = int((row["qstart"] / row["qlen"]) * 100)
        end_bin = int((row["qend"] / row["qlen"]) * 100)

        for b in range(start_bin, min(end_bin + 1, 100)):
            records.append({"PDB": pdb_id, "bin": b, "alntmscore": row["alntmscore"]})

    df_bins = pd.DataFrame(records)
    if df_bins.empty:
        return None

    heatmap_data = (
        df_bins.groupby(["PDB", "bin"])["alntmscore"].mean().reset_index()
    )
    return heatmap_data.pivot(index="PDB", columns="bin", values="alntmscore")

# ----------------------------
# Detect homopolymers
# ----------------------------
def find_homopolymers(seq, run_length=3):
    positions = []
    if pd.isna(seq):
        return positions
    seq = str(seq)
    count = 1
    for i in range(1, len(seq)):
        if seq[i] == seq[i - 1]:
            count += 1
            if count >= run_length:
                frac = (i - run_length // 2) / len(seq) * 100
                positions.append(frac)
        else:
            count = 1
    return positions

def build_homopolymer_dict(df, run_length=3):
    pdb_homopolymer_dict = {}
    for _, row in df.iterrows():
        pdb_id = str(row["PDB"]).strip().split("_")[0]
        if pdb_id not in pdb_homopolymer_dict:
            pdb_homopolymer_dict[pdb_id] = find_homopolymers(
                row.get("qseq", None), run_length=run_length
            )
    return pdb_homopolymer_dict

# ----------------------------
# Plot heatmap
# ----------------------------
def plot_heatmap(heatmap_matrix, title, homopolymer_dict=None, center=None):
    if heatmap_matrix is None:
        print(f"No data for {title}")
        return

    pdb_order = heatmap_matrix.mean(axis=1).sort_values(ascending=False).index
    heatmap_matrix_sorted = heatmap_matrix.loc[pdb_order]

    plt.figure(figsize=(14, 8))
    ax = sns.heatmap(
        heatmap_matrix_sorted,
        cmap=accessible_cmap,
        center=center,
        cbar_kws={
            "label": "alntmscore" if center is None else "Change in TM-score"
        },
        xticklabels=10,
        yticklabels=True,
    )

    if homopolymer_dict is not None:
        for i, pdb in enumerate(pdb_order):
            if pdb in homopolymer_dict:
                for frac in homopolymer_dict[pdb]:
                    ax.plot(
                        frac,
                        i + 0.5,
                        marker="|",
                        color="#32CD32",
                        markersize=10,
                        linewidth=2,
                    )

    plt.xlabel("Position through sequence (% bins)", fontsize=20, fontproperties=font_prop)
    plt.ylabel("PDB ID", fontsize=20, fontproperties=font_prop)
    plt.title(title, fontsize=20, fontproperties=font_prop)

    ax.tick_params(axis="x", labelsize=20)
    ax.tick_params(axis="y", labelsize=20)

    cbar = ax.collections[0].colorbar
    cbar.ax.yaxis.label.set_size(20)
    cbar.ax.yaxis.label.set_fontproperties(font_prop)
    cbar.ax.tick_params(labelsize=20)

    plt.tight_layout()
    plt.show()

# ----------------------------
# Main analysis for a file
# ----------------------------
def process_file(filepath, label):
    print(f"\n--- Processing {label} ---\n")

    df = pd.read_csv(filepath)
    for col in ["qstart", "qend", "qlen", "alntmscore"]:
        df[col] = pd.to_numeric(df[col], errors="coerce")

    homopolymer_dict = build_homopolymer_dict(df, run_length=3)

    # VAR
    df_var = df[df["variant"] == "var"].copy()
    heatmap_matrix_var = build_heatmap_matrix(df_var)
    plot_heatmap(heatmap_matrix_var, f"{label} – VAR: TM-score across sequence",
                 homopolymer_dict=homopolymer_dict)

    # WT
    df_wt = df[df["variant"] == "wt"].copy()
    heatmap_matrix_wt = build_heatmap_matrix(df_wt)
    plot_heatmap(heatmap_matrix_wt, f"{label} – WT: TM-score across sequence",
                 homopolymer_dict=homopolymer_dict)

    # Difference
    if (heatmap_matrix_var is not None) and (heatmap_matrix_wt is not None):
        common_pdbs = heatmap_matrix_var.index.intersection(heatmap_matrix_wt.index)
        common_bins = heatmap_matrix_var.columns.intersection(heatmap_matrix_wt.columns)

        diff_matrix = (
            heatmap_matrix_var.loc[common_pdbs, common_bins]
            - heatmap_matrix_wt.loc[common_pdbs, common_bins]
        )

        plot_heatmap(diff_matrix, f"{label} – Difference VAR − WT",
                     homopolymer_dict=homopolymer_dict, center=0)

        summary_stats = pd.DataFrame({
            "mean_delta": diff_matrix.mean(axis=1)
        }).sort_values("mean_delta", ascending=False)

        print("Per-PDB difference summary (VAR − WT):")
        print(summary_stats.head(20))

# ----------------------------
# Run for local and global
# ----------------------------
local_path = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\local_foldseek_results.csv"
global_path = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\global_foldseek_results.csv"

process_file(local_path, "Local Foldseek")
process_file(global_path, "Global Foldseek")
