In [1]:
import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

import Bio

In [2]:
import os
os.chdir('../../')

In [16]:
# set up function for mean prob escape chart to avoid clutter from large block of code

def plot_avg_escape(prob_escape):
    max_aa_subs = 4  # group if >= this many substitutions
    
    mean_prob_escape = (
        prob_escape.assign(
            n_subs=lambda x: (
                x["aa_substitutions_reference"]
                .str.split()
                .map(len)
                .clip(upper=max_aa_subs)
                .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
            )
        )
        .groupby(["antibody_concentration", "n_subs"], as_index=False)
        .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
        .rename(
            columns={
                "prob_escape": "censored to [0, 1]",
                "prob_escape_uncensored": "not censored",
            }
        )
        .melt(
            id_vars=["antibody_concentration", "n_subs"],
            var_name="censored",
            value_name="probability escape",
        )
    )

    mean_prob_escape_chart = (
        alt.Chart(mean_prob_escape)
        .encode(
            x=alt.X("antibody_concentration"),
            y=alt.Y(
                "probability escape",
                scale=alt.Scale(type="symlog", constant=0.05),
            ),
            column=alt.Column("censored", title=None),
            color=alt.Color("n_subs", title="n substitutions"),
            tooltip=[
                alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
                for c in mean_prob_escape.columns
            ],
        )
        .mark_line(point=True, size=0.5)
        .properties(width=200, height=125)
        .configure_axis(grid=False)
    )

    return mean_prob_escape_chart

In [3]:
spatial_distances = polyclonal.pdb_utils.inter_residue_distances(
    "scratch_notebooks/221227_model_fitting/4o5n_renumbered_1chain.pdb",
    target_chains=["A"],
)

spatial_distances

Unnamed: 0,site_1,site_2,distance,chain_1,chain_2
0,9,10,1.328212,A,A
1,9,11,3.850353,A,A
2,9,12,6.449567,A,A
3,9,13,9.701373,A,A
4,9,14,12.647217,A,A
...,...,...,...,...,...
254536,721,499,15.731319,A,A
254537,721,500,19.078522,A,A
254538,722,499,67.375801,A,A
254539,722,500,55.555973,A,A


In [4]:
def reassign_b_factor(
    input_pdbfile,
    output_pdbfile,
    df,
    metric_col,
    *,
    site_col="site",
    chain_col="chain",
    missing_metric=0,
    model_index=0,
):

    # subset `df` to needed columns and error check it
    cols = [metric_col, site_col, chain_col]
    for col in cols:
        if col not in df.columns:
            raise ValueError(f"`df` lacks column {col}")
    df = df[cols].drop_duplicates()
#     if len(df) != len(df.groupby([site_col, chain_col])):
#         raise ValueError("non-unique metric for a site in a chain")

    if df[site_col].dtype != int:
        # if we have string type, convert to int
        if df[site_col].map(type).eq(str).all():
            encodes_int = df[site_col].str.fullmatch(r"\d+")
            if encodes_int.all():
                df[site_col] = df[site_col].astype(int)
            else:
                # this may raise an error if there are sites like 214a; before fixing
                # such errors, need to check the `residue.get_id()[1]` command below
                raise ValueError(
                    f"`site_col` has non-integer entries:\n{df[site_col][~encodes_int]}"
                )
        else:
            raise ValueError(f"`site_col` is neither str nor int:\n{df[site_col]}")

    # read PDB, catch warnings about discontinuous chains
    with warnings.catch_warnings():
        warnings.simplefilter(
            "ignore", category=Bio.PDB.PDBExceptions.PDBConstructionWarning
        )
        pdb = Bio.PDB.PDBParser().get_structure("_", input_pdbfile)

    # get the model out of the PDB
    model = list(pdb.get_models())[model_index]

    # make sure all chains in PDB
    missing_chains = set(df[chain_col]) - {chain.id for chain in model.get_chains()}
    if missing_chains:
        raise ValueError(f"`df` has chains not in PDB: {missing_chains}")

    # make missing_metric a dict if it isn't already
    if not isinstance(missing_metric, dict):
        missing_metric = {chain.id: missing_metric for chain in model.get_chains()}

    # loop over all chains and do coloring
    for chain in model.get_chains():
        chain_id = chain.id
        site_to_val = (
            df.query(f"{chain_col} == @chain_id")
            .set_index(site_col)[metric_col]
            .to_dict()
        )
        for residue in chain:
            site = residue.get_id()[1]
            try:
                metric_val = site_to_val[site]
            except KeyError:
                metric_val = missing_metric[chain_id]
            # for disordered residues, get list of them
            try:
                residuelist = residue.disordered_get_list()
            except AttributeError:
                residuelist = [residue]
            for r in residuelist:
                for atom in r:
                    # for disordered atoms, get list of them
                    try:
                        atomlist = atom.disordered_get_list()
                    except AttributeError:
                        atomlist = [atom]
                    for a in atomlist:
                        a.bfactor = metric_val

    # write PDB
    io = Bio.PDB.PDBIO()
    io.set_structure(pdb)
    io.save(output_pdbfile)

