In [1]:
import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

import Bio

In [2]:
import os
os.chdir('../../')

In [8]:
spatial_distances = polyclonal.pdb_utils.inter_residue_distances(
    "data/PDBs/4o5n.pdb",
    target_chains=["A", "B"],
)

spatial_distances

Unnamed: 0,site_1,site_2,distance,chain_1,chain_2
0,9,10,1.328212,A,A
1,9,11,3.469929,B,B
2,9,12,6.336130,B,B
3,9,13,9.189821,B,B
4,9,14,8.930696,B,A
...,...,...,...,...,...
260276,497,499,15.936294,B,B
260277,497,500,16.632641,B,B
260278,498,499,23.859705,B,B
260279,498,500,13.285421,B,B


In [3]:
def reassign_b_factor(
    input_pdbfile,
    output_pdbfile,
    df,
    metric_col,
    *,
    site_col="site",
    chain_col="chain",
    missing_metric=0,
    model_index=0,
):

    # subset `df` to needed columns and error check it
    cols = [metric_col, site_col, chain_col]
    for col in cols:
        if col not in df.columns:
            raise ValueError(f"`df` lacks column {col}")
    df = df[cols].drop_duplicates()
#     if len(df) != len(df.groupby([site_col, chain_col])):
#         raise ValueError("non-unique metric for a site in a chain")

    if df[site_col].dtype != int:
        # if we have string type, convert to int
        if df[site_col].map(type).eq(str).all():
            encodes_int = df[site_col].str.fullmatch(r"\d+")
            if encodes_int.all():
                df[site_col] = df[site_col].astype(int)
            else:
                # this may raise an error if there are sites like 214a; before fixing
                # such errors, need to check the `residue.get_id()[1]` command below
                raise ValueError(
                    f"`site_col` has non-integer entries:\n{df[site_col][~encodes_int]}"
                )
        else:
            raise ValueError(f"`site_col` is neither str nor int:\n{df[site_col]}")

    # read PDB, catch warnings about discontinuous chains
    with warnings.catch_warnings():
        warnings.simplefilter(
            "ignore", category=Bio.PDB.PDBExceptions.PDBConstructionWarning
        )
        pdb = Bio.PDB.PDBParser().get_structure("_", input_pdbfile)

    # get the model out of the PDB
    model = list(pdb.get_models())[model_index]

    # make sure all chains in PDB
    missing_chains = set(df[chain_col]) - {chain.id for chain in model.get_chains()}
    if missing_chains:
        raise ValueError(f"`df` has chains not in PDB: {missing_chains}")

    # make missing_metric a dict if it isn't already
    if not isinstance(missing_metric, dict):
        missing_metric = {chain.id: missing_metric for chain in model.get_chains()}

    # loop over all chains and do coloring
    for chain in model.get_chains():
        chain_id = chain.id
        site_to_val = (
            df.query(f"{chain_col} == @chain_id")
            .set_index(site_col)[metric_col]
            .to_dict()
        )
        for residue in chain:
            site = residue.get_id()[1]
            try:
                metric_val = site_to_val[site]
            except KeyError:
                metric_val = missing_metric[chain_id]
            # for disordered residues, get list of them
            try:
                residuelist = residue.disordered_get_list()
            except AttributeError:
                residuelist = [residue]
            for r in residuelist:
                for atom in r:
                    # for disordered atoms, get list of them
                    try:
                        atomlist = atom.disordered_get_list()
                    except AttributeError:
                        atomlist = [atom]
                    for a in atomlist:
                        a.bfactor = metric_val

    # write PDB
    io = Bio.PDB.PDBIO()
    io.set_structure(pdb)
    io.save(output_pdbfile)

In [4]:
def get_b_factor_sum(model, serum):
    df = model.mut_escape_site_summary_df()
    df['site'] = df['site'].astype(int)
    df['sum'] = df['total positive'] + df['total negative']
    
    
    df = model.mut_escape_site_summary_df(min_times_seen=3)
    df['site'] = df['site'].astype(int)
    df['sum'] = df['total positive'] + df['total negative']
    df = df.loc[df['site'] < 326]
    df['chain'] = 'A'
#     return df
#     chains = ['A']
    metric_col = 'sum'
    
#     if isinstance(chains, str) and len(chains) == 1:
#         chains = [chains]
#     df = pd.concat([df.assign(chain=chain) for chain in chains], ignore_index=True)
    
    result_files = []

    for epitope in model.epitopes:
        output_pdbfile = f'scratch_notebooks/221227_model_fitting/{serum}_sum_epitope_{epitope}.pdb'
        output_pdbfile = output_pdbfile.format(
            epitope=epitope
        ).replace(" ", "_")
        if os.path.dirname(output_pdbfile):
            os.makedirs(os.path.dirname(output_pdbfile), exist_ok=True)
        result_files.append((epitope, output_pdbfile))
        reassign_b_factor(input_pdbfile='data/PDBs/4o5n.pdb',
                          output_pdbfile=output_pdbfile,
                          df = df.query("epitope == @epitope"),
                          metric_col='sum',
                          site_col='site',
                          chain_col='chain',
                          missing_metric=0,
                          model_index=0
                         )
    
    return pd.DataFrame(result_files, columns=["epitope", "PDB file"])

### Get models

