In [1]:
import pandas as pd
import numpy as np
import re
import string

from Bio import SeqIO, PDB
from natsort import natsort_keygen
from pathlib import Path

import altair as alt
import theme
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

### Format structural alignment

In [2]:
def alignment_to_df(fasta_path, ref_id, start_site=1):
    records = []
    for rec in SeqIO.parse(fasta_path, "fasta"):
        clean_id = rec.id.split("_")[0]
        records.append((clean_id, str(rec.seq)))

    seq_dict = dict(records)
    if ref_id not in seq_dict:
        raise ValueError(f"ref_id '{ref_id}' not found. Available IDs: {list(seq_dict.keys())}")

    other_ids = [rid for rid in seq_dict if rid != ref_id]
    ref_seq = seq_dict[ref_id]

    letters = list(string.ascii_lowercase)
    def insertion_label(base_num, k):
        if k < len(letters):
            return f"{base_num}{letters[k]}"
        else:
            raise ValueError(f"Insertion index {k} exceeds available letters ({len(letters)})")

    final_df = None

    for pdb in other_ids:
        seq = seq_dict[pdb]
        if len(ref_seq) != len(seq):
            raise ValueError("Aligned sequences must be the same length.")
        
        rows = []
        last_numeric = start_site - 1
        ins_count = {}

        for ref_aa, aa in zip(ref_seq, seq):
            if ref_aa == "-" and aa == "-":
                continue

            if ref_aa != "-" and aa != "-":
                last_numeric += 1
                site = str(last_numeric)
                rows.append((site, ref_aa, aa))
                ins_count[last_numeric] = 0

            elif ref_aa == "-" and aa != "-":
                base = last_numeric
                k = ins_count.get(base, 0)
                site = insertion_label(base, k)
                ins_count[base] = k + 1
                rows.append((site, "-", aa))

            else:  # ref_aa != "-" and aa == "-"
                last_numeric += 1
                site = str(last_numeric)
                rows.append((site, ref_aa, "-"))
                ins_count[last_numeric] = 0

            df = pd.DataFrame(rows, columns=["struct_site", f"{ref_id}_aa", f"{pdb}_aa"])

        if final_df is None:
            final_df = df
        else:
            final_df = pd.merge(final_df, df, on=["struct_site", f"{ref_id}_aa"], how="outer")
        
    return final_df.sort_values("struct_site", key=natsort_keygen()).reset_index(drop=True)

ha1_aln = alignment_to_df('../results/foldmason/ha1/result_aa.fa', ref_id='4o5n', start_site=9)
ha2_aln = alignment_to_df('../results/foldmason/ha2/result_aa.fa', ref_id='4o5n', start_site=330)

ha1_aln.head()

Unnamed: 0,struct_site,4o5n_aa,4r8w_aa,4kwm_aa
0,9,P,-,P
1,10,G,-,G
2,11,A,D,D
3,12,T,K,Q
4,13,L,I,I


### Add RSA and RMSD

In [3]:
MAX_ASA_TIEN = {
    'A': 129.0, 'C': 167.0, 'D': 193.0, 'E': 223.0, 'F': 240.0, 'G': 104.0, 
    'H': 224.0, 'I': 197.0, 'K': 236.0, 'L': 201.0, 'M': 224.0, 'N': 195.0, 
    'P': 159.0, 'Q': 225.0, 'R': 274.0, 'S': 155.0, 'T': 172.0, 'V': 174.0, 
    'W': 285.0, 'Y': 263.0
}

