# Analysis Notebook

This notebook is used for the analysis and visualisation of data from the ESM-2 hallucination pipeline.

# Cumulative Log-likelihood and melt temperature.

This cell plots the cumulative log-likelihood and melt temperature over epochs. It is used to demonstrate optimisation for increased thermostability over time. It is normalised to have both on a comparable scale. It takes the last output file of temBERTure as input.

In [None]:
import math
import pandas as pd
import matplotlib.pyplot as plt

# File path
file_path = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\results\temberture_results\protease_best_scores_per_iteration_melt_temps.csv"
# Load data
df = pd.read_csv(file_path)

# Sort
df = df.sort_values(['pdb_id', 'iteration'])

# Exclude problematic proteins
incomplete_pairs = ["2agl", "5ur0", "7mx6", "6i2a"]
df = df[~df['pdb_id'].isin(incomplete_pairs)]

# Compute cumulative score
df['cumulative_score'] = df.groupby('pdb_id')['score'].cumsum()

# Compute derivative of avg_prediction
df['avg_prediction_derivative'] = df.groupby('pdb_id')['avg_melt_temp'].diff()

# Colour palette
duo_palette = {
    'dark_blue': '#12436D',
    'orange': '#F46A25'
}

# Grid size
pdb_ids = df['pdb_id'].unique()
n_plots = len(pdb_ids)
n_cols = math.ceil(math.sqrt(n_plots))
n_rows = math.ceil(n_plots / n_cols)

fig, axes = plt.subplots(
    n_rows,
    n_cols,
    figsize=(6 * n_cols, 4 * n_rows),
    squeeze=False
)

for ax, pdb_id in zip(axes.flatten(), pdb_ids):
    group = df[df['pdb_id'] == pdb_id].copy()

    # Left axis: melt temperature
    ax.plot(
        group['iteration'],
        group['avg_melt_temp'],
        color=duo_palette['orange'],
        label='Melt Temperature'
    )
    ax.set_xlabel('Iteration', fontsize=20)
    ax.set_ylabel(
        'Melt Temperature (°C)',
        color=duo_palette['orange'],
        fontsize=20
    )
    ax.set_ylim(40, 90)
    ax.tick_params(axis='x', labelsize=20)
    ax.tick_params(axis='y', labelcolor=duo_palette['orange'], labelsize=20)
    ax.grid(True)

    # Add PDB ID title
    ax.set_title(f"{pdb_id}", fontsize=20)

    # Right axis: cumulative log-likelihood
    ax2 = ax.twinx()
    ax2.plot(
        group['iteration'],
        group['cumulative_score'],
        color=duo_palette['dark_blue'],
        linestyle='dashed',
        label='Cumulative Log-Likelihood'
    )
    ax2.set_ylabel(
        'Cumulative Log-Likelihood',
        color=duo_palette['dark_blue'],
        fontsize=20
    )
    ax2.tick_params(axis='y', labelcolor=duo_palette['dark_blue'], labelsize=20)
    ax2.tick_params(axis='x', labelsize=20)

    # Collect legend handles once
    if pdb_id == pdb_ids[0]:
        lines1, labels1 = ax.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        handles = lines1 + lines2
        labels = labels1 + labels2

# Hide unused subplots
for i in range(n_plots, n_rows * n_cols):
    fig.delaxes(axes.flatten()[i])

# Global legend
fig.legend(
    handles,
    labels,
    loc='lower center',
    bbox_to_anchor=(0.5, -0.02),
    ncol=2
)

plt.tight_layout()
plt.show()


# Amino acid counts and propensity
The following scripts will visualise differences in amino acid counts and propensity. Set the output path to be a new CSV in your results directory.


In [None]:
# Extract row with maximum iteration per pdb_id and save to CSV
max_iteration_rows = df.loc[
    df.groupby('pdb_id')['iteration'].idxmax(),
    ['pdb_id', 'iteration', 'score', 'sequence']
]

output_path = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\results\protease_max_iteration_summary.csv"
max_iteration_rows.to_csv(output_path, index=False)

max_iteration_rows.head()


This section counts amino acid propensities and saves them into a new CSV. You shouuld ideally select your results directory for this.

In [None]:
import pandas as pd
from collections import Counter