In [10]:
def get_b_factor_sum(model, serum):
    df = model.mut_escape_site_summary_df()
    df['site'] = df['site'].astype(int)
    df['sum'] = df['total positive'] + df['total negative']
    
    
    df = model.mut_escape_site_summary_df(min_times_seen=3)
    df['site'] = df['site'].astype(int)
    df['sum'] = df['total positive'] + df['total negative']
    df = df.loc[df['site'] < 326]
    df['chain'] = 'A'
#     return df
#     chains = ['A']
    metric_col = 'sum'
    
#     if isinstance(chains, str) and len(chains) == 1:
#         chains = [chains]
#     df = pd.concat([df.assign(chain=chain) for chain in chains], ignore_index=True)
    
    result_files = []

    for epitope in model.epitopes:
        output_pdbfile = f'scratch_notebooks/221227_model_fitting/230108_escape_pdb/{serum}_sum_epitope_{epitope}.pdb'
        output_pdbfile = output_pdbfile.format(
            epitope=epitope
        ).replace(" ", "_")
        if os.path.dirname(output_pdbfile):
            os.makedirs(os.path.dirname(output_pdbfile), exist_ok=True)
        result_files.append((epitope, output_pdbfile))
        reassign_b_factor(input_pdbfile='data/PDBs/4o5n.pdb',
                          output_pdbfile=output_pdbfile,
                          df = df.query("epitope == @epitope"),
                          metric_col='sum',
                          site_col='site',
                          chain_col='chain',
                          missing_metric=0,
                          model_index=0
                         )
    
    return pd.DataFrame(result_files, columns=["epitope", "PDB file"])

### Get models