def processDSSP(dsspfile, chain=None, max_asa=MAX_ASA_TIEN):
    """Get secondary structure and solvent accessibility from ``dssp``."""
    dssp_cys = re.compile('[a-z]')
    d_dssp = PDB.make_dssp_dict(dsspfile)[0]
    chains = set([chainid for (chainid, r) in d_dssp.keys()])
    if chain is None:
        assert len(chains) == 1, "chain is None, but multiple chains"
        chain = list(chains)[0]
    elif chain not in chains:
        raise ValueError("Invalid chain {0}".format(chain))
    d_df = {'pdb_site':[],
            'amino_acid':[],
            'ASA':[],
            'RSA':[],
            'SS':[],
            'SS_class':[],
            }
    for ((chainid, r), tup) in d_dssp.items():
        if chainid == chain:
            (tmp_aa, ss, acc) = tup[ : 3]
            if dssp_cys.match(tmp_aa):
                aa = 'C'
            else:
                aa = tmp_aa
            if r[2] and not r[2].isspace():
                # site has letter suffix
                d_df['pdb_site'].append(str(r[1]) + r[2].strip())
            else:
                d_df['pdb_site'].append(r[1])
            d_df['amino_acid'].append(aa)
            d_df['ASA'].append(acc)
            d_df['RSA'].append(acc / float(max_asa[aa]))
            d_df['SS'].append(ss)
            if ss in ['G', 'H', 'I', 'P']: # double check if P is helix
                d_df['SS_class'].append('helix')
            elif ss in ['B', 'E']:
                d_df['SS_class'].append('strand')
            elif ss in ['T', 'S', '-']:
                d_df['SS_class'].append('loop')
            else:
                raise ValueError("invalid SS of {0}".format(ss))
    
    return pd.DataFrame(d_df)

h3_ha1_ss = processDSSP("../data/4o5n.dssp", chain='A')[['amino_acid', 'RSA', 'SS']]
h3_ha2_ss = processDSSP("../data/4o5n.dssp", chain='B')[['amino_acid', 'RSA', 'SS']]

h5_ha1_ss = processDSSP("../data/4kwm.dssp", chain='A')[['amino_acid', 'RSA', 'SS']]
h5_ha2_ss = processDSSP("../data/4kwm.dssp", chain='B')[['amino_acid', 'RSA', 'SS']]

h7_ha1_ss = processDSSP("../data/4r8w.dssp", chain='A')[['amino_acid', 'RSA', 'SS']]
h7_ha2_ss = processDSSP("../data/4r8w.dssp", chain='B')[['amino_acid', 'RSA', 'SS']]

In [5]:
def merge_rsa_by_alignment(aln_df, rsa_df, aa_col, prefix=True):
    """
    Merge RSA/SS data into an alignment DataFrame by matching amino acids.

    Parameters
    ----------
    aln_df : DataFrame with an alignment column (e.g., '4o5n_aa') containing letters or '-'.
    rsa_df : DataFrame with columns ['amino_acid', ...other RSA/SS fields...] in order.
    aa_col : Column in aln_df to traverse (default '4o5n_aa').
    prefix : If True, prefix copied columns with f"{aa_col}_" to avoid collisions.
    """
    if 'amino_acid' not in rsa_df.columns:
        raise ValueError("rsa_df must contain 'amino_acid' column.")

    aa = aln_df[aa_col].astype(str)
    is_letter = aa.str.fullmatch(r"[A-Za-z]")

    # How many letters we’ll consume from rsa_df
    n_letters = int(is_letter.sum())
    if len(rsa_df) < n_letters:
        raise ValueError(
            f"rsa_df has {len(rsa_df)} rows but df[{aa_col}] has {n_letters} letters."
        )

    # The rsa rows we’ll actually use (first n_letters, in order)
    rsa_used = rsa_df.iloc[:n_letters].reset_index(drop=True)

    # Validate amino-acid matching (case-insensitive)
    aa_letters = aa[is_letter].str.upper().to_numpy()
    rsa_letters = rsa_used['amino_acid'].astype(str).str.upper().to_numpy()
    mism = aa_letters != rsa_letters
    if mism.any():
        # Report first few mismatches with df row indices
        bad_df_idx = aa[is_letter].index[mism]
        examples = []
        for k, i in enumerate(bad_df_idx[:5]):
            j = np.flatnonzero(mism)[k]
            examples.append(
                f"{i}: df[{aa_col}]={aln_df.loc[i, aa_col]!r} vs rsa={rsa_used.loc[j, 'amino_acid']!r}"
            )
        raise ValueError("Amino-acid mismatch at rows [" + ", ".join(examples) + "].")

    # Columns to copy from rsa_df (exclude 'amino_acid' itself)
    cols_to_copy = [c for c in rsa_df.columns if c != 'amino_acid']

    # Prepare an output frame filled with NaN for all df rows
    to_add = pd.DataFrame(index=aln_df.index, columns=cols_to_copy, dtype='object')

    # Fill only the letter rows from the corresponding rsa_used rows (in order)
    to_add.loc[is_letter, cols_to_copy] = rsa_used[cols_to_copy].to_numpy()

    if prefix:
        to_add = to_add.add_prefix(f"{aa_col}_")

    return aln_df.join(to_add)

