# Fit polyclonal model
Here we fit [polyclonal](https://jbloomlab.github.io/polyclonal) models to the data.

First, import Python modules:

In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import yaml

In [2]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

In [3]:
import os
os.chdir('../../')

## Read input data

Get parameterized variable from [papermill](https://papermill.readthedocs.io/)

In [4]:
# papermill parameters cell (tagged as `parameters`)
prob_escape_csv = None
n_threads = None
pickle_file = None
antibody = None

In [5]:
# Parameters
prob_escape_csv = "results/prob_escape/libA_221108_1_1C04-5G04_1_prob_escape.csv"
pickle_file = "results/polyclonal_fits/libA_221108_1_1C04-5G04_1.pickle"
n_threads = 2


Read the probabilities of escape, and filter for those with sufficient no-antibody counts:

In [6]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")

prob_escape = pd.read_csv(
    prob_escape_csv, keep_default_na=False, na_values="nan"
).query("`no-antibody_count` >= no_antibody_count_threshold")
assert prob_escape.notnull().all().all()


Reading probabilities of escape from results/prob_escape/libA_221108_1_1C04-5G04_1_prob_escape.csv


Read the rest of the configuration and input data:

In [7]:
# get information from config
with open("config.yaml") as f:
    config = yaml.safe_load(f)

antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]

# get site numbering map and the reference sites in order
site_numbering_map = pd.read_csv(config["site_numbering_map"])
reference_sites = site_numbering_map.sort_values("sequential_site")[
    "reference_site"
].tolist()

# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
    polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
    raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]

# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")

antibody='1C04-5G04'
n_threads=2
pickle_file='results/polyclonal_fits/libA_221108_1_1C04-5G04_1.pickle'
antibody_config={'min_epitope_activity_to_include': 0.2, 'plot_kwargs': {'addtl_slider_stats': {'times_seen': 3, 'functional effect': -1.38}, 'slider_binding_range_kwargs': {'n_models': {'step': 1}, 'times_seen': {'step': 1, 'min': 1, 'max': 25}}, 'heatmap_max_at_least': 2, 'heatmap_min_at_least': -2}, 'max_epitopes': 1, 'fit_kwargs': {'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}}


## Some summary statistics
Note that these statistics are only for the variants that passed upstream filtering in the pipeline.

Number of variants per concentration:

In [8]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
1.37,26585
2.05,26585
3.08,26585
4.62,26585
6.93,26585
10.4,26585


Plot mean probability of escape across all variants with the indicated number of mutations.
Note that this plot weights each variant the same in the means regardless of how many barcode counts it has.
We plot means for both censored (set to between 0 and 1) and uncensored probabilities of escape.
Also, note it uses a symlog scale for the y-axis.
Mouseover points for values:

In [10]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape = (
    prob_escape.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart = (
    alt.Chart(mean_prob_escape)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart

  for col_name, dtype in df.dtypes.iteritems():


## Fit `polyclonal` model
First, get the fitting related keyword arguments from the configuration passed by `snakemake`:

In [11]:
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_uniqueness_weight=0.3,  # regularize epitope similarity
)
model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 1010 parameters at Fri Dec  2 13:26:44 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.083948      1249.8      1246.4           0           0           0              0               0       3.4611
         139       9.562      435.11       425.2      5.0073           0           0       0.098243        0.017707       4.7874
# Successfully finished at Fri Dec  2 13:26:54 2022.
# Starting optimization of 6488 parameters at Fri Dec  2 13:26:54 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.072612      629.39      569.54      52.414  2.4784e-31           0       0.098243          2.5566       4.7874
          83      6.8014      574.54         562      6.4051     0.42618           0       0.011563         0.15379       5.

  for col_name, dtype in df.dtypes.iteritems():


In [14]:
prob_escape_single = prob_escape.loc[prob_escape['n_aa_substitutions'] == 1]

model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape_single.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_uniqueness_weight=0.3,  # regularize epitope similarity
)
model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 732 parameters at Fri Dec  2 13:28:08 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.012256      121.99      118.53           0           0           0              0               0       3.4611
          87      1.2246      24.416      16.375      3.1699           0           0       0.033231       0.0089425       4.8284
# Successfully finished at Fri Dec  2 13:28:09 2022.
# Starting optimization of 3858 parameters at Fri Dec  2 13:28:09 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.022694      140.04      109.38      25.181  1.3326e-31           0       0.033231         0.62585       4.8284
          49     0.97325      109.51      103.92     0.22984   0.0095984           0     9.7078e-06      7.9476e-05       5.3

  for col_name, dtype in df.dtypes.iteritems():


In [16]:
prob_escape_double = prob_escape.loc[prob_escape['n_aa_substitutions'] == 2]

model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape_double.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_uniqueness_weight=0.3,  # regularize epitope similarity
)
model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 838 parameters at Fri Dec  2 13:28:42 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.021541      309.31      305.85           0           0           0              0               0       3.4611
          42     0.93051      127.65       120.9       1.367           0           0       0.023671       0.0044845       5.3566
# Successfully finished at Fri Dec  2 13:28:43 2022.
# Starting optimization of 5272 parameters at Fri Dec  2 13:28:43 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0     0.02353      156.99      135.14      15.728  5.7283e-32           0       0.023671         0.74579       5.3566
          60      1.6029      141.99      135.14      1.2666     0.19215           0      0.0011144         0.02284       5.3

  for col_name, dtype in df.dtypes.iteritems():


In [17]:
prob_escape_triple = prob_escape.loc[prob_escape['n_aa_substitutions'] == 3]

model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape_triple.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_uniqueness_weight=0.3,  # regularize epitope similarity
)
model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 898 parameters at Fri Dec  2 13:30:26 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.026337      335.97       332.5           0           0           0              0               0       3.4611
          46      1.2865      134.22      127.61      1.1034           0           0        0.01567       0.0027254       5.4867