# ---------------- CONFIG ----------------
family = "protease"
output_csv_props = f"C:\\Users\\james\\Masters_Degree\\Thesis\\protein_language_model_project\\results\\{family}_amino_acid_proportions.csv"
# ----------------------------------------

# Make sure `max_iteration_rows` exists from the previous cell
all_aa = sorted("ACDEFGHIKLMNPQRSTVWY") 
results = []

for _, row in max_iteration_rows.iterrows():
    pdb_id = row['pdb_id']
    seq = row['sequence']
    length = len(seq)

    counts = Counter(seq)
    data = {
        "pdb_id": pdb_id,
        "iteration": row['iteration'],
        "score": row['score'],
        "length": length
    }
    for aa in all_aa:
        prop = counts.get(aa, 0) / length if length > 0 else 0
        data[f"{aa}_prop"] = prop

    results.append(data)

# Create DataFrame and save
df_props = pd.DataFrame(results)
df_props.to_csv(output_csv_props, index=False)

print(f"Saved amino acid proportions to {output_csv_props}")
df_props.head()


# Sequence property analysis

This section plots out changes in propensity across proteins, and broadly analyses physiochemical properties of sequences.

The original analysis utilises a results table manually collated by copying the output of previous steps. 

Hydropathy, Aromaticity, instability index, aliphatic index, iso-electric point, net positive and net negative count were omitted from the final analysis for brevity. These are available in the attached results table however, and the below script can be run on them. The counts for sec structure proportion were based off a Biopython analysis off sequence. These were omitted due to use of Foldseek for finding folds, and DSSP being a superior method, which was omitted from the analysis for brevity. Flexibility on Biopython was originally captured, but not used in the analysis, due to the superiority of molecular dynamics methods, but their computational expense proscribed their use.

For refactoring into an automated pipeline however, it is necessary to have an addition script that collates previous results from temberture and evoprotgrad, and for each protein, for wild-type and variant, have the columns:

Protein family: ['family']

Protein: ['pdb_id']

Wild type or variant: ['wt_or_var']

Sequence: ['sequence']

Average Melt Temperature: ['avg_melt_temp']

Hydropathy: ['gravy']

All 20 proteinogenic amino acids (where {X} is one letter code in capital letter): ['{X}_prop']


These are the other columns utilised. 

This cell will print visualisations for Aromaticity, instability, aliphatic index, hydropathy, sec structure %, flexibility, but these are redundant and not used in this analysis. Their functionality is maintained however, as they may yield insights in the future if large datasets are analysed.

If you don't have Abadi installed as a font, just remove the font_manager and font_prop parts of code, or use them to define a font you do have. Or install Abadi, it's a nice font.

In [None]:
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib as mpl
import pandas as pd

# ----------------- CONFIG -----------------
results_csv = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\results_table.csv"
# ------------------------------------------

# Palette and colormap
duo_palette = {'dark blue': '#12436D', 'orange': '#F46A25'}
colors = [(0, "#12436D"), (0.5, "white"), (1, "#F46A25")]
accessible_cmap = LinearSegmentedColormap.from_list("darkblue_white_orange", colors)

# Load dataset with wt/var info
df = pd.read_csv(results_csv)
incomplete_pairs = ["2agl", "5ur0", "7mx6", "6i2a"]

wt = df[df['wt_or_var'] == 'wt'].set_index('pdb_id')
var = df[df['wt_or_var'] == 'var'].set_index('pdb_id')
wt = wt[~wt.index.isin(incomplete_pairs)]
var = var[~var.index.isin(incomplete_pairs)]

# ---------------- Secondary structure deltas ----------------
delta_helix = (var['sec_struc_helix'] - wt['sec_struc_helix']) * 100
delta_turn = (var['sec_struc_turn'] - wt['sec_struc_turn']) * 100
delta_sheet = (var['sec_struc_sheet'] - wt['sec_struc_sheet']) * 100

secstruc_deltas = pd.DataFrame({'Helix': delta_helix, 'Turn': delta_turn, 'Sheet': delta_sheet})

plt.figure(figsize=(12, 6))
sns.heatmap(secstruc_deltas.T, annot=True, cmap=accessible_cmap,
            square=False, cbar_kws={"shrink": 0.8}, center=0)
plt.title("Change in Secondary Structure Proportions", pad=20, fontsize=20)
plt.ylabel("Secondary Structure", labelpad=10, fontsize=20)
plt.xticks(fontsize=16, rotation=90)
plt.yticks(fontsize=20, rotation=0)
plt.tight_layout(pad=2)
plt.show()

