# Fit polyclonal model
Here we fit [polyclonal](https://jbloomlab.github.io/polyclonal) models to the data.

First, import Python modules:

In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import yaml

In [2]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

In [3]:
import os
os.chdir('../../')

## Read input data

Get parameterized variable from [papermill](https://papermill.readthedocs.io/)

In [52]:
# papermill parameters cell (tagged as `parameters`)
prob_escape_csv = None
n_threads = None
pickle_file = None
antibody = None

In [53]:
# Parameters
prob_escape_csv = "results/prob_escape/libA_221108_1_1C04-5G04_1_prob_escape.csv"
pickle_file = "results/polyclonal_fits/libA_221108_1_1C04-5G04_1.pickle"
n_threads = 2


Read the probabilities of escape, and filter for those with sufficient no-antibody counts:

In [54]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")

prob_escape = pd.read_csv(
    prob_escape_csv, keep_default_na=False, na_values="nan"
).query("`no-antibody_count` >= no_antibody_count_threshold")
assert prob_escape.notnull().all().all()


Reading probabilities of escape from results/prob_escape/libA_221108_1_1C04-5G04_1_prob_escape.csv


Read the rest of the configuration and input data:

In [55]:
# get information from config
with open("config.yaml") as f:
    config = yaml.safe_load(f)

antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]

# get site numbering map and the reference sites in order
site_numbering_map = pd.read_csv(config["site_numbering_map"])
reference_sites = site_numbering_map.sort_values("sequential_site")[
    "reference_site"
].tolist()

# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
    polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
    raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]

# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")