ha1_aln_ss = merge_rsa_by_alignment(ha1_aln, h3_ha1_ss, '4o5n_aa')
ha1_aln_ss = merge_rsa_by_alignment(ha1_aln_ss, h5_ha1_ss, '4kwm_aa')
ha1_aln_ss = merge_rsa_by_alignment(ha1_aln_ss, h7_ha1_ss, '4r8w_aa')

ha2_aln_ss = merge_rsa_by_alignment(ha2_aln, h3_ha2_ss, '4o5n_aa')
ha2_aln_ss = merge_rsa_by_alignment(ha2_aln_ss, h5_ha2_ss, '4kwm_aa')
ha2_aln_ss = merge_rsa_by_alignment(ha2_aln_ss, h7_ha2_ss, '4r8w_aa')

ha1_aln_ss.head()

Unnamed: 0,struct_site,4o5n_aa,4r8w_aa,4kwm_aa,4o5n_aa_RSA,4o5n_aa_SS,4kwm_aa_RSA,4kwm_aa_SS,4r8w_aa_RSA,4r8w_aa_SS
0,9,P,-,P,1.081761,-,1.138365,-,,
1,10,G,-,G,0.153846,-,0.173077,-,,
2,11,A,D,D,0.054264,E,0.098446,-,0.259067,-
3,12,T,K,Q,0.267442,E,0.217778,E,0.279661,E
4,13,L,I,I,0.0,E,0.0,E,0.005076,E


In [6]:
from pathlib import Path

def parse_rmsd_txt(path, id=None):
    lines = Path(path).read_text().splitlines()
    out = []
    for ln in lines[1:]:
        m = re.match(r"\s*(\d+):\s*(.*)\s*$", ln)
        if not m:
            continue
        i = int(m.group(1))
        v = m.group(2).strip()
        v = None if v == "None" else float(v)
        out.append((i, v))
    return pd.DataFrame(out, columns=["aln_idx", f"rmsd_{id}"])

h3_h5_ha1_rmsd = parse_rmsd_txt('../results/rmsd/h3_h5_ha1_rmsd.txt', id='h3h5')
h3_h5_ha2_rmsd = parse_rmsd_txt('../results/rmsd/h3_h5_ha2_rmsd.txt', id='h3h5')

h3_h7_ha1_rmsd = parse_rmsd_txt('../results/rmsd/h3_h7_ha1_rmsd.txt', id='h3h7')
h3_h7_ha2_rmsd = parse_rmsd_txt('../results/rmsd/h3_h7_ha2_rmsd.txt', id='h3h7')

h5_h7_ha1_rmsd = parse_rmsd_txt('../results/rmsd/h5_h7_ha1_rmsd.txt', id='h5h7')
h5_h7_ha2_rmsd = parse_rmsd_txt('../results/rmsd/h5_h7_ha2_rmsd.txt', id='h5h7')

In [7]:
def merge_aln_rmsd(aln_df, rmsd_df):
    assert len(aln_df) == len(rmsd_df)
    return pd.concat(
        [aln_df, rmsd_df], axis=1
    ).drop(columns=['aln_idx'])

struct_align_df = (
    pd.concat(
        [merge_aln_rmsd(
            merge_aln_rmsd(
                merge_aln_rmsd(ha1_aln_ss, h3_h5_ha1_rmsd),
                h3_h7_ha1_rmsd
            ), h5_h7_ha1_rmsd
        ),
        merge_aln_rmsd(
            merge_aln_rmsd(
                merge_aln_rmsd(ha2_aln_ss, h3_h5_ha2_rmsd),
                h3_h7_ha2_rmsd
            ), h5_h7_ha2_rmsd
        )], ignore_index=True
    )
    .sort_values("struct_site", key=natsort_keygen())
    .reset_index(drop=True)
)