# Successfully finished at Fri Dec  2 13:30:28 2022.
# Starting optimization of 5614 parameters at Fri Dec  2 13:30:28 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.026472      147.22      127.91      13.266    6.26e-32           0        0.01567         0.54028       5.4867
          42      1.2943      135.28      128.69     0.93308    0.043813           0     0.00067844        0.012657          

  for col_name, dtype in df.dtypes.iteritems():


In [20]:
prob_escape_single.loc[prob_escape_single['antibody_concentration'] == 1.37]

Unnamed: 0,library,antibody_sample,no-antibody_sample,aa_substitutions_sequential,n_aa_substitutions,barcode,prob_escape,prob_escape_uncensored,antibody_count,no-antibody_count,antibody_neut_standard_count,no-antibody_neut_standard_count,total_no_antibody_count,no_antibody_count_threshold,aa_substitutions_reference,antibody,antibody_concentration
332906,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,K297I,1,ATAACACAAAAAAGTA,0.0226,0.0226,68114,270456,68369,6147,10886757,22,K278I,1C04-5G04,1.37
332920,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,S164A,1,TTGAAAGATTATGAAG,0.0350,0.0350,17017,43771,68369,6147,10886757,22,S145A,1C04-5G04,1.37
332922,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,E363D,1,ATCAGCAAGCGGAGAC,0.0822,0.0822,16531,18089,68369,6147,10886757,22,E344D,1C04-5G04,1.37
332925,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,G161S,1,TAACGGGACCAGTGTA,0.0262,0.0262,13557,46464,68369,6147,10886757,22,G142S,1C04-5G04,1.37
332932,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,D209A,1,CCTTAGTGTAATAAAA,0.0188,0.0188,11659,55637,68369,6147,10886757,22,D190A,1C04-5G04,1.37
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
399389,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,H454Q,1,TTTTGAATGTGTTCCT,0.0000,0.0000,0,32,68369,6147,10886757,22,H435Q,1C04-5G04,1.37
399433,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,K257R,1,TTTTTACAAAGAGTCC,0.0000,0.0000,0,61,68369,6147,10886757,22,K238R,1C04-5G04,1.37
399434,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,R241Y,1,TTTTTACAGTTACGAA,0.0000,0.0000,0,50,68369,6147,10886757,22,R222Y,1C04-5G04,1.37
399449,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,Q382T,1,TTTTTCCCCACGCAAG,0.0000,0.0000,0,48,68369,6147,10886757,22,Q363T,1C04-5G04,1.37