In [6]:
prob_escape = pd.read_csv(
    "results/prob_escape/libA_221223_1_AUSAB-05_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape.notnull().all().all()

prob_escape_filtered_05 = prob_escape.loc[(prob_escape['antibody_concentration'] == 0.0074) |
                                          (prob_escape['antibody_concentration'] == 0.0111)
                                         ]

model_05 = polyclonal.Polyclonal(
    n_epitopes=1,
    data_to_fit=prob_escape_filtered_05.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model_05.fit(
    logfreq=200,
    reg_escape_weight=0.1,
)

# display results
display(model_05.activity_wt_barplot())
display(model_05.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 503 parameters at Fri Jan  6 11:33:37 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.019135       44868       44865           0           0           0              0               0       3.6049
          42      1.2486      612.49      606.83      2.1069           0           0              0               0       3.5517
# Successfully finished at Fri Jan  6 11:33:38 2023.
# Starting optimization of 3244 parameters at Fri Jan  6 11:33:38 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.023646      774.35      748.62      22.181  8.0512e-33           0              0               0       3.5517
          64      1.5823      746.34      739.99      2.1956    0.088411           0              0               0       4.0

In [9]:
prob_escape = pd.read_csv(
    "results/prob_escape/libA_221223_1_AUSAB-07_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape.notnull().all().all()

prob_escape_filtered_07 = prob_escape.loc[(prob_escape['antibody_concentration'] == 0.0010) |
                                       (prob_escape['antibody_concentration'] == 0.0015) |
                                       (prob_escape['antibody_concentration'] == 0.0023) |
                                       (prob_escape['antibody_concentration'] == 0.0034)
                                      ]

reference_sites = pd.read_csv("data/site_map.csv")["reference_site"].tolist()

model_07 = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape_filtered_07.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
    spatial_distances=spatial_distances,
)

# fit model
opt_res = model_07.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_uniqueness_weight=0,
    reg_uniqueness2_weight=1,
    reg_spatial_weight=0.0,
    reg_spatial2_weight=0.0005,
)

# display results
display(model_07.activity_wt_barplot())
display(model_07.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 1006 parameters at Fri Jan  6 11:35:48 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.068907       89018       89007           0           0           0              0               0       11.386
         135      8.9729      2440.9      2424.7      2.7583           0      10.576              0        0.083337       2.7186
# Successfully finished at Fri Jan  6 11:35:57 2023.
# Starting optimization of 6488 parameters at Fri Jan  6 11:35:57 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.074765      2974.4      2907.2      31.612  4.5229e-32      10.576              0          22.245       2.7186
         171      13.608      2783.7      2727.1      30.634      1.2771      18.963              0          2.4987       3.

In [10]:
prob_escape = pd.read_csv(
    "results/prob_escape/libA_221223_1_AUSAB-11_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape.notnull().all().all()

prob_escape_filtered_11 = prob_escape.loc[(prob_escape['antibody_concentration'] == 0.0067) |
                                       (prob_escape['antibody_concentration'] == 0.0100) |
                                       (prob_escape['antibody_concentration'] == 0.0150)
                                      ]

model_11 = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape_filtered_11.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
    spatial_distances=spatial_distances,
)

# fit model
opt_res = model_11.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_uniqueness_weight=0,
    reg_uniqueness2_weight=1,
    reg_spatial_weight=0.0,
    reg_spatial2_weight=0.0005,
)

# display results
display(model_11.activity_wt_barplot())
display(model_11.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 1006 parameters at Fri Jan  6 11:36:52 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.044176       58261       58253           0           0           0              0               0       8.0095
          15      1.3997      5971.9      5913.2      2.6324           0      54.974              0         0.22631      0.90466
# Successfully finished at Fri Jan  6 11:36:53 2023.
# Starting optimization of 6488 parameters at Fri Jan  6 11:36:53 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.065317      7099.1      6921.1      40.707  5.9335e-32      54.974              0          81.363      0.90466
         182      13.355      6478.7      6419.7      32.622     0.83533      20.461              0          3.2139       1.

In [11]:
prob_escape_13 = pd.read_csv(
    "results/prob_escape/libA_221027_1_AUSAB-13_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape_13.notnull().all().all()

model_13 = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape_13.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
    spatial_distances=spatial_distances,
)

# fit model
opt_res = model_13.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_uniqueness_weight=0,
    reg_uniqueness2_weight=1,
    reg_spatial_weight=0.0,
    reg_spatial2_weight=0.0005,
)

# display results
display(model_13.activity_wt_barplot())
display(model_13.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 1006 parameters at Fri Jan  6 11:37:54 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.075322  1.6416e+05  1.6415e+05           0           0           0              0               0       11.001
          82      9.0704      848.47      837.86      1.7832           0      3.5513              0         0.15505        5.127
# Successfully finished at Fri Jan  6 11:38:03 2023.
# Starting optimization of 6474 parameters at Fri Jan  6 11:38:03 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0     0.10189      1093.1      1006.5      19.327  2.2396e-32      3.5513              0          58.611        5.127
         162      17.803      949.25      920.66      14.826        1.79      6.2682              0         0.58849       5.

Get RGB values for epitope colors:

In [15]:
import matplotlib.colors

for epitope, hex_color in model_13.epitope_colors.items():
    rgb = [round(val, 3) for val in matplotlib.colors.to_rgb(hex_color)]
    print(f"{epitope}: hex color is {hex_color}; RGB tuple is {rgb}")

1: hex color is #0072B2; RGB tuple is [0.0, 0.447, 0.698]
2: hex color is #CC79A7; RGB tuple is [0.8, 0.475, 0.655]


Also get RGB values for negative site escape (orange):

In [16]:
rgb = [round(val, 3) for val in matplotlib.colors.to_rgb('#E69F00')]
rgb

[0.902, 0.624, 0.0]

### Save pdb files with refactored b values

In [18]:
serum_models = {
    'AUSAB-05': model_05,
    'AUSAB-07': model_07,
    'AUSAB-11': model_11, 
    'AUSAB-13': model_13
}

for serum in serum_models:
    get_b_factor_sum(serum_models[serum], serum)