struct_align_df.head()

Unnamed: 0,struct_site,4o5n_aa,4r8w_aa,4kwm_aa,4o5n_aa_RSA,4o5n_aa_SS,4kwm_aa_RSA,4kwm_aa_SS,4r8w_aa_RSA,4r8w_aa_SS,rmsd_h3h5,rmsd_h3h7,rmsd_h5h7
0,9,P,-,P,1.081761,-,1.138365,-,,,9.178066,,
1,10,G,-,G,0.153846,-,0.173077,-,,,8.183151,,
2,11,A,D,D,0.054264,E,0.098446,-,0.259067,-,5.04805,1.735252,4.172437
3,12,T,K,Q,0.267442,E,0.217778,E,0.279661,E,3.93908,1.489412,4.694308
4,13,L,I,I,0.0,E,0.0,E,0.005076,E,3.725425,1.019553,3.75921


### Add DMS background residues and numbering

In [8]:
# remove sites that are missing in the structural alignment

h3_missing = [*range(1, 9), *range(326, 330), *range(503, 505)]
h3_wt = pd.read_csv(
    '../data/MDCKSIAT1_entry_func_effects.csv'
)[['site', 'wildtype']].drop_duplicates().reset_index(drop=True).query(
    "site not in @h3_missing"
)

h5_missing = [*map(str, range(1, 9)),
              *map(str, range(325, 339)),
              *map(str, range(503, 552)),
              "328a", "328b", "328c", "510a"]
h5_wt = pd.read_csv(
    '../data/293T_entry_func_effects.csv'
)[['site', 'wildtype']].drop_duplicates().reset_index(drop=True).query(
    "site not in @h5_missing and ~site.str.contains('-', na=False)",
    engine="python"
)

h7_missing = [*map(str, range(326, 331)), 
              *map(str, range(500, 515)), "328a"]
h7_wt = pd.read_csv(
    '../data/293_2-6_entry_func_effects.csv'
)[['site', 'wildtype']].drop_duplicates().reset_index(drop=True).query(
    "site not in @h7_missing"
)

In [9]:
def add_wt_cols(
    aln_df, wt_df,
    ref_col="4o5n_aa",      # col to align against with letters or '-'
    site_col="site",        # col in wt_df with site numbering
    aa_col="wildtype",      # column in wt_df with wt residue
    out_site_col="wt_site",
    out_aa_col="wt_aa",
):
    wt = wt_df.reset_index(drop=True).sort_values("site", key=natsort_keygen())
    is_letter = aln_df[ref_col].astype(str).str.fullmatch(r"[A-Za-z]")

    idx = is_letter.cumsum() - 1
    take = is_letter & (idx < len(wt))

    # prefill outputs as NA
    out_site = pd.Series(np.nan, index=aln_df.index, dtype="object")
    out_aa   = pd.Series(np.nan, index=aln_df.index, dtype="object")

    # fill where we have letters and WT left
    pos = idx[take].to_numpy()
    out_site.loc[take] = wt[site_col].to_numpy()[pos]
    out_aa.loc[take]   = wt[aa_col].to_numpy()[pos]

    # summary stats
    comparable = take
    comp_idx = comparable[comparable].index
    n_compared = len(comp_idx)
    if n_compared:
        ref_up = aln_df.loc[comp_idx, ref_col].astype("string").str.upper().to_numpy()
        wt_up  = pd.Series(out_aa, dtype="string").loc[comp_idx].str.upper().to_numpy()
        n_match = int((ref_up == wt_up).sum())
        pct_match = float(np.round(100.0 * n_match / n_compared, 2))

    else:
        n_match = 0
        pct_match = np.nan

    out_df = aln_df.assign(**{out_site_col: out_site, out_aa_col: out_aa})
    return out_df, {"n_compared": n_compared, "n_match": n_match, "pct_match": pct_match}