In [21]:
prob_escape_double.loc[prob_escape_double['antibody_concentration'] == 1.37]

Unnamed: 0,library,antibody_sample,no-antibody_sample,aa_substitutions_sequential,n_aa_substitutions,barcode,prob_escape,prob_escape_uncensored,antibody_count,no-antibody_count,antibody_neut_standard_count,no-antibody_neut_standard_count,total_no_antibody_count,no_antibody_count_threshold,aa_substitutions_reference,antibody,antibody_concentration
332908,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,I393V L403F,2,CAAGCAAAGAGGAATG,0.0798,0.0798,25973,29260,68369,6147,10886757,22,I374V L384F,1C04-5G04,1.37
332927,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,K101A N235K,2,GTTAGGTGTGGGACCC,0.4424,0.4424,12946,2631,68369,6147,10886757,22,K82A N216K,1C04-5G04,1.37
332929,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,N235R K387H,2,AGAATTTGTCACGAGT,0.1645,0.1645,12439,6797,68369,6147,10886757,22,N216R K368H,1C04-5G04,1.37
332937,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,S250D Q375S,2,TATTTTGGAACAGTTT,0.0199,0.0199,11170,50458,68369,6147,10886757,22,S231D Q356S,1C04-5G04,1.37
332938,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,K140N A449L,2,GGTTGAAGGTTTGGTT,0.1183,0.1183,11138,8464,68369,6147,10886757,22,K121N A430L,1C04-5G04,1.37
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
399304,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,I350F K465L,2,TTTTACAGGCTACTCA,0.0000,0.0000,0,53,68369,6147,10886757,22,I331F K446L,1C04-5G04,1.37
399370,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,T150V K154R,2,TTTTCGTCGGTCTTTC,0.0000,0.0000,0,50,68369,6147,10886757,22,T131V K135R,1C04-5G04,1.37
399430,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,S212F H412N,2,TTTTTAACAATAACCT,0.0000,0.0000,0,80,68369,6147,10886757,22,S193F H393N,1C04-5G04,1.37
399451,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,K154D V366S,2,TTTTTCGAGATAAATG,0.0000,0.0000,0,48,68369,6147,10886757,22,K135D V347S,1C04-5G04,1.37


In [23]:
prob_escape.loc[prob_escape['antibody_concentration'] == 1.37]

Unnamed: 0,library,antibody_sample,no-antibody_sample,aa_substitutions_sequential,n_aa_substitutions,barcode,prob_escape,prob_escape_uncensored,antibody_count,no-antibody_count,antibody_neut_standard_count,no-antibody_neut_standard_count,total_no_antibody_count,no_antibody_count_threshold,aa_substitutions_reference,antibody,antibody_concentration
332905,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,R111S V366M R402S,3,TATCTACCTAACGAAA,0.0920,0.0920,95694,93568,68369,6147,10886757,22,R92S V347M R383S,1C04-5G04,1.37
332906,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,K297I,1,ATAACACAAAAAAGTA,0.0226,0.0226,68114,270456,68369,6147,10886757,22,K278I,1C04-5G04,1.37
332907,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,A125M P246H I393A G398Q F411Y,5,CTTTCAATTATGAGAC,0.3589,0.3589,44394,11122,68369,6147,10886757,22,A106M P227H I374A G379Q F392Y,1C04-5G04,1.37
332908,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,I393V L403F,2,CAAGCAAAGAGGAATG,0.0798,0.0798,25973,29260,68369,6147,10886757,22,I374V L384F,1C04-5G04,1.37
332909,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,Q94N A231T P343Q V366S,4,CAGCAACTTTCAGATA,0.0431,0.0431,25911,54000,68369,6147,10886757,22,Q75N A212T P324Q V347S,1C04-5G04,1.37
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
399473,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,R111I G161N K410E,3,TTTTTTCCCCTATGCA,0.0000,0.0000,0,32,68369,6147,10886757,22,R92I G142N K391E,1C04-5G04,1.37
399475,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,A182V K283A S289T T337L,4,TTTTTTGAACACCAAG,0.0000,0.0000,0,153,68369,6147,10886757,22,A163V K264A S270T T318L,1C04-5G04,1.37
399478,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,Q330A K487T,2,TTTTTTGCATAGAGAG,0.0000,0.0000,0,46,68369,6147,10886757,22,Q311A K468T,1C04-5G04,1.37
399480,libA,221108_1_antibody_1C04-5G04_1.37_1,221108_1_no-antibody_control_1,T179R G205C P343R V366L E378D L458I,6,TTTTTTTCACCAACTG,0.0000,0.0000,0,84,68369,6147,10886757,22,T160R G186C P324R V347L E359D L439I,1C04-5G04,1.37