# ---------------- Stability/physchem deltas ----------------
deltas_temp = var['avg_melt_temp'] - wt['avg_melt_temp']
deltas_instability = var['instability_index'] - wt['instability_index']
delta_aromaticity = var['aromaticity'] - wt['aromaticity']
delta_aliphatic = var['aliphatic_index'] - wt['aliphatic_index']
delta_gravy = var['gravy'] - wt['gravy']
delta_isoelectric = var['isoelectric_point'] - wt['isoelectric_point']
delta_flex = var['flexibility_mean'] - wt['flexibility_mean']

df_deltas_vertical = pd.DataFrame({
    'avg_melt_temp': deltas_temp,
    'instability_index': deltas_instability,
    'aromaticity': delta_aromaticity,
    'aliphaticity': delta_aliphatic,
    'Hydropathicity': delta_gravy,
    'isoelectric_point': delta_isoelectric,
    'Flexibility': delta_flex
})

for col in df_deltas_vertical.columns:
    values = df_deltas_vertical[col]
    colours = [duo_palette['orange'] if v <= 0 else duo_palette['dark blue'] for v in values]
    plt.figure(figsize=(10, 5))
    sns.barplot(x=values.index, y=values.values, palette=colours)
    plt.axhline(0, color='black', linestyle='--')
    plt.ylabel(f"Change in {col.replace('_',' ').title()}", fontsize=20)
    plt.xticks(fontsize=20, rotation=90)
    plt.yticks(fontsize=20)
    plt.title(f"Change in {col.replace('_',' ').title()} across proteins", fontsize=20)
    plt.tight_layout()
    plt.show()

# ---------------- Amino acid deltas ----------------
amino_acid_props = [f"{aa}_prop" for aa in "ACDEFGHIKLMNPQRSTVWY"]

df_aa_deltas_pct = (var[amino_acid_props] - wt[amino_acid_props]) * 100
df_aa_deltas_pct_T = df_aa_deltas_pct.transpose()

# Weighted average change
wt['length'] = wt['sequence'].str.len()
var['length'] = var['sequence'].str.len()
avg_lengths = (wt['length'] + var['length']) / 2

weighted_deltas = df_aa_deltas_pct_T * avg_lengths.values
total_residues = avg_lengths.sum()
weighted_avg_change_aa = weighted_deltas.sum(axis=1) / total_residues
weighted_avg_change_aa.index = [aa.replace('_prop', '') for aa in weighted_avg_change_aa.index]

plt.figure(figsize=(10, 5))
colours = [duo_palette['orange'] if v <= 0 else duo_palette['dark blue'] for v in weighted_avg_change_aa.values]
sns.barplot(x=weighted_avg_change_aa.index, y=weighted_avg_change_aa.values, palette=colours)
plt.axhline(0, color='black', linestyle='--')
plt.ylabel("Change in proportion (%)", fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.title("Absolute change of amino acid proportions across proteins", fontsize=20)
plt.tight_layout()
plt.show()

# Net change counts
sign_matrix = np.sign(df_aa_deltas_pct_T)
df_aa_deltas_pct_T.index = [aa.replace('_prop','') for aa in df_aa_deltas_pct_T.index]
sign_matrix.index = df_aa_deltas_pct_T.index
net_change_aa = sign_matrix.sum(axis=1)

plt.figure(figsize=(10, 5))
colours = [duo_palette['orange'] if v <= 0 else duo_palette['dark blue'] for v in net_change_aa.values]
sns.barplot(x=net_change_aa.index, y=net_change_aa.values, palette=colours)
plt.axhline(0, color='black', linestyle='--')
plt.ylabel("Net change count across proteins")
plt.xlabel("Amino Acid")
plt.title("Net change of amino acids across proteins")
plt.tight_layout()
plt.show()

# Heatmap of per-protein amino acid deltas
plt.figure(figsize=(10, 10))
max_abs = np.abs(df_aa_deltas_pct_T.values).max()
sns.heatmap(df_aa_deltas_pct_T, annot=True, fmt=".1f",
            cmap=accessible_cmap, center=0,
            vmin=-max_abs, vmax=max_abs,
            cbar_kws={"label": "Change in proportion (%)", "shrink": 0.8})
plt.title("Change in Amino Acid Proportions", pad=20, fontsize=20)
plt.xlabel("PDB ID", fontsize=20)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20)
plt.tight_layout(pad=2)
plt.show()