antibody='1C04-5G04'
n_threads=2
pickle_file='results/polyclonal_fits/libA_221108_1_1C04-5G04_1.pickle'
antibody_config={'min_epitope_activity_to_include': 0.2, 'plot_kwargs': {'addtl_slider_stats': {'times_seen': 3, 'functional effect': -1.38}, 'slider_binding_range_kwargs': {'n_models': {'step': 1}, 'times_seen': {'step': 1, 'min': 1, 'max': 25}}, 'heatmap_max_at_least': 2, 'heatmap_min_at_least': -2}, 'max_epitopes': 1, 'fit_kwargs': {'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}}


## Some summary statistics
Note that these statistics are only for the variants that passed upstream filtering in the pipeline.

Number of variants per concentration:

In [56]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
1.37,26585
2.05,26585
3.08,26585
4.62,26585
6.93,26585
10.4,26585


Plot mean probability of escape across all variants with the indicated number of mutations.
Note that this plot weights each variant the same in the means regardless of how many barcode counts it has.
We plot means for both censored (set to between 0 and 1) and uncensored probabilities of escape.
Also, note it uses a symlog scale for the y-axis.
Mouseover points for values:

In [57]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape = (
    prob_escape.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart = (
    alt.Chart(mean_prob_escape)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart

  for col_name, dtype in df.dtypes.iteritems():


## Fit `polyclonal` model
First, get the fitting related keyword arguments from the configuration passed by `snakemake`:

In [65]:
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_uniqueness_weight=0.3,  # regularize epitope similarity
)
model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 1010 parameters at Wed Nov 30 19:46:24 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.062239      1249.8      1246.4           0           0           0              0               0       3.4611
         139      9.3347      435.11       425.2      5.0073           0           0       0.098243        0.017707       4.7874
# Successfully finished at Wed Nov 30 19:46:34 2022.
# Starting optimization of 6488 parameters at Wed Nov 30 19:46:34 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.074147      629.39      569.54      52.414  2.4784e-31           0       0.098243          2.5566       4.7874
          83      7.0297      574.54         562      6.4051     0.42618           0       0.011563         0.15379       5.

  for col_name, dtype in df.dtypes.iteritems():


In [11]:
spatial_distances = polyclonal.pdb_utils.inter_residue_distances(
    "scratch_notebooks/221111_model-fitting/4o5n.pdb",
    target_chains=["A", "B"],
)

spatial_distances

Unnamed: 0,site_1,site_2,distance,chain_1,chain_2
0,9,10,1.328212,A,A
1,9,11,3.469929,B,B
2,9,12,6.336130,B,B
3,9,13,9.189821,B,B
4,9,14,8.930696,B,A
...,...,...,...,...,...
260276,497,499,15.936294,B,B
260277,497,500,16.632641,B,B
260278,498,499,23.859705,B,B
260279,498,500,13.285421,B,B


In [46]:
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_uniqueness_weight=0.3,  # regularize epitope similarity
    reg_spatial2_weight=0.001
)
model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 1010 parameters at Wed Nov 30 13:40:14 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.060234      1249.8      1246.4           0           0           0              0               0       3.4611
         139      9.3753      435.11       425.2      5.0073           0           0       0.098243        0.017707       4.7874
# Successfully finished at Wed Nov 30 13:40:24 2022.
# Starting optimization of 6488 parameters at Wed Nov 30 13:40:24 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.075179      629.39      569.54      52.414  2.4784e-31           0       0.098243          2.5566       4.7874
          83      7.1101      574.54         562      6.4051     0.42618           0       0.011563         0.15379       5.

  for col_name, dtype in df.dtypes.iteritems():


In [47]:
display(model.activity_wt_barplot())

In [60]:
def reassign_b_factor(
    input_pdbfile,
    output_pdbfile,
    df,
    metric_col,
    *,
    site_col="site",
    chain_col="chain",
    missing_metric=0,
    model_index=0,
):

    # subset `df` to needed columns and error check it
    cols = [metric_col, site_col, chain_col]
    for col in cols:
        if col not in df.columns:
            raise ValueError(f"`df` lacks column {col}")
    df = df[cols].drop_duplicates()
    if len(df) != len(df.groupby([site_col, chain_col])):
        raise ValueError("non-unique metric for a site in a chain")

#     if df[site_col].dtype != int:
#         raise ValueError("function currently requires `site_col` to be int")

    # read PDB, catch warnings about discontinuous chains
    with warnings.catch_warnings():
        warnings.simplefilter(
            "ignore", category=Bio.PDB.PDBExceptions.PDBConstructionWarning
        )
        pdb = Bio.PDB.PDBParser().get_structure("_", input_pdbfile)

    # get the model out of the PDB
    model = list(pdb.get_models())[model_index]

    # make sure all chains in PDB
    missing_chains = set(df[chain_col]) - {chain.id for chain in model.get_chains()}
    if missing_chains:
        raise ValueError(f"`df` has chains not in PDB: {missing_chains}")

    # make missing_metric a dict if it isn't already
    if not isinstance(missing_metric, dict):
        missing_metric = {chain.id: missing_metric for chain in model.get_chains()}

    # loop over all chains and do coloring
    for chain in model.get_chains():
        chain_id = chain.id
        site_to_val = (
            df.query(f"{chain_col} == @chain_id")
            .set_index(site_col)[metric_col]
            .to_dict()
        )
        for residue in chain:
            site = residue.get_id()[1]
            try:
                metric_val = site_to_val[site]
            except KeyError:
                metric_val = missing_metric[chain_id]
            # for disordered residues, get list of them
            try:
                residuelist = residue.disordered_get_list()
            except AttributeError:
                residuelist = [residue]
            for r in residuelist:
                for atom in r:
                    # for disordered atoms, get list of them
                    try:
                        atomlist = atom.disordered_get_list()
                    except AttributeError:
                        atomlist = [atom]
                    for a in atomlist:
                        a.bfactor = metric_val

    # write PDB
    io = Bio.PDB.PDBIO()
    io.set_structure(pdb)
    io.save(output_pdbfile)

In [61]:
model.mut_escape_site_summary_df()

Unnamed: 0,epitope,site,wildtype,mean,total positive,max,min,total negative,n mutations
0,1,-9,A,-0.001172,0.000000,-0.001172,-0.001172,-0.001172,1
1,1,-4,D,-0.005142,0.000000,-0.005142,-0.005142,-0.005142,1
2,1,-3,A,-0.002190,0.000000,-0.002129,-0.002251,-0.004380,2
3,1,-2,D,0.000156,0.005342,0.003080,-0.002412,-0.004565,5
4,1,-1,T,-0.002416,0.000000,-0.002133,-0.002699,-0.004832,2
...,...,...,...,...,...,...,...,...,...
490,1,537,W,-0.023594,0.000000,-0.009212,-0.037976,-0.047189,2
491,1,538,A,-0.015609,0.000000,-0.009087,-0.028650,-0.062436,4
492,1,539,C,-0.003064,0.000000,-0.003064,-0.003064,-0.003064,1
493,1,540,Q,-0.004828,0.000358,0.000358,-0.008450,-0.014841,3


In [63]:
site_summary_df = model.mut_escape_site_summary_df()
site_summary_df['site'] = site_summary_df['site'].astype(float)
                                      
site_summary_df = site_summary_df.loc[(site_summary_df['site'] < 326) &
                                      (site_summary_df['site'] > 0)
                                     ]
                                      
site_summary_df['site'] = site_summary_df['site'].astype(int)
site_summary_df['chain'] = 'A'

site_summary_df['sum'] = site_summary_df['total positive'] + site_summary_df['total negative']

site_summary_df

Unnamed: 0,epitope,site,wildtype,mean,total positive,max,min,total negative,n mutations,chain,sum
5,1,1,Q,0.000544,0.010459,0.008452,-0.006590,-0.008283,4,A,0.002176
6,1,2,K,0.002125,0.008015,0.007134,-0.001642,-0.001642,3,A,0.006374
7,1,3,I,-0.004356,0.106804,0.020328,-0.032871,-0.185207,18,A,-0.078403
8,1,4,P,-0.011548,0.116220,0.038450,-0.070772,-0.335640,19,A,-0.219420
9,1,5,G,-0.019586,0.046363,0.033597,-0.112364,-0.379320,17,A,-0.332957
...,...,...,...,...,...,...,...,...,...,...,...
289,1,320,M,-0.007440,0.000000,-0.007440,-0.007440,-0.007440,1,A,-0.007440
290,1,321,R,-0.013752,0.000000,-0.004925,-0.022580,-0.027505,2,A,-0.027505
291,1,323,V,0.056724,0.851738,0.341134,-0.026568,-0.057604,14,A,0.794135
292,1,324,P,-0.011839,0.019198,0.012691,-0.040831,-0.113913,8,A,-0.094715


In [64]:
import warnings
import Bio

reassign_b_factor(input_pdbfile='scratch_notebooks/221111_model-fitting/4o5n.pdb',
                  output_pdbfile='scratch_notebooks/221130_summary_figures/libA_cocktail_b-factors_sum.pdb',
                  df=site_summary_df,
                  metric_col='sum',
                  site_col="site",
                  chain_col="chain",
                  missing_metric=0,
                  model_index=0,)