In [11]:
spatial_distances = polyclonal.pdb_utils.inter_residue_distances(
    "scratch_notebooks/221111_model-fitting/4o5n.pdb",
    target_chains=["A", "B"],
)

spatial_distances

Unnamed: 0,site_1,site_2,distance,chain_1,chain_2
0,9,10,1.328212,A,A
1,9,11,3.469929,B,B
2,9,12,6.336130,B,B
3,9,13,9.189821,B,B
4,9,14,8.930696,B,A
...,...,...,...,...,...
260276,497,499,15.936294,B,B
260277,497,500,16.632641,B,B
260278,498,499,23.859705,B,B
260279,498,500,13.285421,B,B


In [46]:
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_uniqueness_weight=0.3,  # regularize epitope similarity
    reg_spatial2_weight=0.001
)
model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 1010 parameters at Wed Nov 30 13:40:14 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.060234      1249.8      1246.4           0           0           0              0               0       3.4611
         139      9.3753      435.11       425.2      5.0073           0           0       0.098243        0.017707       4.7874
# Successfully finished at Wed Nov 30 13:40:24 2022.
# Starting optimization of 6488 parameters at Wed Nov 30 13:40:24 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.075179      629.39      569.54      52.414  2.4784e-31           0       0.098243          2.5566       4.7874
          83      7.1101      574.54         562      6.4051     0.42618           0       0.011563         0.15379       5.

  for col_name, dtype in df.dtypes.iteritems():


In [47]:
display(model.activity_wt_barplot())

In [60]:
def reassign_b_factor(
    input_pdbfile,
    output_pdbfile,
    df,
    metric_col,
    *,
    site_col="site",
    chain_col="chain",
    missing_metric=0,
    model_index=0,
):

    # subset `df` to needed columns and error check it
    cols = [metric_col, site_col, chain_col]
    for col in cols:
        if col not in df.columns:
            raise ValueError(f"`df` lacks column {col}")
    df = df[cols].drop_duplicates()
    if len(df) != len(df.groupby([site_col, chain_col])):
        raise ValueError("non-unique metric for a site in a chain")

#     if df[site_col].dtype != int:
#         raise ValueError("function currently requires `site_col` to be int")

    # read PDB, catch warnings about discontinuous chains
    with warnings.catch_warnings():
        warnings.simplefilter(
            "ignore", category=Bio.PDB.PDBExceptions.PDBConstructionWarning
        )
        pdb = Bio.PDB.PDBParser().get_structure("_", input_pdbfile)

    # get the model out of the PDB
    model = list(pdb.get_models())[model_index]

    # make sure all chains in PDB
    missing_chains = set(df[chain_col]) - {chain.id for chain in model.get_chains()}
    if missing_chains:
        raise ValueError(f"`df` has chains not in PDB: {missing_chains}")

    # make missing_metric a dict if it isn't already
    if not isinstance(missing_metric, dict):
        missing_metric = {chain.id: missing_metric for chain in model.get_chains()}

    # loop over all chains and do coloring
    for chain in model.get_chains():
        chain_id = chain.id
        site_to_val = (
            df.query(f"{chain_col} == @chain_id")
            .set_index(site_col)[metric_col]
            .to_dict()
        )
        for residue in chain:
            site = residue.get_id()[1]
            try:
                metric_val = site_to_val[site]
            except KeyError:
                metric_val = missing_metric[chain_id]
            # for disordered residues, get list of them
            try:
                residuelist = residue.disordered_get_list()
            except AttributeError:
                residuelist = [residue]
            for r in residuelist:
                for atom in r:
                    # for disordered atoms, get list of them
                    try:
                        atomlist = atom.disordered_get_list()
                    except AttributeError:
                        atomlist = [atom]
                    for a in atomlist:
                        a.bfactor = metric_val

    # write PDB
    io = Bio.PDB.PDBIO()
    io.set_structure(pdb)
    io.save(output_pdbfile)