# Total % change across all proteins
total_pct_change_aa = df_aa_deltas_pct_T.sum(axis=1)
plt.figure(figsize=(10, 5))
colours = [duo_palette['orange'] if v <= 0 else duo_palette['dark blue'] for v in total_pct_change_aa.values]
sns.barplot(x=total_pct_change_aa.index, y=total_pct_change_aa.values, palette=colours)
plt.axhline(0, color='black', linestyle='--')
plt.ylabel("Total change in proportion (%)", fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.title("Total % change of amino acid proportions across proteins", fontsize=20)
plt.tight_layout()
plt.show()


The following cell plots amino acid substitutions across the dataset.

In [None]:
# -*- coding: utf-8 -*-
"""
Heatmap of amino acid substitutions (wt → var), ignoring self substitutions.
Assumes `df` is already loaded in the notebook with columns:
['pdb_id', 'wt_or_var', 'sequence'].
"""

from collections import Counter
import numpy as np
import matplotlib.pyplot as plt

# -------------------------------
# Pivot to wide form: wt vs var sequences
# -------------------------------
wt_sequences = df[df['wt_or_var'] == 'wt'][['pdb_id', 'sequence']].rename(columns={'sequence': 'wt_sequence'})
var_sequences = df[df['wt_or_var'] == 'var'][['pdb_id', 'sequence']].rename(columns={'sequence': 'var_sequence'})
df_sequences = pd.merge(wt_sequences, var_sequences, on='pdb_id')

# Remove incomplete pairs
incomplete_pairs = ["2agl", "5ur0", "7mx6", "6i2a"]
df_sequences = df_sequences[~df_sequences['pdb_id'].isin(incomplete_pairs)]

# -------------------------------
# Count amino acid transitions, ignoring self substitutions
# -------------------------------
transition_counter = Counter()
for _, row in df_sequences.iterrows():
    wt_seq = row['wt_sequence']
    var_seq = row['var_sequence']
    min_len = min(len(wt_seq), len(var_seq))
    for wt_aa, var_aa in zip(wt_seq[:min_len], var_seq[:min_len]):
        if wt_aa != var_aa:
            transition_counter[(wt_aa, var_aa)] += 1

# -------------------------------
# Build substitution matrix
# -------------------------------
amino_acids = sorted(
    set(w for (w, v) in transition_counter.keys()).union(
        v for (w, v) in transition_counter.keys()
    )
)
aa_to_idx = {aa: i for i, aa in enumerate(amino_acids)}

matrix = np.zeros((len(amino_acids), len(amino_acids)), dtype=float)
for (wt_aa, var_aa), count in transition_counter.items():
    i = aa_to_idx[wt_aa]
    j = aa_to_idx[var_aa]
    matrix[i, j] = count

# Blank diagonal (self-substitutions)
np.fill_diagonal(matrix, np.nan)

# -------------------------------
# Plot heatmap
# -------------------------------
plt.figure(figsize=(10, 8))
cax = plt.imshow(matrix, cmap="Blues")
cbar = plt.colorbar(cax, label='Count of substitutions')

plt.xticks(range(len(amino_acids)), amino_acids, fontsize=20)
plt.yticks(range(len(amino_acids)), amino_acids, fontsize=20)
plt.xlabel("Variant amino acid", fontsize=20)
plt.ylabel("Wildtype amino acid", fontsize=20)
plt.title("Heatmap of amino acid substitutions", fontsize=20)
plt.tight_layout()
plt.show()


# Motif counting

The following cell performs analysis on thermophilic motif counts in sequences

In [None]:
# -*- coding: utf-8 -*-
"""
Motif Δ heatmaps with incomplete-pair filtering.
Assumes `df` is already defined with columns: ['pdb_id', 'wt_or_var', 'sequence'].
"""

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.font_manager import FontProperties
import pandas as pd

# =========================
# Plotting theme
# =========================
colors = [(0, "#12436D"), (0.5, "white"), (1, "#F46A25")]
accessible_cmap = LinearSegmentedColormap.from_list("darkblue_white_orange", colors)

# Exclude incomplete pairs (apply again here in case df was reused)
incomplete_pairs = ["2agl", "5ur0", "7mx6", "6i2a"]
results_csv = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\results_table.csv"
df = df[~df["pdb_id"].isin(incomplete_pairs)].copy()

# Normalize sequence casing
df["sequence"] = df["sequence"].str.upper()

# =========================
# Motif definitions
# =========================
motifs = [
    "GSGSG", "GGGGG", "GGGGGGGG", "EAAAK", "PAPAP", "GGGGS",
    "AEAAAKEAAAKA", "VSQTSKLTR", "AETVFPDV", "PLG", "LWA",
    "TRHRQPR", "GWE", "AGNRVRR", "SVG", "RRRRRRR", "HHHHH", "GFLG",
    "GS", "PPP", "AAAA", "KKK", "EEE", "VVV", "GGG", "LLL"
]

over_represented_all = [
    "EEEE", "EEEK", "EEER", "MRRR", "EEKR", "IRRR", "EEEF", "EEKK", "EEEV",
    "EEIK", "EIKK", "EEIV", "EIKR", "EKRR", "IKRR", "EEKV", "EEIR", "EEEN",
    "EEEY", "EERV", "KRRR", "ELWW", "RRRV", "AEER", "EEPR", "EERR", "AERR",
    "RRRY", "ERRR", "EGPR", "ERRV", "EEGP", "EEGR", "EERW", "GPRR", "EEPP"
]

polyE_motifs = ["EE", "EEE", "EEEE", "EEEEE", "EEEEEE", "EEEEEEE", "EEEEEEEE", "EEEEEEEEE"]

polyX_motifs = []
for aa in ["M", "V", "A", "K", "P", "E", "L", "R", "H"]:
    for n in range(3, 7 + 1):
        polyX_motifs.append(aa * n)

# =========================
# Helpers
# =========================
def safe_counts(seq: object, motif_list):
    """Count occurrences of each motif in motif_list within seq (handles NaN)."""
    if not isinstance(seq, str):
        return {m: 0 for m in motif_list}
    return {m: seq.count(m) for m in motif_list}

def summarize_motifs(df_in: pd.DataFrame, motif_list, label: str):
    """
    Count motifs per sequence, aggregate to per (pdb_id, wt_or_var),
    save CSV, and return (wt_table, var_table, delta_table).
    """
    counts = df_in["sequence"].apply(lambda s: safe_counts(s, motif_list))
    counts_df = pd.DataFrame(counts.tolist())
    result = pd.concat([df_in[["pdb_id", "wt_or_var"]], counts_df], axis=1)
    summary = result.groupby(["pdb_id", "wt_or_var"], as_index=False).sum(numeric_only=True)

    out_csv = rf"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\results\{label}_counts.csv"
    summary.to_csv(out_csv, index=False)

    pivoted = summary.pivot(index="pdb_id", columns="wt_or_var")
    wt = pivoted.xs("wt", axis=1, level=1).reindex(columns=motif_list, fill_value=0).fillna(0)
    var = pivoted.xs("var", axis=1, level=1).reindex(columns=motif_list, fill_value=0).fillna(0)
    delta = (var - wt)

    return wt, var, delta

def plot_delta_heatmap(delta_df: pd.DataFrame, title: str, x_label: str, cbar_label: str, figsize=(14, 8), annotate=False):
    plt.figure(figsize=figsize)
    ax = sns.heatmap(
        delta_df,
        cmap=accessible_cmap,
        center=0,
        annot=annotate,
        fmt="d" if annotate else "",
        cbar_kws={'label': cbar_label}
    )
    ax.set_title(title, fontsize=18)
    ax.set_xlabel(x_label, fontsize=14)
    ax.set_ylabel("Protein (pdb_id)", fontsize=14)
    plt.tight_layout()
    plt.show()

# =========================
# 1) General motif set
# =========================
_, _, delta_general = summarize_motifs(df, motifs, label="motif")
plot_delta_heatmap(
    delta_general,
    title="Change in Motif Counts (var − wt)",
    x_label="Motif",
    cbar_label="Δ (var − wt) in motif count",
    figsize=(14, 8),
    annotate=False
)

# =========================
# 2) Over-represented quads
# =========================
_, _, delta_overrep = summarize_motifs(df, over_represented_all, label="overrep_quads")
plot_delta_heatmap(
    delta_overrep,
    title="Change in Over-Represented Quad Motifs (var − wt)",
    x_label="Quadruplet",
    cbar_label="Δ (var − wt) in quad count",
    figsize=(14, 8),
    annotate=False
)

# =========================
# 3) Poly-E
# =========================
_, _, delta_polyE = summarize_motifs(df, polyE_motifs, label="polyE")
plot_delta_heatmap(
    delta_polyE.astype(int),
    title="Change in Poly-E Motif Counts (var − wt)",
    x_label="Poly-E Motif",
    cbar_label="Δ (var − wt) in motif count",
    figsize=(14, 8),
    annotate=True
)

# =========================
# 4) Poly-X
# =========================
_, _, delta_polyX = summarize_motifs(df, polyX_motifs, label="polyX")
plot_delta_heatmap(
    delta_polyX.astype(int),
    title="Change in Poly-X Motif Counts (V, M, A, K, P, E, L, R, H; n=3–7)",
    x_label="Poly-X Motif",
    cbar_label="Δ (var − wt) in motif count",
    figsize=(16, 9),
    annotate=True
)

print("Motif Δ analysis done.")


In [None]:
# -*- coding: utf-8 -*-
"""
Mutation location and Δcontact heatmaps (per-protein).
Assumes:
- mut_path: CSV with at least ['pdb_id','wt_resSeq','var_resSeq', 'Delta_contacts' or 'Delta_contaacts'].
- seq_path: CSV with ['pdb_id','sequence'].
"""

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# --- Config ---
mut_path = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\mutation_contact_counts_all.csv"
seq_path = r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\var_wt_sequences.csv"
POS_COL = "var_resSeq"   # use "wt_resSeq" if you want WT numbering
BIN_SIZE = 5             # percentage bin size

# --- Colormap ---
colors = [(0, "#12436D"), (0.5, "white"), (1, "#F46A25")]
accessible_cmap = LinearSegmentedColormap.from_list("darkblue_white_orange", colors)

# --- Load data ---
mut = pd.read_csv(mut_path)
seq = pd.read_csv(seq_path)

# Pick delta contacts column
delta_col = next((c for c in ["Delta_contacts", "Delta_contaacts"] if c in mut.columns), None)

# --- Position cleaning ---
mut = mut.rename(columns={POS_COL: "position"})
mut["position"] = pd.to_numeric(mut["position"], errors="coerce")
mut = mut.dropna(subset=["position"]).assign(position=lambda d: d["position"].astype(int))

if delta_col is not None:
    mut[delta_col] = pd.to_numeric(mut[delta_col], errors="coerce")

# --- Sequence lengths ---
seq["length"] = seq["sequence"].str.len().astype(int)
df = mut.merge(seq[["pdb_id", "length"]], on="pdb_id", how="left").dropna(subset=["length"])
df["length"] = df["length"].astype(int)

# --- Map positions to % ---
denom = (df["length"] - 1).replace(0, 1)
df["pct"] = np.clip(100 * (df["position"] - 1) / denom, 0, 100)

# --- Binning ---
bins = np.arange(0, 100 + BIN_SIZE, BIN_SIZE)
labels = [f"{int(b)}–{int(b+BIN_SIZE)}" for b in bins[:-1]]
df["pct_bin"] = pd.cut(df["pct"], bins=bins, labels=labels, include_lowest=True, right=False)

# --- Heatmap 1: mutation *frequency* per bin (normalized per protein) ---
heat_counts = (df.groupby(["pdb_id", "pct_bin"])
                 .size()
                 .unstack(fill_value=0)
                 .reindex(columns=labels, fill_value=0))

row_sums = heat_counts.sum(axis=1).replace(0, np.nan)
heat_norm = heat_counts.div(row_sums, axis=0).fillna(0)

plt.figure(figsize=(16, max(4, 0.35 * heat_norm.shape[0])))
sns.heatmap(
    heat_norm,
    cmap="Blues",
    vmin=0, vmax=heat_norm.values.max(),
    cbar_kws={"label": "Fraction of mutations"}
)
plt.xlabel("Sequence % (binned)")
plt.ylabel("Protein (pdb_id)")
plt.title("Mutation frequency across sequence bins — per protein")
plt.tight_layout()
plt.show()

# --- Heatmap 2: Δcontacts per bin (if available) ---
if delta_col is not None:
    heat_mean = (df.groupby(["pdb_id", "pct_bin"])[delta_col]
                   .mean()
                   .unstack()
                   .reindex(columns=labels))

    v = np.nanmax(np.abs(heat_mean.values))
    plt.figure(figsize=(16, max(4, 0.35 * heat_mean.shape[0])))
    sns.heatmap(
        heat_mean,
        cmap=accessible_cmap,
        center=0,
        vmin=-v, vmax=v,
        cbar_kws={"label": "Mean Δ contacts"}
    )
    plt.xlabel("Sequence % (binned)")
    plt.ylabel("Protein (pdb_id)")
    plt.title("Change in contacts (Δ) across sequence bins — per protein")
    plt.tight_layout()
    plt.show()


In [None]:
#!/usr/bin/env python
import pandas as pd
from collections import Counter
from pathlib import Path
import sys

AA20 = list("ACDEFGHIKLMNPQRSTVWY")

def concat_sequences(series: pd.Series) -> str:
    return "".join(series.dropna().astype(str)).upper()

def count_aa(seq: str) -> pd.Series:
    c = Counter(ch for ch in seq if ch in AA20)
    return pd.Series({aa: c.get(aa, 0) for aa in AA20}, dtype="int64")

def main(input_csv, outdir=None):
    input_csv = Path(input_csv)
    if not input_csv.is_file():
        sys.exit(f"Input not found: {input_csv}")

    outdir = Path(outdir) if outdir else input_csv.parent
    outdir.mkdir(parents=True, exist_ok=True)

    df = pd.read_csv(input_csv)
    df["wt_or_var"] = df["wt_or_var"].astype(str).str.strip().str.lower()

    wt_df  = df[df["wt_or_var"] == "wt"]
    var_df = df[df["wt_or_var"] == "var"]

    wt_concat  = concat_sequences(wt_df["sequence"])  if not wt_df.empty  else ""
    var_concat = concat_sequences(var_df["sequence"]) if not var_df.empty else ""
    all_concat = wt_concat + var_concat if (wt_concat or var_concat) else concat_sequences(df["sequence"])

    wt_counts  = count_aa(wt_concat)   if wt_concat  else pd.Series({aa:0 for aa in AA20}, dtype="int64")
    var_counts = count_aa(var_concat)  if var_concat else pd.Series({aa:0 for aa in AA20}, dtype="int64")
    all_counts = count_aa(all_concat)

    wt_total  = int(wt_counts.sum())
    var_total = int(var_counts.sum())
    all_total = int(all_counts.sum())

    # Long form (AA as index) – used just to transpose cleanly
    global_props_long = pd.DataFrame({
        "wt_count": wt_counts,
        "var_count": var_counts,
    })
    global_props_long["wt_prop"]  = (global_props_long["wt_count"]  / wt_total)  if wt_total  > 0 else 0.0
    global_props_long["var_prop"] = (global_props_long["var_count"] / var_total) if var_total > 0 else 0.0
    global_props_long["delta_pp"] = (global_props_long["var_prop"] - global_props_long["wt_prop"]) * 100

    overall_props_long = pd.DataFrame({
        "all_count": all_counts,
    })
    overall_props_long["all_prop"] = (overall_props_long["all_count"] / all_total) if all_total > 0 else 0.0

    # ▼ Transposed outputs: amino acids as columns
    global_props_wide  = global_props_long.T      # rows: metrics, cols: AA
    overall_props_wide = overall_props_long.T     # rows: metrics, cols: AA

    stem = input_csv.stem
    out_global_wide  = outdir / f"{stem}_global_aa_props_wide.csv"
    out_overall_wide = outdir / f"{stem}_overall_aa_props_wide.csv"

    global_props_wide.to_csv(out_global_wide)
    overall_props_wide.to_csv(out_overall_wide)

    print(f"[OK] WT total:  {wt_total}")
    print(f"[OK] VAR total: {var_total}")
    print(f"[OK] ALL total: {all_total}")
    print(f"[OK] Wrote: {out_global_wide}")
    print(f"[OK] Wrote: {out_overall_wide}")

if __name__ == "__main__":
    csv_path = sys.argv[1] if len(sys.argv) > 1 else r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\results_table.csv"
    outdir = sys.argv[2] if len(sys.argv) > 2 else None
    main(csv_path, outdir)


The following script plots difference in propensities across dataset against thermophilic and mesophilic propensities described by Taylor & Vaisman (2010). The CSV of their propensity table is available in the supplementary data folder.


In [None]:
# -*- coding: utf-8 -*-
"""
Overlay plots: Variant vs comparator groups for AA propensities.
Line chart style with metrics box (L1, L2, |Δ| <= threshold).
"""

import math
import numpy as np
import matplotlib.pyplot as plt

# Palette
duo_palette = {
    'dark blue': '#12436D',  # Variant
    'orange':    '#F46A25',  # Comparator
}

# Canonical AA order
AA_ORDER = ["A","C","D","E","F","G","H","I","K","L",
            "M","N","P","Q","R","S","T","V","W","Y"]

def plot_variant_overlays_line(df, delta_thresh=0.5):
    """
    Plot overlays of Variant vs other groups using line charts.

    Parameters
    ----------
    df : pandas.DataFrame
        Must contain 'Group' column and amino acid columns (A, C, D, ...).
        One row must be 'Variant'.
    delta_thresh : float, optional
        Threshold for counting "close" pairs (default 0.5).

    Returns
    -------
    None
        Displays plots inline.
    """
    if "Group" not in df.columns:
        raise ValueError("Input DataFrame must contain a 'Group' column.")

    # Ensure correct AA column order
    aa_cols = [c for c in df.columns if c != "Group" and c.strip().upper() in AA_ORDER]
    col_lookup = {c.strip().upper(): c for c in aa_cols}
    cols = [col_lookup[aa] for aa in AA_ORDER if aa in col_lookup]

    # Ensure numeric
    df = df.copy()
    for c in cols:
        df[c] = pd.to_numeric(df[c], errors="coerce")

    # Extract Variant row
    mask_variant = df["Group"].str.casefold() == "variant"
    if not mask_variant.any():
        raise ValueError("Row 'Variant' not found in column 'Group'.")
    variant_vals = df.loc[mask_variant, cols].iloc[0].to_numpy(dtype=float)

    comparators = df.loc[~mask_variant, "Group"].tolist()
    n = len(comparators)
    ncols = 2 if n > 1 else 1
    nrows = math.ceil(n / ncols)

    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(6*ncols, 3.8*nrows),
        squeeze=False
    )
    x = np.arange(len(cols))
    xticklabels = [c.strip().upper() for c in cols]

    for ax, grp in zip(axes.ravel(), comparators):
        other_vals = df.loc[df["Group"] == grp, cols].iloc[0].to_numpy(dtype=float)

        # Metrics
        mask = ~np.isnan(variant_vals) & ~np.isnan(other_vals)
        diffs = variant_vals[mask] - other_vals[mask]
        l1 = float(np.sum(np.abs(diffs)))
        l2 = float(np.sqrt(np.sum(diffs**2)))
        qual_count = int(np.sum(np.abs(diffs) <= delta_thresh))

        # Line plots
        ax.plot(x, variant_vals, marker="o", linewidth=2,
                label="Variant", color=duo_palette['dark blue'])
        ax.plot(x, other_vals, marker="s", linewidth=2,
                label=grp, color=duo_palette['orange'])

        # Formatting
        ax.set_title(f"Variant vs {grp}")
        ax.set_xticks(x, labels=xticklabels)
        ax.set_xlabel("Amino acid")
        ax.set_ylabel("Propensity (%)")
        ax.legend(loc="upper right")

        # Metrics box
        ax.text(
            0.01, 0.02,
            f"L1 distance = {l1:.2f}\n"
            f"L2 distance = {l2:.2f}\n"
            f"Pairs ≤ {delta_thresh} = {qual_count}",
            fontsize=8,
            transform=ax.transAxes,
            ha="left", va="bottom",
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.75, edgecolor="none")
        )

    # Hide unused axes
    for ax in axes.ravel()[len(comparators):]:
        ax.axis("off")

    fig.tight_layout()
    plt.show()
aa_df = pd.read_csv(
    r"C:\Users\james\Masters_Degree\Thesis\protein_language_model_project\supplementary_data\AA_propensities_over_all.csv"
)
aa_df.rename(columns={aa_df.columns[0]: "Group"}, inplace=True)

plot_variant_overlays_line(aa_df, delta_thresh=0.5)