In [20]:
prob_escape = pd.read_csv(
    "results/prob_escape/libA_221223_1_AUSAB-05_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape.notnull().all().all()

prob_escape_filtered_05 = prob_escape.loc[(prob_escape['antibody_concentration'] == 0.0074) |
                                          (prob_escape['antibody_concentration'] == 0.0111)
                                         ]

model_05 = polyclonal.Polyclonal(
    n_epitopes=1,
    data_to_fit=prob_escape_filtered_05.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model_05.fit(
    logfreq=200,
    reg_escape_weight=0.1,
)

# display results
display(model_05.activity_wt_barplot())
display(model_05.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 503 parameters at Sun Jan  8 18:06:46 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.019455       44868       44865           0           0           0              0               0       3.6049
          42      1.3459      612.49      606.83      2.1069           0           0              0               0       3.5517
# Successfully finished at Sun Jan  8 18:06:47 2023.
# Starting optimization of 3244 parameters at Sun Jan  8 18:06:47 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.020361      774.35      748.62      22.181  8.0512e-33           0              0               0       3.5517
          64      1.6567      746.34      739.99      2.1956    0.088411           0              0               0       4.0

In [21]:
plot_avg_escape(prob_escape_filtered_05)

In [6]:
prob_escape = pd.read_csv(
    "results/prob_escape/libA_221223_1_AUSAB-07_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape.notnull().all().all()

prob_escape_filtered_07 = prob_escape.loc[(prob_escape['antibody_concentration'] == 0.0010) |
                                       (prob_escape['antibody_concentration'] == 0.0015) |
                                       (prob_escape['antibody_concentration'] == 0.0023) |
                                       (prob_escape['antibody_concentration'] == 0.0034)
                                      ]

reference_sites = pd.read_csv("data/site_map.csv")["reference_site"].tolist()

model_07 = polyclonal.Polyclonal(
    n_epitopes=1,
    data_to_fit=prob_escape_filtered_07.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
#     sites=reference_sites,
#     spatial_distances=spatial_distances,
)

# fit model
opt_res = model_07.fit(
    logfreq=200,
    reg_escape_weight=0.1,
#     reg_uniqueness_weight=0,
#     reg_uniqueness2_weight=1,
#     reg_spatial_weight=0.0,
#     reg_spatial2_weight=0.0005,
)

# display results
display(model_07.activity_wt_barplot())
display(model_07.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 503 parameters at Sun Jan  8 16:03:36 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.025975       89223       89218           0           0           0              0               0       5.1932
          92      3.0457      2430.7      2418.5      9.2091           0           0              0               0       2.9907
# Successfully finished at Sun Jan  8 16:03:39 2023.
# Starting optimization of 3244 parameters at Sun Jan  8 16:03:40 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.028219      2970.5      2904.8      62.781  3.7391e-32           0              0               0       2.9907
         145      4.6616      2725.3      2657.2      60.504      3.9396           0              0               0       3.7

In [17]:
plot_avg_escape(prob_escape_filtered_07)

In [8]:
prob_escape = pd.read_csv(
    "results/prob_escape/libA_221223_1_AUSAB-11_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape.notnull().all().all()

prob_escape_filtered_11 = prob_escape.loc[(prob_escape['antibody_concentration'] == 0.0067) |
                                       (prob_escape['antibody_concentration'] == 0.0100) |
                                       (prob_escape['antibody_concentration'] == 0.0150)
                                      ]

model_11 = polyclonal.Polyclonal(
    n_epitopes=3,
    data_to_fit=prob_escape_filtered_11.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
    spatial_distances=spatial_distances,
)

# fit model
opt_res = model_11.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_uniqueness_weight=0,
    reg_uniqueness2_weight=1,
    reg_spatial_weight=0.0,
    reg_spatial2_weight=0.01,
)

# display results
display(model_11.activity_wt_barplot())
display(model_11.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 1509 parameters at Sun Jan  8 16:06:26 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.061891       57023       57011           0           0           0              0               0       12.014
         122      8.4394      5641.4      5612.9      1.7786           0      26.403              0        0.071812      0.18928
# Successfully finished at Sun Jan  8 16:06:34 2023.
# Starting optimization of 9732 parameters at Sun Jan  8 16:06:34 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0     0.10045      6670.3      6608.4      23.057  3.6386e-32      26.403              0          12.193      0.18928
         139      14.846      6470.9      6425.4       17.78      0.3414      23.044              0          3.4493      0.8

In [7]:
prob_escape_13 = pd.read_csv(
    "results/prob_escape/libA_221027_1_AUSAB-13_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape_13.notnull().all().all()

model_13 = polyclonal.Polyclonal(
    n_epitopes=1,
    data_to_fit=prob_escape_13.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
#     sites=reference_sites,
#     spatial_distances=spatial_distances,
)

# fit model
opt_res = model_13.fit(
    logfreq=200,
    reg_escape_weight=0.1,
#     reg_uniqueness_weight=0,
#     reg_uniqueness2_weight=1,
#     reg_spatial_weight=0.0,
#     reg_spatial2_weight=0.0005,
)

# display results
display(model_13.activity_wt_barplot())
display(model_13.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 503 parameters at Sun Jan  8 16:03:55 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.035253  1.6471e+05   1.647e+05           0           0           0              0               0       5.0008
          83      3.7893      843.77       832.7      6.5188           0           0              0               0       4.5542
# Successfully finished at Sun Jan  8 16:03:59 2023.
# Starting optimization of 3237 parameters at Sun Jan  8 16:03:59 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.040198      1059.8      1002.4       52.92  3.4112e-32           0              0               0       4.5542
         133      6.1014      939.61      909.86      22.893      1.9968           0              0               0        4.

In [18]:
plot_avg_escape(prob_escape_13)

Get RGB values for epitope colors:

In [9]:
import matplotlib.colors

for epitope, hex_color in model_11.epitope_colors.items():
    rgb = [round(val, 3) for val in matplotlib.colors.to_rgb(hex_color)]
    print(f"{epitope}: hex color is {hex_color}; RGB tuple is {rgb}")

1: hex color is #0072B2; RGB tuple is [0.0, 0.447, 0.698]
2: hex color is #CC79A7; RGB tuple is [0.8, 0.475, 0.655]
3: hex color is #4C3549; RGB tuple is [0.298, 0.208, 0.286]


Also get RGB values for negative site escape (orange):

In [16]:
rgb = [round(val, 3) for val in matplotlib.colors.to_rgb('#E69F00')]
rgb

[0.902, 0.624, 0.0]

### Save pdb files with refactored b values

In [11]:
serum_models = {
#     'AUSAB-05': model_05,
    'AUSAB-07': model_07,
    'AUSAB-11': model_11, 
    'AUSAB-13': model_13
}

for serum in serum_models:
    get_b_factor_sum(serum_models[serum], serum)

### Also get cocktail data

In [14]:
prob_escape_cocktail = pd.read_csv(
    "results/prob_escape/libA_221108_1_1C04-5G04_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape_cocktail.notnull().all().all()

model_cocktail = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape_cocktail.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
    spatial_distances=spatial_distances,
)

# fit model
opt_res = model_cocktail.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_uniqueness_weight=0,
    reg_uniqueness2_weight=1,
    reg_spatial_weight=0.0,
    reg_spatial2_weight=0.0005,
)

# display results
display(model_cocktail.activity_wt_barplot())
display(model_cocktail.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 1010 parameters at Sun Jan  8 16:14:47 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.069858      1249.8      1246.4           0           0           0              0               0       3.4611
          30      2.1741      441.79      433.79      1.1553           0      1.7007              0        0.079176       5.0661
# Successfully finished at Sun Jan  8 16:14:49 2023.
# Starting optimization of 6488 parameters at Sun Jan  8 16:14:49 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0     0.12777       614.7      571.86      15.527  1.6044e-32      1.7007              0          20.544       5.0661
          85      9.3592      575.56      564.83      4.4008     0.13674     0.46006              0         0.17158       5.

In [15]:
get_b_factor_sum(model_cocktail, '1C04-5G04_cocktail')

Unnamed: 0,epitope,PDB file
0,1,scratch_notebooks/221227_model_fitting/230108_...
1,2,scratch_notebooks/221227_model_fitting/230108_...


In [19]:
plot_avg_escape(prob_escape_cocktail)

In [23]:
serum_models = {
    'AUSAB-05': model_05,
    'AUSAB-07': model_07,
    'AUSAB-11': model_11, 
    'AUSAB-13': model_13,
    '1C04-5G04_cocktail': model_cocktail
}

for serum in serum_models:
    mut_escape_plot = serum_models[serum].mut_escape_plot(addtl_slider_stats={"times_seen": 3}, 
                                                          init_floor_at_zero=False)
    
    mut_escape_plot.save(f'{serum}_mut_escape_plot.html')

In [27]:
prob_escape = pd.read_csv(
    "results/prob_escape/libA_221021_1_1C04_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape.notnull().all().all()

prob_escape_filtered_1C04 = prob_escape.loc[
#     (prob_escape['antibody_concentration'] == 0.10) |
    (prob_escape['antibody_concentration'] == 0.20) |
                                          (prob_escape['antibody_concentration'] == 0.40) 
#                                           (prob_escape['antibody_concentration'] == 0.80)
                                         ]

model_1C04 = polyclonal.Polyclonal(
    n_epitopes=1,
    data_to_fit=prob_escape_filtered_1C04.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model_1C04.fit(
    logfreq=200,
    reg_escape_weight=0.1,
)

# display results
display(model_1C04.activity_wt_barplot())
display(model_1C04.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 507 parameters at Sun Jan  8 18:21:32 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.020357       24571       24571           0           0           0              0               0      0.18124
         188      4.3824      445.05      422.96      15.958           0           0              0               0       6.1287
# Successfully finished at Sun Jan  8 18:21:37 2023.
# Starting optimization of 3254 parameters at Sun Jan  8 18:21:37 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.027565      644.82      486.08      152.61  1.2866e-31           0              0               0       6.1287
         138       3.734      439.82      398.96      32.242      3.0578           0              0               0       5.5

In [28]:
mut_escape_plot_1C04 = model_1C04.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False)
mut_escape_plot.save('1C04_mut_escape_plot.html')