In [61]:
model.mut_escape_site_summary_df()

Unnamed: 0,epitope,site,wildtype,mean,total positive,max,min,total negative,n mutations
0,1,-9,A,-0.001172,0.000000,-0.001172,-0.001172,-0.001172,1
1,1,-4,D,-0.005142,0.000000,-0.005142,-0.005142,-0.005142,1
2,1,-3,A,-0.002190,0.000000,-0.002129,-0.002251,-0.004380,2
3,1,-2,D,0.000156,0.005342,0.003080,-0.002412,-0.004565,5
4,1,-1,T,-0.002416,0.000000,-0.002133,-0.002699,-0.004832,2
...,...,...,...,...,...,...,...,...,...
490,1,537,W,-0.023594,0.000000,-0.009212,-0.037976,-0.047189,2
491,1,538,A,-0.015609,0.000000,-0.009087,-0.028650,-0.062436,4
492,1,539,C,-0.003064,0.000000,-0.003064,-0.003064,-0.003064,1
493,1,540,Q,-0.004828,0.000358,0.000358,-0.008450,-0.014841,3


In [63]:
site_summary_df = model.mut_escape_site_summary_df()
site_summary_df['site'] = site_summary_df['site'].astype(float)
                                      
site_summary_df = site_summary_df.loc[(site_summary_df['site'] < 326) &
                                      (site_summary_df['site'] > 0)
                                     ]
                                      
site_summary_df['site'] = site_summary_df['site'].astype(int)
site_summary_df['chain'] = 'A'

site_summary_df['sum'] = site_summary_df['total positive'] + site_summary_df['total negative']

site_summary_df

Unnamed: 0,epitope,site,wildtype,mean,total positive,max,min,total negative,n mutations,chain,sum
5,1,1,Q,0.000544,0.010459,0.008452,-0.006590,-0.008283,4,A,0.002176
6,1,2,K,0.002125,0.008015,0.007134,-0.001642,-0.001642,3,A,0.006374
7,1,3,I,-0.004356,0.106804,0.020328,-0.032871,-0.185207,18,A,-0.078403
8,1,4,P,-0.011548,0.116220,0.038450,-0.070772,-0.335640,19,A,-0.219420
9,1,5,G,-0.019586,0.046363,0.033597,-0.112364,-0.379320,17,A,-0.332957
...,...,...,...,...,...,...,...,...,...,...,...
289,1,320,M,-0.007440,0.000000,-0.007440,-0.007440,-0.007440,1,A,-0.007440
290,1,321,R,-0.013752,0.000000,-0.004925,-0.022580,-0.027505,2,A,-0.027505
291,1,323,V,0.056724,0.851738,0.341134,-0.026568,-0.057604,14,A,0.794135
292,1,324,P,-0.011839,0.019198,0.012691,-0.040831,-0.113913,8,A,-0.094715


In [64]:
import warnings
import Bio

reassign_b_factor(input_pdbfile='scratch_notebooks/221111_model-fitting/4o5n.pdb',
                  output_pdbfile='scratch_notebooks/221130_summary_figures/libA_cocktail_b-factors_sum.pdb',
                  df=site_summary_df,
                  metric_col='sum',
                  site_col="site",
                  chain_col="chain",
                  missing_metric=0,
                  model_index=0,)