aln_out, stats = add_wt_cols(
    struct_align_df, 
    h3_wt, 
    ref_col="4o5n_aa",
    out_aa_col="h3_wt_aa",
    out_site_col="h3_site"
)
print(stats)

aln_out, stats = add_wt_cols(
    aln_out, 
    h5_wt, 
    ref_col="4kwm_aa",
    out_aa_col="h5_wt_aa",
    out_site_col="h5_site"
)
print(stats)

aln_out, stats = add_wt_cols(
    aln_out, 
    h7_wt, 
    ref_col="4r8w_aa",
    out_aa_col="h7_wt_aa",
    out_site_col="h7_site"
)
print(stats)

{'n_compared': 490, 'n_match': 457, 'pct_match': 93.27}
{'n_compared': 487, 'n_match': 453, 'pct_match': 93.02}
{'n_compared': 485, 'n_match': 485, 'pct_match': 100.0}


In [10]:
aln_out.to_csv('../results/structural_alignment/structural_alignment.csv', index=False)
aln_out.head()

Unnamed: 0,struct_site,4o5n_aa,4r8w_aa,4kwm_aa,4o5n_aa_RSA,4o5n_aa_SS,4kwm_aa_RSA,4kwm_aa_SS,4r8w_aa_RSA,4r8w_aa_SS,rmsd_h3h5,rmsd_h3h7,rmsd_h5h7,h3_site,h3_wt_aa,h5_site,h5_wt_aa,h7_site,h7_wt_aa
0,9,P,-,P,1.081761,-,1.138365,-,,,9.178066,,,9,S,9,K,,
1,10,G,-,G,0.153846,-,0.173077,-,,,8.183151,,,10,T,10,S,,
2,11,A,D,D,0.054264,E,0.098446,-,0.259067,-,5.04805,1.735252,4.172437,11,A,11,D,11.0,D
3,12,T,K,Q,0.267442,E,0.217778,E,0.279661,E,3.93908,1.489412,4.694308,12,T,12,Q,12.0,K
4,13,L,I,I,0.0,E,0.0,E,0.005076,E,3.725425,1.019553,3.75921,13,L,13,I,13.0,I


### Sanity check: RSA correlation

In [12]:
def plot_rsa_correlation(aln_df, pdb_x, pdb_y, colors=['#5773CC', '#FFB900']):
    r_value = aln_df[f'{pdb_x}_aa_RSA'].corr(aln_df[f'{pdb_y}_aa_RSA'])
    r_text = f"r = {r_value:.2f}"

    chart = alt.Chart(
        aln_df.drop_duplicates()
    ).transform_calculate(
        same_wildtype=f'datum["{pdb_x}_aa"] == datum["{pdb_y}_aa"]'
    ).mark_circle(
        size=30, opacity=0.6, color='#899DA4'
    ).encode(
        y=alt.Y(f'{pdb_x}_aa_RSA', title=f'RSA in PDB {pdb_x}'),
        x=alt.X(f'{pdb_y}_aa_RSA', title=f'RSA in PDB {pdb_y}'),
        color=alt.Color(
            'same_wildtype:N',
            title=['Shared amino', 'acid at site'], 
            scale=alt.Scale(domain=[True, False], range=[colors[0], colors[1]])
        ),
        tooltip=['struct_site', f'{pdb_x}_aa', f'{pdb_y}_aa', f'{pdb_x}_aa_RSA', f'{pdb_y}_aa_RSA']
    ).properties(
        width=175,
        height=175
    )

    r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
        align='left',
        baseline='top',
        fontSize=16,
        fontWeight='normal',
        color='black'
    ).encode(
        text='text:N',
        x=alt.value(5), 
        y=alt.value(5)
    )

    return chart + r_label

(
    plot_rsa_correlation(aln_out, '4o5n', '4kwm') |
    plot_rsa_correlation(aln_out, '4o5n', '4r8w') |
    plot_rsa_correlation(aln_out, '4kwm', '4r8w')
)