# Figure 5: universal antagonism in peptide libraries

In [None]:
import os
import json,math
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors as clr
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
sns.set_context('talk')
import warnings
from scipy.stats import percentileofscore
warnings.filterwarnings('ignore')
idx = pd.IndexSlice
import sys
if not "../" in sys.path:
    sys.path.insert(1, "../")
from scripts.preprocess import write_conc_uM
from secondary_scripts.hhatv4_ec50_mcmc import load_raw_data_hhatv4
from secondary_scripts.mskcc_ec50_mcmc import hill_back
from scripts.plotting import standalone_legend
#from utils.fitting import (
#    hillFunction4p, 
#    hillFunction, 
#    Hill, 
#    inverseHill, 
#    student_t_ci, 
#    find_bounds_on_min, 
#    cost_fit_hill, 
#    r_squared, 
#    cost_fit_hill_4p
#)

import h5py
import corner

pj = os.path.join

In [None]:
root_dir = ".."
data_dir = pj(root_dir, 'data')
fig_dir = 'panels5'
res_dir = pj(root_dir, "results", "for_plots")

do_save_plots = False

In [None]:
#@title Aesthetic parameters
plt.rcParams["figure.dpi"] = 60

In [None]:
#@title Labeling parameters
# Renaming convention for TCRs
# tcr_rename = {
#     "1": r"CMV-A$^{\mathrm{M\L}}$",
#     "2": r"CMV-B$^{\mathrm{M\L}}$",
#     "3": r"CMV-C$^{\mathrm{M\L}}$",
#     "4": r"gp100-A$^{\mathrm{M\L}}$",
#     "5": r"gp100-C$^{\mathrm{M\L}}$",
#     "6": r"gp100-B$^{\mathrm{M\L}}$",
#     "7": r"Neoag-A$^{\mathrm{M\L}}$"
# }
tcr_rename = {
    "1": "C1",
    "2": "C2",
    "3": "C3",
    "4": "G1",
    "5": "G3",
    "6": "G2",
    "7": "N1"
}

pSelf = r'p$\mathregular{^{APL}}$'
pAg = r'p$\mathregular{^{Ag}}$'

## Simplified plots for cartoons

# Load data

In [None]:
#@title Experimental
cd25data = pd.read_hdf(pj(data_dir, "dose_response", 'fullCD25EC50df.hdf')).query("Peptide != 'None'")
cd25data

In [None]:
with open(pj(data_dir, "pep_tau_map_others.json"), "r") as h:
    other_taus = json.load(h)
hhat_pep_taus = {
    "HHAT-WT": other_taus.get("HHAT-WT"), 
    "HHAT-p8F": other_taus.get("HHAT-p8F")
}
inverse_pep_dict = {a:k.split("-")[1] for k, a in hhat_pep_taus.items()}

pep_order = list(hhat_pep_taus.keys())
modelTaus = pd.Series([hhat_pep_taus[a] for a in pep_order], 
    index=pd.MultiIndex.from_tuples([a.split("-") for a in pep_order], names=["TCR", "Peptide"])
).to_frame(name="Tau")

# Load the model curve and its CI, separate from it the peptides' points
fullModelDf = (pd.read_hdf(pj(res_dir, 'model_predict_aebs_hhat_wt_neoantigen.h5'))
               .xs("PC9-GL", level="Target").rename(write_conc_uM, level="TCR_Antigen_Pulse_uM")
              )
fullModelDf = pd.concat({"HHAT":fullModelDf}, names=["TCR"])

modelCurves = fullModelDf.drop(hhat_pep_taus.values(), level="TCR_Antigen_tau")

modelFCs = fullModelDf.loc[fullModelDf.index.isin(hhat_pep_taus.values(), level="TCR_Antigen_tau")]
modelFCs["Peptide"] = [inverse_pep_dict.get(k) 
        for k in modelFCs.index.get_level_values("TCR_Antigen_tau")]
modelFCs = modelFCs.set_index("Peptide", append=True)

# Copies
modelCurves2 = modelCurves.copy()
modelFCs2 = modelFCs.copy()
modelTaus2 = modelTaus.copy()

display(modelCurves)
display(modelFCs)
display(modelTaus)

In [None]:
#@title Balachandran-Model
balachandranModelDf = pd.read_hdf(pj(res_dir, 'mskcc_antagonism_fc_predictions_corrected_revised.h5'),
                                  key='fc_stats')
# Load individual MCMC samples, need to compute percentages for each sample to get a CI on these percentages.
balachandranModelSamples = pd.read_hdf(pj(res_dir, 'mskcc_antagonism_fc_predictions_corrected_revised.h5'),key='fc_samples')
balachandranModelDf

# Compute percentages within each MCMC sample and estimate CI on these
# Compute the fraction of agonists, antagonists, null peptides for each TCR and antigen
dens_name = "TCR_Antigen_Pulse_uM"
df_pc9_samples = np.log2(balachandranModelSamples.xs("1uM", level=dens_name))
ag_types_names = ['Antagonist','Null','Enhancer']
null_thresh = 1.0  # log2 fold-change

totals_peps = df_pc9_samples.groupby(["Antigen", "TCR"]).count()
# Cumulative percentages. Version 1: antag, antag+null, all
df_fracs_cumul_samples1 = pd.concat({
    ag_types_names[0]: (df_pc9_samples <= -null_thresh).groupby(["Antigen", "TCR"]).sum() / totals_peps,
    ag_types_names[1]: (df_pc9_samples < null_thresh).groupby(["Antigen", "TCR"]).sum() / totals_peps,
    ag_types_names[2]: (df_pc9_samples > -np.inf).groupby(["Antigen", "TCR"]).sum() / totals_peps  # 100 %
}, names=["Type"]).sort_index(level=["Antigen", "TCR"]).droplevel(["Antigen"]).reorder_levels(["TCR", "Type"]).sort_index()
# Version 2: all, enhancer+null, enhancer
df_fracs_cumul_samples2 = pd.concat({
    ag_types_names[0]: (df_pc9_samples > -np.inf).groupby(["Antigen", "TCR"]).sum() / totals_peps,  # 100 %
    ag_types_names[1]: (df_pc9_samples > -null_thresh).groupby(["Antigen", "TCR"]).sum() / totals_peps,
    ag_types_names[2]: (df_pc9_samples >= null_thresh).groupby(["Antigen", "TCR"]).sum() / totals_peps
}, names=["Type"]).sort_index(level=["Antigen", "TCR"]).droplevel(["Antigen"]).reorder_levels(["TCR", "Type"]).sort_index()
# Version 3: individual fractions of each type, no cumulative
df_fracs_samples = pd.concat({
    ag_types_names[0]: (df_pc9_samples <= -null_thresh).groupby(["Antigen", "TCR"]).sum() / totals_peps,  # 100 %
    ag_types_names[1]: (np.abs(df_pc9_samples) < null_thresh).groupby(["Antigen", "TCR"]).sum() / totals_peps,
    ag_types_names[2]: (df_pc9_samples >= null_thresh).groupby(["Antigen", "TCR"]).sum() / totals_peps
}, names=["Type"]).sort_index(level=["Antigen", "TCR"]).droplevel(["Antigen"]).reorder_levels(["TCR", "Type"]).sort_index()

# Compute statistics of agonist, antagonist, null fractions across MCMC samples (columns)
df_fracs_cumul_samples_stats1 = pd.concat({
    "mean": df_fracs_cumul_samples1.mean(axis=1),
    "median": df_fracs_cumul_samples1.median(axis=1),
    "percentile_2.5": df_fracs_cumul_samples1.quantile(q=0.025, axis=1),
    "percentile_97.5": df_fracs_cumul_samples1.quantile(q=0.975, axis=1)
}, names=["stats"], axis=1)
df_fracs_cumul_samples_stats2 = pd.concat({
    "mean": df_fracs_cumul_samples2.mean(axis=1),
    "median": df_fracs_cumul_samples2.median(axis=1),
    "percentile_2.5": df_fracs_cumul_samples2.quantile(q=0.025, axis=1),
    "percentile_97.5": df_fracs_cumul_samples2.quantile(q=0.975, axis=1)
}, names=["stats"], axis=1)
df_fracs_samples_stats = pd.concat({
    "mean": df_fracs_samples.mean(axis=1),
    "median": df_fracs_samples.median(axis=1),
    "percentile_2.5": df_fracs_samples.quantile(q=0.025, axis=1),
    "percentile_97.5": df_fracs_samples.quantile(q=0.975, axis=1)
}, names=["stats"], axis=1)


In [None]:
# Select the sample which yields the median fraction of antigens
#@title MSKCC EC50s
# Don't use their EC50s anymore, use mine
#balachandranDf = pd.read_hdf(data_dir+'correctedBalachandranEC50s.hdf').xs(('Yes'),level=('Corrected'))
balachandranSamplesDf = pd.read_hdf(pj(res_dir, "mskcc_antagonism_fc_predictions_corrected_revised.h5"), key="EC50_samples")
balachandranDf2 = pd.read_hdf(pj(res_dir, "mskcc_antagonism_fc_predictions_corrected_revised.h5"), key="EC50_fits")
med_col = df_fracs_samples_stats["median"]
median_sample = np.argmin(np.sum(np.abs(df_fracs_samples.values 
                    - df_fracs_samples_stats["median"].values.reshape(-1, 1)), axis=0))
balachandranDf = pd.concat({"median":balachandranSamplesDf[median_sample].to_frame(name="EC50 (M)"), 
                            "best":balachandranDf2[("MAP", "log_ec50_M")].to_frame(name="EC50 (M)")}, 
                           axis=1, names=["stat", "parameter"])
balachandranDf = 10.0 ** balachandranDf
balachandranSamplesDf = 10.0 ** balachandranSamplesDf
display(balachandranDf)
print("Sample closest to median:", median_sample)
print("Distances to true median:", df_fracs_samples[median_sample] - df_fracs_samples_stats["median"])

In [None]:
#@title Antagonism percentages
# Old version where I was computing these percentages in the plotting code
# I saved the results in some other format, but whatever, easier to reuse the code here. 
antigenList1,antigenList2,tupleList = [],[],[]
ro = ['7','4','6','5','1','2','3']
fullPlottingDf = np.log2(balachandranModelSamples[median_sample].to_frame("best")
                         .query("TCR_Antigen_Pulse_uM == '1uM'"))

for i,ag in enumerate(ro):
  plottingDf = fullPlottingDf.query("TCR == @ag")
  antagonistPercent = 100*plottingDf[plottingDf['best'] <= -null_thresh].shape[0]/plottingDf.shape[0]
  enhancerPercent = 100*plottingDf[plottingDf['best'] >= null_thresh].shape[0]/plottingDf.shape[0]

  antigenList1.append([antagonistPercent,100-enhancerPercent,100])  # Cumulative: antag, antag+null, all
  antigenList2.append([100,100-antagonistPercent,enhancerPercent])  # Cumulative: all, enhancer+null, enhancer
  tupleList.append(ag)

antagonistPercentageDf1 = pd.DataFrame(np.matrix(antigenList1),index=tupleList,columns=ag_types_names)
antagonistPercentageDf1.index.name = 'TCR'
antagonistPercentageDf1.columns.name = 'Type'
antagonistPercentageDf1 = antagonistPercentageDf1.stack().to_frame("best").sort_index()
antagonistPercentageDf1 = pd.concat([antagonistPercentageDf1, df_fracs_cumul_samples1*100.0], axis=1)
antagonistPercentageDf1.columns.name = "sample"

antagonistPercentageDf2 = pd.DataFrame(np.matrix(antigenList2),index=tupleList,columns=ag_types_names)
antagonistPercentageDf2.index.name = 'TCR'
antagonistPercentageDf2.columns.name = 'Type'
antagonistPercentageDf2 = antagonistPercentageDf2.stack().to_frame("best").sort_index()
# Add computed CI
antagonistPercentageDf2 = pd.concat([antagonistPercentageDf2, df_fracs_cumul_samples2*100.0], axis=1)
antagonistPercentageDf2.columns.name = "sample"

display(antagonistPercentageDf1)
display(antagonistPercentageDf2)

# Panel A: Process explanation schematic

In [None]:
#@title Left: Cell representation

In [None]:
#@title Middle: Cartoon representation
numPoints = 5
centerToSelectingDistance = 6
resolution = 1
agQuality = 12
dummyVal = 1

valList = [agQuality]
idList = ['Selecting']
for i in range(numPoints):
  newX = agQuality - centerToSelectingDistance-resolution*int(numPoints/2)
  newX = newX+resolution*i
  valList.append(newX)
  if i == 0:
    idList.append('Non-conservative')
  elif i == numPoints-1:
    idList.append('Conservative')
  else:
    idList.append('')

df = pd.DataFrame({'Antigen Quality':valList,'Peptide':idList,'Dummy':[dummyVal]*(numPoints+1)})
g = sns.relplot(data=df,y='Dummy',x='Antigen Quality',hue='Peptide',hue_order=['Selecting','Non-conservative','Conservative',''],palette=['k','r','g','grey'],s=500,zorder=100,edgecolor='k',linewidth=0,aspect=2,legend=False,style='Peptide',style_order=['','Non-conservative','Conservative','Selecting'],markers=['o','o','o','D'])
ax = g.axes.flat[0]
ax.set_ylim([0.94,1.01])
ax.set_xlim([-3.5,13.5])
for i in range(numPoints):
    start = valList[i+1]
    mid = (valList[0]-valList[1])/2
    if i == 0:
      color = 'r'
    elif i == len(valList)-2:
      color = 'g'
    else:
      color = 'grey'
    ax.annotate('',xy=(agQuality,dummyVal-(numPoints-(i))*0.005),xytext=(start,dummyVal-(numPoints-(i))*0.005),xycoords='data',arrowprops=dict(facecolor=color,arrowstyle='<-',linewidth=3,color=color),zorder=0)
ax.annotate('',xy=(agQuality,dummyVal),xytext=(agQuality,-0.0005+dummyVal-(numPoints-(0))*0.005),xycoords='data',arrowprops=dict(facecolor='k',arrowstyle='-',linewidth=3,color='k'),zorder=200)
ax.annotate('',xytext=(-3,dummyVal),xy=(13,dummyVal),xycoords='data',arrowprops=dict(facecolor='black',linewidth=2,width=0.5),zorder=0)

ax.annotate('Single AA\nmutations',xy=(valList[3]-2.5,0.985), xytext=(valList[3]-4.5, 0.985),ha='center',va='center',arrowprops=dict(arrowstyle='-[, widthB=2.0, lengthB=0.8',linewidth=1.5))
#ax.annotate('Possible self-peptides (p$\mathregular{^{MT}}$)',xy=(valList[3],1.01), xytext=(valList[3], 1.02),ha='center',va='center',arrowprops=dict(arrowstyle='-[, widthB=5.0, lengthB=0.8',linewidth=1.5))
#ax.annotate('Selecting peptide (p$\mathregular{^{WT}}$)',xy=(12,1.01), xytext=(12, 1.02),ha='center',va='center',arrowprops=dict(arrowstyle='-[, widthB=1.0, lengthB=0.8',linewidth=1.5))
ax.annotate('Possible self-peptides\n'+pSelf,xy=(valList[3],1.01), xytext=(valList[3], 1.02),ha='center',va='center',arrowprops=dict(arrowstyle='-[, widthB=5.0, lengthB=0.8',linewidth=1.5))
ax.annotate('Agonist peptide\n'+pAg,xy=(12,1.01), xytext=(12, 1.02),ha='center',va='center',arrowprops=dict(arrowstyle='-[, widthB=1.0, lengthB=0.8',linewidth=1.5))


ax.annotate('Low impact',xy=(12.2,dummyVal-(numPoints-(4))*0.005),ha='left',va='top',color='g')
ax.annotate('High impact',xy=(12.2,dummyVal-(numPoints-(0))*0.005),ha='left',va='bottom',color='r')

#ax.annotate('Antigen\nquality',xy=(-0.5,1),ha='center',va='center',bbox=dict(facecolor='w',edgecolor='w'))
ax.annotate('Antigen\nstrength',xy=(valList[3]-4.5,1),ha='center',va='center',bbox=dict(facecolor='w',edgecolor='w'))

ax.set_ylabel('')
ax.set_xlabel('')
ax.spines[['left', 'bottom']].set_visible(False)
ax.set_yticks([])
ax.set_xticks([])

if do_save_plots:
    g.figure.savefig(pj(fig_dir, '5A_center-explanationSchematic.pdf'),bbox_inches='tight',transparent=True)

In [None]:
#@title Right: Cartoon representation (multiple strengths)
numPoints = 5
centerToSelectingDistance = 6
resolution = 1
agQuality = 12
dummyVal = 1

qualities = [12,10,8]

trueValList = []
trueIDlist = []
dummyList = []
for j in range(3):
  valList = [qualities[j]]
  idList = ['Selecting']
  for i in range(numPoints):
    newX = qualities[j] - centerToSelectingDistance-resolution*int(numPoints/2)
    newX = newX+resolution*i
    valList.append(newX)
    if i == 0:
      idList.append('Non-conservative')
    elif i == numPoints-1:
      idList.append('Conservative')
    else:
      idList.append('')
  trueValList+=valList
  trueIDlist+=idList
  dummyList+=[1-j*0.01]*len(valList)

valList = trueValList.copy()
idList = trueIDlist.copy()

df = pd.DataFrame({'Antigen Quality':valList,'Peptide':idList,'Dummy':dummyList})
g = sns.relplot(data=df,y='Dummy',x='Antigen Quality',hue='Peptide',hue_order=['Selecting','Non-conservative','Conservative',''],palette=['k','r','g','grey'],s=500,zorder=100,edgecolor='k',linewidth=0,aspect=2,legend=False,style='Peptide',style_order=['','Non-conservative','Conservative','Selecting'],markers=['o','o','o','D'])
ax = g.axes.flat[0]
ax.set_ylim([0.97,1.05])
ax.set_xlim([-3.5,13.5])
for i in range(numPoints):
    start = valList[i+1]
    mid = (valList[0]-valList[1])/2
    if i == 0:
      color = 'r'
    elif i == len(valList)-2:
      color = 'g'
    else:
      color = 'grey'

ax.axvspan(6.5,12.5,color='orange',alpha=0.3)
ax.axvspan(1.5,5.5,color='purple',alpha=0.3)
ax.axvline(color='k',linestyle=':',x=5.5)
ax.axvline(color='k',linestyle=':',x=1.5)
ax.axvline(color='k',linestyle=':',x=6.5)

ax.set_ylabel('')
ax.set_xlabel('')
ax.spines[['left', 'bottom']].set_visible(False)
ax.set_yticks([])
ax.set_xticks([])

ax.annotate('Strong',xy=(13.2,dummyVal),ha='left',va='center')
ax.annotate('Medium',xy=(13.2,dummyVal-0.01),ha='left',va='center')
ax.annotate('Weak',xy=(13.2,dummyVal-0.02),ha='left',va='center')

ax.annotate('',xytext=(-3,dummyVal),xy=(13,dummyVal),xycoords='data',arrowprops=dict(facecolor='black',linewidth=2,width=0.5),zorder=0)
ax.annotate('',xytext=(-3,dummyVal-0.01),xy=(13,dummyVal-0.01),xycoords='data',arrowprops=dict(facecolor='black',linewidth=2,width=0.5),zorder=0)
ax.annotate('',xytext=(-3,dummyVal-0.02),xy=(13,dummyVal-0.02),xycoords='data',arrowprops=dict(facecolor='black',linewidth=2,width=0.5),zorder=0)

ax.annotate('No effect',xy=(0,dummyVal+0.015),ha='center',va='center')#,fontweight='bold')
ax.annotate('Antagonism',xy=(3.5,dummyVal+0.015),color='purple',ha='center',va='center')#,fontweight='bold')
ax.annotate('Enhancement',xy=(9.5,dummyVal+0.015),color='orange',ha='center',va='center')#,fontweight='bold')

ax.annotate('Agonist\nstrength',xy=(14,dummyVal+0.015),ha='center',va='center',annotation_clip=False)

if do_save_plots:
    g.figure.savefig(pj(fig_dir, '5A_right-hypothesisSchematic.pdf'),bbox_inches='tight',transparent=True)

# Panel B: Experimental dose response curves

In [None]:
#@title Set threshold for antagonism and enhancement
antagonism_threshold = 0.5
enhancement_threshold = 2.0

In [None]:
#@title Plotting functions
def normalized_hill(x, a, b, h, n):
    if np.isfinite(h):
        xnorm = x / h
    else:
        xnorm = 0.0
    return (a - b) * xnorm**n / (xnorm**n + 1.0) + b


def geo_mean_apply(ser):
    return np.exp(np.mean(np.log(ser)))


def load_n4_ref(dose_dir):
    df_cd25_ec50s = (pd.read_hdf(pj(data_dir, "dose_response", "experimental_peptide_ec50s_blasts.h5"),
                key="df").xs("CD25fit", level="Method"))
    df_cd25_ec50s = df_cd25_ec50s.groupby(["TCR", "Peptide"]).apply(geo_mean_apply)

    # Rename HHAT peptides to HHAT-...
    rename_dict = {p:"HHAT-{}".format(p) for p in df_cd25_ec50s.xs("HHAT").index.unique()}
    rename_dict.update({p:"NYESO-{}".format(p) for p in df_cd25_ec50s.xs("NYESO").index.unique()})
    rename_dict.update({p:"OT1-{}".format(p) for p in df_cd25_ec50s.xs("OT1").index.unique()})
    df_cd25_ec50s = df_cd25_ec50s.rename(rename_dict, level="Peptide")

    # Choose reference absolute EC50 for N4: use CD25 EC50s
    # This means we will have different taus for OT-1 peptides vs. fig. 2
    # But that's OK, we are using a different set of EC50s
    # to illustrate the general procedure to predict antagonism
    ref_ec50_n4 = df_cd25_ec50s.at[("OT1", "OT1-N4")]  # M

    # Load N4's reference EC50
    ref_file = pj(data_dir, "reference_pep_tau_maps.json")
    with open(ref_file, "r") as file:
        tau_refs = json.load(file)

    ref_tau_n4 = tau_refs.get("N4")  # s

    return ref_ec50_n4, ref_tau_n4


def clean_nosub_duplicates(df):
    for tcr in df.index.get_level_values("TCR").unique():
        params_tcr = df.xs(tcr, level="TCR")
        antigen = params_tcr.index.get_level_values("Antigen").unique()[0]
        # Find all false substitutions
        wt_duplicates = {}
        for pep in params_tcr.index.get_level_values("Peptide").unique():
            if pep[0] == pep[2]:
                wt_duplicates[pep] = params_tcr.loc[(antigen, pep)]
        # Check they were all identical
        #print(wt_duplicates)
        # Replace them all by one WT row
        df = df.drop([(antigen, tcr, pep) for pep in wt_duplicates.keys()])
        df.loc[(antigen, tcr, "WT")] = list(wt_duplicates.values())[0]
    return df

def plot_ec50_curves(df_data, df_params):
    possible_concs = df_data["Dose (M)"].unique()
    conc_range = np.geomspace(possible_concs.min()*0.5, possible_concs.max()*2, 200)
    x_min = conc_range.min()
    # Compute each EC50 curve
    df_params2 = df_params[["V_inf", "backgnd", "log_ec50_M", "n"]].copy()
    df_params2["ec50_M"] = 10.0**df_params2["log_ec50_M"]
    df_params2 = df_params2.drop("log_ec50_M", axis=1)
    curves = np.stack([normalized_hill(conc_range, *df_params2.iloc[i])
                        for i in range(df_params.shape[0])])
    # colors: according to EC50
    # Clip EC50s to reasonable range before defining color
    norm = mpl.colors.Normalize(vmin=np.log(df_params2["ec50_M"].min()),
                    vmax=min(np.log(df_params2["ec50_M"].max()), np.log(x_min*1e12)))
    colors = [plt.cm.magma(1.0 - norm(np.log(df_params2.loc[p, "ec50_M"]))) for p in df_params.index]
    labels = df_params2.index.get_level_values("Peptide").values

    # Activation marker
    response_name = [lbl for lbl in df_data.columns if lbl.startswith("Response")]
    response_name = response_name[0]

    # Plot
    fig, ax = plt.subplots()
    default_figsize = fig.get_size_inches()
    fig.set_size_inches(default_figsize[0]*0.9, default_figsize[1]*0.9)
    ax.set_xscale("log")
    # Plot dose response data first
    for i, lbl in enumerate(df_params2.index):
        data_pts = df_data.loc[lbl]
        ax.plot(data_pts["Dose (M)"], data_pts[response_name], mfc="grey",
        mec=colors[i], ls="-", lw=1.5, marker="o", ms=8, mew=0.75,
        color=colors[i])

    # Plot fitted curves
    #for i in range(len(curves)):
    #    ax.plot(conc_range, curves[i]*100.0, color=colors[i], lw=2.0, label=labels[i])

    ylims = [-5.0, 105.0]  # %
    ax.set_ylim(ylims)
    #for i, p in enumerate(df_params.index):
    #    if x_min <= df_params.loc[p, "K_a"] <= conc_range.max():
    #        max_ampli = df_params.loc[p, "A"]*100.0
    #        ax.plot(df_params.loc[p, "K_a"], max_ampli / 2.0, ls="none",
    #            marker="^", color=colors[i], mec="r", mew=1.0, ms=8.0 )
            #ax.axvline(df_params.loc[p, "K_a"], ymin=0.0,
            #        ymax=(max_ampli/2.0 - ylims[0]) / (ylims[1] - ylims[0]),
            #        lw=1.0, ls="--", color=colors[i])

    ax.set_xlabel(r"Dose (M)")
    ax.set_ylabel(response_name)
    for side in ["top", "right"]:
        ax.spines[side].set_visible(False)
    #ax.set_xticklabels([])
    #ax.set_yticklabels([])
    #ax.set_xticks([])
    #ax.set_yticks([])
    #locmin = mpl.ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * .1,
    #                                      numticks=100)
    #ax.xaxis.set_minor_locator(locmin)
    fig.tight_layout()
    return fig, ax


def plot_model_curve(df_ec50s, model_curve, ec50_ref, tau_ref, npow=6):
    """ model_curve also contains percentiles for CI plotting """
    # Convert mock EC50s to taus using the actual formula
    ser_ec50s = df_ec50s["EC50 (M)"]
    #ec50_ref = 1e-11  # N4, in M
    npow = 6
    converter = lambda x: tau_ref * (ec50_ref / x)**(1/npow)
    all_taus = ser_ec50s.apply(converter).values

    # Clip tau values larger than the max in the curve
    # we don't want the plot to shrink the antagonism region too much
    # due to a couple of outlier very strong agonists
    all_taus = all_taus.clip(min=-np.inf, max=model_curve.index.values.max())

    # Find closest tau in the curves
    where_closest = np.argmin(np.abs(model_curve.index.values[:, None] - all_taus[None, :]), axis=0)
    model_points = model_curve["best"].iloc[where_closest]

    # Colors
    # Clip EC50s to reasonable range before defining color
    norm = mpl.colors.Normalize(vmin=np.log(ser_ec50s.min()),
            vmax=min(np.log(ser_ec50s.max()), np.log(ser_ec50s.min()*1e12)))
    colors = plt.cm.magma(1.0 - norm(np.log(ser_ec50s)))
    labels = ser_ec50s.index.get_level_values("Peptide").values

    # Plot the model curve and highlight mock points
    fig, ax = plt.subplots()
    default_figsize = fig.get_size_inches()
    fig.set_size_inches(default_figsize[0]*0.9, default_figsize[1]*0.85)
    ax.set_yscale("log", base=2)
    ax.plot(model_curve.index.values, model_curve["best"], ls="-", color="k", lw=3.0)
    ax.fill_between(model_curve.index.values, model_curve["percentile_2.5"], model_curve["percentile_97.5"],
                    color="k", alpha=0.2)
    ax.axhline(1.0, ls="--", color="grey")
    for i in range(len(all_taus)):
        ax.plot(all_taus[i], model_points.iloc[i], marker="o", mfc=colors[i],
                mec="k", mew=0.75, ms=12)
    ax.set_ylabel(r"$FC_{\mathrm{TCR/CAR}}$")
    ax.set_xlabel(r"TCR antigenicity ($\tau$)", labelpad=10.0)
    ax.set_xticklabels([])
    ax.set_xticks([])
    #ax.set_yticklabels([])
    # highlight regions of antagonism, null, enhancement
    ylims = ax.get_ylim()
    ax.set_ylim(ylims)
    ax.set_xlim(ax.get_xlim())
    ax.fill_between(ax.get_xlim(), ylims[0], antagonism_threshold, color="purple", alpha=0.3, zorder=0)
    ax.fill_between(ax.get_xlim(), enhancement_threshold, ylims[1], color="orange", alpha=0.3, zorder=0)
    ax.annotate("Enhancement", color="orange", xy=(ax.get_xlim()[1]*0.95, 2.25), ha="right", va="bottom")
    ax.annotate("Antagonism", color="purple", xy=(ax.get_xlim()[1]*0.95, 0.4), ha="right", va="top")
    for side in ["top", "right"]:
        ax.spines[side].set_visible(False)
    fig.tight_layout()
    return fig, ax


def corner_plot_2d_mcmc(samples, pvec_best, pnames, sizes_kwargs={}, **kwargs):
    """
    Args:
        samples (np.ndarray): MCMC samples.
            Using the first two parameters only.
        pvec_best (np.ndarray): best parameter value
        pnames (list): list of fitted parameter names, using the first two.
        sizes_kwargs (dict): things like scaleup, small_lw, truth_lw, small_markersize.
            Also truth_color.

    Other kwargs are passed to corner.corner.

    Returns:
        fig (matplotlib.figure.Figure): cornerplot figure
    """
    # Make the corner plot, using aesthetical parameters in sizes_kwargs
    scaleup = sizes_kwargs.get("scaleup", 1.0)
    small_lw = sizes_kwargs.get("small_lw", 0.8) * scaleup
    truth_lw = sizes_kwargs.get("truth_lw", 1.25) * scaleup
    small_markersize = sizes_kwargs.get("small_markersize", 1.0) * scaleup
    tcr_color = np.asarray((0.0, 156.0, 75.0, 255.0)) / 255.0  # deep key lime green
    truth_color = sizes_kwargs.get("truth_color", tcr_color)
    #"xkcd:cornflower", #"xkcd:sage"
    reverse_plots = sizes_kwargs.get("reverse_plots", False)
    labelpad = sizes_kwargs.get("labelpad", len(pvec_best)**3/200.0)

    # Corner plot
    hist2d_kwargs = {"contour_kwargs":{"linewidths":small_lw},
                     "data_kwargs":{"ms":small_markersize}}
    fig, ax = plt.subplots()
    corner.hist2d(
        x=samples[0].flatten(),
        y=samples[1].flatten(),
        ax=ax,
        # Plot truths manually below to control line width
        **hist2d_kwargs,
        **kwargs
    )
    # Label the graph
    ax.set(xlabel=pnames[0], ylabel=pnames[1])

    # Add truths manually to control line width
    ax.axvline(pvec_best[0], color=truth_color, lw=truth_lw)
    ax.axhline(pvec_best[1], color=truth_color, lw=truth_lw)
    ax.plot(pvec_best[0], pvec_best[1], ls="none", marker="s",
            ms=5.0*small_markersize, mfc=truth_color, mec=truth_color)

    fig.set_size_inches(sizes_kwargs.get("figsize", fig.get_size_inches()))
    for side in ["top", "right"]:
        ax.spines[side].set_visible(False)
    fig.tight_layout()
    return fig, ax


In [None]:
#@title Load MSKCC data
# Import MSKCC dose response data, and our curve fitting parameters and ec50 data
mskcc_data = pd.read_hdf(pj(data_dir, "dose_response", "MSKCC_rawDf.hdf")).sort_index()
mskcc_params = pd.read_hdf(pj(res_dir, "mskcc_antagonism_fc_predictions_corrected_revised.h5"), key="EC50_fits")
choice_method = "MAP"
mskcc_ec50s = (10.0 ** mskcc_params.loc[:, (choice_method, ["log_ec50_M", "log_ec50_ugmL"])]
               .droplevel("Feature", axis=1)
                .rename({"log_ec50_ugmL":"EC50 (ug/mL)", "log_ec50_M":"EC50 (M)"}, axis=1)
              )
mskcc_params = mskcc_params.xs(choice_method, axis=1, level=0)
for df in [mskcc_data, mskcc_params, mskcc_ec50s]:
    print(df.index.get_level_values("Antigen").unique())

# In the parameters dataframe, drop all false substitutions, e.g. A7A:
# these are all copies of the WT, duplicated for heatmap plotting convenience
#mskcc_params = clean_nosub_duplicates(mskcc_params)

# Rename CD137 to 4-1BB
resp_name = "Response (4-1BB+ %)"
mskcc_data = mskcc_data.rename({"Response (CD137+ %)":resp_name}, axis=1)

# Change K_a from ug/ml to mol/l
print(mskcc_data)
print(mskcc_ec50s)
mskcc_params["K_a"] = mskcc_ec50s["EC50 (M)"]
print(mskcc_params)

In [None]:
df_response_inf = mskcc_data.copy()
peps_without_ec50 = (mskcc_ec50s["EC50 (M)"] == np.inf)
df_response_inf["INF"] = peps_without_ec50
df_response_inf = df_response_inf.set_index("INF", append=True).set_index("Dose (ug/mL)", append=True)
# Keep only the largest dose
df_response_inf = df_response_inf.xs(100.0, level="Dose (ug/mL)")[resp_name]

In [None]:
#@title Plot EC50 curves for the neoantigen
fig, ax = plot_ec50_curves(mskcc_data.loc["Neoantigen"], mskcc_params.loc["Neoantigen"])
if do_save_plots:
    fig.savefig(pj(fig_dir, "5B-mskcc_ec50_data.pdf"), transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Panel C: EC50 heatmaps

In [None]:
#@title Plot
from matplotlib.patches import Rectangle
# For heatmaps, distributions: use the median sample
# to avoid the clipped EC50s of null peptides
# For WT peptides bar graphs: use 'best'
stat_choice = "median"

tcrDict = {'CMV':['1','2','3'],'gp100':['4','5','6'],'Neoantigen':['7']}
tcrDict2 = {'7':'N1','4':'G1','6':'G2','5':'G3','1':'C1','2':'C2','3':'C3'}
wtseqDict = {'CMV': 'NLVPMVATV','gp100': 'IMDQVPFSV', 'Neoantigen': 'GRLKALCQR'}
for antigen in ['CMV','gp100','Neoantigen']:
  for tcr in tcrDict[antigen]:
    labeledCytDf = balachandranDf.xs(stat_choice, axis=1).query("Antigen == @antigen and TCR == @tcr")
    #KVPRNQDWL
    wtseq = wtseqDict[antigen]
    wtDf = pd.concat([labeledCytDf.query("Peptide == 'WT'")]*len(wtseq),keys=[x+str(i+1)+x for i,x in enumerate(wtseq)]).droplevel('Peptide')
    wtDf.index.names = ['Peptide']+list(wtDf.index.names)[1:]
    wtDf = wtDf.reset_index().set_index(labeledCytDf.index.names)
    labeledCytDf2 = labeledCytDf.query("Peptide != ['Irrelevant','WT']")
    labeledCytDf2 = pd.concat([labeledCytDf2,wtDf])
    labeledCytDf2['WT'] = [x[:2] for x in labeledCytDf2.index.get_level_values('Peptide')]
    labeledCytDf2['Mutant'] = [x[2] for x in labeledCytDf2.index.get_level_values('Peptide')]
    labeledCytDf2 = labeledCytDf2.set_index(['WT','Mutant'],append=True)

    fig = plt.figure(figsize=(5,10))
    plottingDf = labeledCytDf2.droplevel(['Peptide']).loc[:,'EC50 (M)'].unstack('WT')
    plottingDf = plottingDf[[x+str(i+1) for i,x in enumerate(wtseq)]]
    mutantOrder = 'WFYCMLIVDENHRKQGPSAT'
    plottingDf = plottingDf.droplevel(['Antigen','TCR']).reindex([x for x in mutantOrder],axis=0,level=0).astype(float)
    plottingDf = np.log10(plottingDf)
    #print('-'.join([tcr,antigen]))
    #display(plottingDf.max().max())
    #display(plottingDf.min().min())
    plottingDf = np.clip(plottingDf,a_min=-11,a_max=0)
    g = sns.heatmap(plottingDf,cbar_kws={'label':'EC$_{50}$ (M)', 'shrink':0.8}, cmap='magma_r')
    wtposes = [[i,plottingDf.index.unique('Mutant').tolist().index(x)] for i,x in enumerate(wtseq)]
    ax = plt.gca()
    ax.set_title(antigen+', TCR '+tcrDict2[tcr])# + r"$^{M\L}$")
    for wtpose in wtposes:
      ax.add_patch(Rectangle((wtpose[0], wtpose[1]), 1, 1, fill=False, edgecolor='w', lw=3))
    g.set_xticklabels([x.get_text()[0] for x in g.get_xticklabels()],rotation=0)
    #g.set_yticklabels([x for x in mutantOrder],rotation=0)
    g.set_yticklabels([x.get_text()[0] for x in g.get_yticklabels()],rotation=0)
    ax.set_xlabel('AA in '+pAg)
    ax.set_ylabel('AA in '+pSelf)
    ax.collections[0].colorbar.ax.tick_params(labelsize=14)

    ogyticks = g.collections[0].colorbar.get_ticks()[1:]
    newyticks = list(pd.unique([int(x) for x in ogyticks]))
    newyticklabels = ['10$^{'+str(x)+'}$' for x in newyticks]
    g.collections[0].colorbar.set_ticks(newyticks)
    g.collections[0].colorbar.set_ticklabels(newyticklabels)

    if tcr == "7":
        if do_save_plots:
            fig.savefig(pj(fig_dir, '5C-balachandranExperimentalEC50-Neoantigen7.pdf'),
                bbox_inches='tight',transparent=True, dpi=300)
        plt.show()
    plt.close()

# Panel D: EC50 distributions

In [None]:
#@title Plot
stat_choice = "median"
tcr_to_antigen_map = {"1":"CMV", "2":"CMV", "3":"CMV", "4":"gp100", "5":"gp100", "6":"gp100", "7":"Neoantigen"}
# Find the EC50s that give enhancement, antagonism, or nothing
# based on model predictions in PC9 at 1 uM
balachandranMixedDf = pd.concat([balachandranDf[(stat_choice, "EC50 (M)")].to_frame(name="EC50 (M)"),
        (balachandranModelSamples[median_sample].xs("1uM", level=dens_name, axis=0)
         .to_frame("FC"))], axis=1)
# limit between null and antagonism: weakest peptide (largest EC50) producing FC < 1/2
ec50_limit_null = np.amax(balachandranMixedDf["EC50 (M)"].loc[balachandranMixedDf["FC"] <= 0.5])# * 1e6
# limit between antagonism and null2: strongest peptide (smallest EC50) producing FC < 1/2
ec50_limit_null2 = np.amin(balachandranMixedDf["EC50 (M)"].loc[balachandranMixedDf["FC"] <= 0.5])# * 1e6
# Limit between antagonism and agonist: weakest peptide (largest EC50) producing FC > 2
ec50_limit_antag = np.amax(balachandranMixedDf["EC50 (M)"].loc[balachandranMixedDf["FC"] > 2.0])# * 1e6

# plot histogram of 1/EC50, limit EC50s to 1e8 uM = 100 M
fullPlottingDf = -(np.log10(balachandranDf.xs(stat_choice, axis=1)).clip(upper=2))
ro = ['7','4','6','5','1','2','3']
g = sns.displot(data=fullPlottingDf, x="EC50 (M)",kind='kde',color='k',rug=True,row='TCR',
                rug_kws={'height':0.1},facet_kws=dict(sharey=False),height=2.2,aspect=2.5,row_order=ro,zorder=300)
figsize = g.figure.get_size_inches()
g.figure.set_size_inches(figsize[0]*0.8, figsize[1])
#,hue='TCR_Antigen_Density',palette=['k','grey'],hue_order=['1uM','1nM'])
#titles = ['CMV (Strong), TCRs 1-3','gp100 (Weak) TCRs 4-6','Neoantigen (Very Weak), TCR 7']
#titles = ['CMV, TCR 1','CMV, TCR 2','CMV, TCR 3','gp100, TCR 4','gp100, TCR 5','gp100, TCR 6','Neoantigen, TCR 7']
titles = ['Neoantigen, TCR N1','gp100, TCR G1','gp100, TCR G2','gp100, TCR G3','CMV, TCR C1','CMV, TCR C2','CMV, TCR C3']
tupleList = []
for i,tcr in enumerate(ro):
    axis = g.axes.flat[i]
    axis.set_xlim([-2.5,12.5])
    # Annotate where the WT peptide is
    ag = tcr_to_antigen_map.get(tcr)
    axis.set_ylim(axis.get_ylim()[0], axis.get_ylim()[1]*1.25)
    x_wt = np.log10(1e0 / balachandranDf.at[(ag, tcr, "WT"), (stat_choice, "EC50 (M)")])
    xfrac = (x_wt + 2.5) / 15
    axis.axvline(x_wt, 0, 0.8, color="k",linestyle='--',zorder=300)
    where_below = (balachandranDf.loc[(ag, tcr), (stat_choice, "EC50 (M)")]
                   > balachandranDf.at[(ag, tcr, "WT"),(stat_choice, "EC50 (M)")])
    frac_below = where_below.sum() / where_below.size
    frac_above = 1.0 - frac_below
    axis.annotate(pAg, xy=(xfrac, 0.87), xycoords="axes fraction", ha="center",
                  color="k",fontsize=16,zorder=300)

    plottingDf = fullPlottingDf.query("TCR == @tcr")

    axis.set_ylabel('# Peptides')
    axis.set_xlabel(r'1/EC$_{50}$ (M$^{-1}$)')
    axis.set_yticks([])
    axis.set_title(titles[i])# +  r"$^{M\L}$")
    ogxticks = axis.get_xticks()
    newxticklabels = ['10$^{'+str(int(x))+'}$' for x in ogxticks]
    axis.set_xticks(ogxticks)
    axis.set_xticklabels(newxticklabels)
    tupleList.append(ag)
    axis.set_xlim([-2.5,12.5])

    axis.plot(x_wt,0.8*axis.get_ylim()[1],marker='o',color='k',markersize=6)

if do_save_plots:
    g.figure.savefig(pj(fig_dir, '5D-balachandranEC50distributions-7plot_rowWise.pdf'), 
                  bbox_inches='tight',transparent=True, dpi=300)
plt.show()
plt.close()

# Panel E: Model prediction pipeline

In [None]:
#@title Top: Ligand/receptor cell cartoon

In [None]:
#@title Middle: Model schematic (from fig 2)

In [None]:
#@title Bottom: Example of MCMC distribution
truth_color = np.asarray((1.0, 103.0, 146.0, 255.0)) / 255.0  # CAR blue
# Load best fit
best_kmf = "(1, 2, 1)"
with open(pj(root_dir, "results", "mcmc", "mcmc_analysis_tcr_car_both_conc.json"), "r") as f:
    best_fits = np.asarray(json.load(f).get(best_kmf)
                .get("param_estimates").get("MAP best"))  # log10
with open(pj(root_dir, "results", "mcmc", "mcmc_analysis_tcr_car_both_conc.json"), "r") as f:
    burn_in_frac = json.load(f).get(best_kmf).get("burn_in_frac")
    
# Load parameter samples (log10 of parameters) from full MCMC file
best_kmf = "(1, 2, 1)"
param_choice = [0, 1]
param_names = [r"$\log_{10} C^C_{m,th}$", r"$\log_{10} I^C_{th}$"]
with h5py.File(pj(root_dir, "results", "mcmc", "mcmc_results_tcr_car_both_conc.h5"), "r") as f:
    # Burn-in fraction
    n_steps = f.get("samples").get(best_kmf).shape[2]
    # C_mth and I_th selected
    param_samples = f.get("samples").get(best_kmf)[param_choice, :, int(n_steps*burn_in_frac):]

best_fits = best_fits[param_choice]

fig, ax = corner_plot_2d_mcmc(param_samples, best_fits, param_names,
        sizes_kwargs={"scaleup":1.75, "truth_color":truth_color,
            "figsize":(4.25, 3.25)})
if do_save_plots:
    fig.savefig(pj(fig_dir, "5E_bottom-mcmc_distribution_example.pdf"), transparent=True, bbox_inches="tight")
plt.show()
plt.close()


# Panel F: Model FC curve for the neoantigen

In [None]:
#@title Plot
# Import precomputed model curve for the selected APC type
model_curve = modelCurves.loc[("HHAT", "1uM")]

# CMV is assumed to be like N4 in the CD25 EC50 dataset
ec50_n4, tau_n4 = load_n4_ref(data_dir)

fig, ax = plot_model_curve(mskcc_ec50s, model_curve, ec50_ref=ec50_n4, tau_ref=tau_n4)
ax.annotate("No effect", xy=(0.81, 0.41), xycoords="axes fraction",va='center', ha="center",
             color="grey",zorder=300,bbox=dict(facecolor='white', edgecolor='none'))
ax.axhline(color='k',linestyle=':',y=antagonism_threshold)
ax.axhline(color='k',linestyle=':',y=enhancement_threshold)
if do_save_plots:
    fig.savefig(pj(fig_dir, "5F-mskcc_neoag_on_model_curve.pdf"), transparent=True, bbox_inches="tight")
plt.show()
plt.close()


# Panel G: Model FC heatmap for neoantigen

In [None]:
#@title Plot
from matplotlib.patches import Rectangle

tcrDict = {'CMV':['1','2','3'],'gp100':['4','5','6'],'Neoantigen':['7']}
tcrDict2 = {'7':'N1','4':'G1','6':'G2','5':'G3','1':'C1','2':'C2','3':'C3'}
wtseqDict = {'CMV': 'NLVPMVATV','gp100': 'IMDQVPFSV', 'Neoantigen': 'GRLKALCQR'}
for antigen in ['CMV','gp100','Neoantigen']:
  for tcr in tcrDict[antigen]:
    for agDensity in ['1uM','1nM']:
      labeledCytDf = (balachandranModelSamples[median_sample].to_frame("FC")
                    .query("Antigen == @antigen and TCR == @tcr"))
      #KVPRNQDWL
      wtseq = wtseqDict[antigen]
      wtDf = pd.concat([labeledCytDf.query("Peptide == 'WT'")]*len(wtseq),keys=[x+str(i+1)+x for i,x in enumerate(wtseq)]).droplevel('Peptide')
      wtDf.index.names = ['Peptide']+list(wtDf.index.names)[1:]
      wtDf = wtDf.reset_index().set_index(labeledCytDf.index.names)
      labeledCytDf2 = labeledCytDf.query("Peptide != ['Irrelevant','WT']")
      labeledCytDf2 = pd.concat([labeledCytDf2,wtDf])
      labeledCytDf2['WT'] = [x[:2] for x in labeledCytDf2.index.get_level_values('Peptide')]
      labeledCytDf2['Mutant'] = [x[2] for x in labeledCytDf2.index.get_level_values('Peptide')]
      labeledCytDf2 = labeledCytDf2.set_index(['WT','Mutant'],append=True)

      fig = plt.figure(figsize=(5,10))
      plottingDf = labeledCytDf2.droplevel(['Peptide']).loc[:,'FC'].xs(agDensity, level=dens_name).unstack('WT')
      plottingDf = plottingDf[[x+str(i+1) for i,x in enumerate(wtseq)]]
      mutantOrder = 'WFYCMLIVDENHRKQGPSAT'
      plottingDf = plottingDf.droplevel(['Antigen','TCR']).reindex([x for x in mutantOrder],axis=0,level=0).astype(float)
      plottingDf = np.log2(plottingDf)
      #print('-'.join([tcr,antigen]))
      #display(plottingDf.max().max())
      #display(plottingDf.min().min())
      #plottingDf = np.clip(plottingDf,a_min=-11,a_max=0)
      g = sns.heatmap(plottingDf,cbar_kws={'label':r'$FC_{\mathrm{TCR/CAR}}$', 'shrink':0.8},cmap='PuOr_r',center=0)
      wtposes = [[i,plottingDf.index.unique('Mutant').tolist().index(x)] for i,x in enumerate(wtseq)]
      ax = plt.gca()
      ax.set_title(antigen+', TCR '+tcrDict2[tcr])# + r"$^{M\L}$")
      for wtpose in wtposes:
        ax.add_patch(Rectangle((wtpose[0], wtpose[1]), 1, 1, fill=False, edgecolor='k', lw=3))
      g.set_xticklabels([x.get_text()[0] for x in g.get_xticklabels()],rotation=0)
      #g.set_yticklabels([x for x in mutantOrder],rotation=0)
      g.set_yticklabels([x.get_text()[0] for x in g.get_yticklabels()],rotation=0)
      ax.set_xlabel('AA in '+pAg)
      ax.set_ylabel('AA in '+pSelf)
      ax.collections[0].colorbar.ax.tick_params(labelsize=14)
      ogyticks = g.collections[0].colorbar.get_ticks()[1:]
      newyticks = list(pd.unique([int(x) for x in ogyticks]))[:-1]
      newyticklabels = ['2$^{'+str(x)+'}$' for x in newyticks]
      g.collections[0].colorbar.set_ticks(newyticks)
      g.collections[0].colorbar.set_ticklabels(newyticklabels)
      if tcr == "7" and agDensity == "1uM":
          if do_save_plots:
              fig.savefig(pj(fig_dir, '5G-balachandranModelFC-'+antigen+','+tcr+'-'+agDensity+'.pdf'), 
                          bbox_inches="tight", transparent=True, dpi=300)
          plt.show()
      plt.close()

# Panel H: FC Distributions

In [None]:
#@title Plot
antigenList1,antigenList2,tupleList = [],[],[]
ro = ['7','4','6','5','1','2','3']
distrib_all_mcmc_samples = False
if distrib_all_mcmc_samples:
    fullPlottingDf = (np.log2(balachandranModelSamples).query("TCR_Antigen_Pulse_uM == '1uM'")
                .stack().to_frame("FC"))
    show_rug = False
else:
    fullPlottingDf = np.log2(balachandranModelSamples[median_sample].to_frame("FC")
                            .query("TCR_Antigen_Pulse_uM == '1uM'"))
    show_rug = True
g = sns.displot(data=fullPlottingDf,x='FC',kind='kde',color='k',rug=show_rug,row='TCR',rug_kws={'height':0.1},
                facet_kws=dict(sharey=False),height=2.2,aspect=2.5,row_order=ro)
titles = ['Neoantigen, TCR N1','gp100, TCR G1','gp100, TCR G2','gp100, TCR G3','CMV, TCR C1','CMV, TCR C2','CMV, TCR C3']
#titles = [x+'$^{\mathrm{M\L}}$' for x in titles]
figsize = g.figure.get_size_inches()
g.figure.set_size_inches(figsize[0]*0.8, figsize[1])
for i,ag in enumerate(ro):
  axis = g.axes.flat[i]
  plottingDf = fullPlottingDf.query("TCR == @ag")
  axis.axvline(color='k',linestyle=':',x=-1)
  axis.axvline(color='k',linestyle=':',x=1)
  antagonistPercent = 100*plottingDf[plottingDf['FC'] <= -1].shape[0]/plottingDf.shape[0]
  enhancerPercent = 100*plottingDf[plottingDf['FC'] >= 1].shape[0]/plottingDf.shape[0]
  axis.axvspan(axis.get_xlim()[0],-1,color='purple',alpha=0.3)
  axis.axvspan(1,axis.get_xlim()[1],color='orange',alpha=0.3)
  axis.set_xlim([-5,5])
  axis.set_ylabel('# Peptides')
  axis.set_xlabel(r'$FC_{\mathrm{TCR/CAR}}$')
  axis.set_yticks([])
  axis.set_title(titles[i])# +  r"$^{M\L}$")
  ogyticks = g.axes.flat[i].get_xticks()
  newyticks = [-4,-2,0,2,4]
  newyticklabels = ['2$^{'+str(x)+'}$' for x in newyticks]
  axis.set_xticks(newyticks)
  axis.set_xticklabels(newyticklabels)

if do_save_plots:
    g.figure.savefig(pj(fig_dir, '5H-balachandranModelFCdistributions-7plot_rowWise.pdf'), 
                     bbox_inches='tight',transparent=True, dpi=300)
plt.show()
plt.close()

# Panel I: pWT EC50s

In [None]:
#@title Plot
stat_choice = "best"
p = sns.color_palette(sns.color_palette(),10)
p = p[1:4]
p = ['orange','purple','white']
plottingDf = balachandranDf.xs(stat_choice, axis=1).query("Peptide == 'WT'").reset_index()
categoryDict = {'1':'Enhancer','2':'Enhancer','3':'Enhancer','4':'Antagonist','5':'Enhancer','6':'Antagonist','7':'Null'}
plottingDf['Category'] = [categoryDict[x] for x in plottingDf['TCR']]
renamingDict = {'7':'A','4':'A','6':'B','5':'C','1':'A','2':'B','3':'C'}
plottingDf = plottingDf.set_index('TCR').rename(renamingDict).reset_index()
g = sns.catplot(data=plottingDf,height=5,aspect=0.7,sharex=False,col_order=['Neoantigen','gp100','CMV'],y='EC50 (M)',order=['A','B','C'],x='TCR',hue='Category',kind='bar',dodge=False,linewidth=1,edgecolor='k',palette=['k','k','k'],hue_order=['Enhancer','Antagonist','Null'],legend=False,col='Antigen')
g.set(yscale='log')
g.axes.flat[0].set_ylabel('p$\mathregular{^{Ag}}$ EC50 (M)')
for axis in g.axes.flat:
    title = axis.get_title().split(' = ')[1].replace('Neoantigen','Neoag')
    axis.set_title(title)
    firstChar = title[0].upper()
    newlabels = [firstChar+str(i+1) for i,x in enumerate(axis.get_xticklabels())]
    axis.set_xticklabels(newlabels)
if do_save_plots:
    g.figure.savefig(pj(fig_dir, 'balachandranModelEndogenousEC50s-horizontal-separateTCR.pdf'),bbox_inches='tight',transparent=True)
plt.show()
plt.close()

In [None]:
#@title Plot
p = sns.color_palette(sns.color_palette(),10)
p = p[1:4]
p = ['orange','purple','white']
plottingDf = balachandranDf.xs(stat_choice, axis=1).query("Peptide == 'WT'").reset_index()
categoryDict = {'1':'Enhancer','2':'Enhancer','3':'Enhancer','4':'Antagonist','5':'Enhancer','6':'Antagonist','7':'Null'}
plottingDf['Category'] = [categoryDict[x] for x in plottingDf['TCR']]
renamingDict = {'7':'A','4':'A','6':'B','5':'C','1':'A','2':'B','3':'C'}
plottingDf = plottingDf.set_index('TCR').rename(renamingDict).reset_index()
plottingDf['1/EC50 (uM)'] = [1/(x) for x in plottingDf['EC50 (M)']]
g = sns.catplot(data=plottingDf,height=5,aspect=0.7,sharex=False,col_order=['Neoantigen','gp100','CMV'],y='1/EC50 (uM)',order=['A','B','C'],x='TCR',hue='Category',kind='bar',dodge=False,linewidth=1,edgecolor='k',palette=['k','k','k'],hue_order=['Enhancer','Antagonist','Null'],legend=False,col='Antigen')
g.set(yscale='log')
minVal = min(plottingDf['1/EC50 (uM)'])
minVal = np.power(10,np.floor(np.log10(minVal)))
maxVal = max(plottingDf['1/EC50 (uM)'])
#maxVal = np.power(10,np.ceil(np.log10(maxVal)))
#g.axes.flat[0].set_ylabel('p$\mathregular{^{Ag}}$ EC50 (M)')
g.axes.flat[0].set_ylabel(pAg+' 1/EC$_{50}$ (M$^{-1}$)')
g.axes.flat[0].set_ylim([minVal,maxVal*1.1])
for axis in g.axes.flat:
  title = axis.get_title().split(' = ')[1].replace('Neoantigen','Neoag')
  #axis.set_title('')# g.axes.flat[0].set_xlabel('')
  axis.set_title(title)
  firstChar = title[0].upper()
  newlabels = [firstChar+str(i+1) for i,x in enumerate(axis.get_xticklabels())]
  axis.set_xticklabels(newlabels)
if do_save_plots:
    g.figure.savefig(pj(fig_dir, '5I-balachandranModelEndogenousEC50s-horizontal-separateTCR-invertedEC50.pdf'),
                     bbox_inches='tight',transparent=True)
plt.show()

# Panel J: Peptide type distribution summary

In [None]:
#@title Plot
stats_choice = "best"
plottingDf = antagonistPercentageDf2.copy().stack("sample").to_frame("Percentage")
agDict = {'1':'CMV','2':'CMV','3':'CMV','4':'gp100','5':'gp100','6':'gp100','7':'Neoantigen'}
plottingDf['Antigen'] = [agDict[x] for x in plottingDf.index.get_level_values('TCR')]
g = sns.catplot(data=plottingDf.rename({'Neoantigen':'Neoag'}).reset_index(),x='Percentage',y='Antigen',
                hue='Type',dodge=False,palette=['purple','white','orange'],hue_order=['Antagonist','Null','Enhancer'],
                kind='bar',linewidth=1,edgecolor='k',height=3,aspect=1.6)
sns.move_legend(g, "center", bbox_to_anchor=(.5, 1), ncol=3, title='')
g.axes.flat[0].set_xticks([0,25,50,75,100])
g.axes.flat[0].set_xlabel('Peptides (%)')
if do_save_plots:
    g.figure.savefig(pj(fig_dir, '5J-balachandranModelFCdistributions-summaryLegend.pdf'),bbox_inches='tight',transparent=True)
plt.clf()

renamingDict = {'7':'A','4':'A','6':'B','5':'C','1':'A','2':'B','3':'C'}
plottingDf = pd.concat({"Percentage":antagonistPercentageDf1.copy()}, axis=1, names=["Feature"]).stack("sample")
agDict = {'1':'CMV','2':'CMV','3':'CMV','4':'gp100','5':'gp100','6':'gp100','7':'Neoantigen'}
plottingDf['Antigen'] = [agDict[x] for x in plottingDf.index.get_level_values('TCR')]
g = sns.catplot(data=plottingDf.rename({'Neoantigen':'Neoag'}).rename(renamingDict,level='TCR').reset_index(),
                col='Antigen',y='Percentage',x='TCR',hue='Type',dodge=False,palette=['purple','white','orange'][::-1],
                hue_order=['Antagonist','Null','Enhancer'][::-1],kind='bar',linewidth=1,edgecolor='k',height=5,legend=False,
                order=['A','B','C'],col_order=['Neoantigen','gp100','CMV'],aspect=0.7,sharex=False,
                errorbar=('pi', 95))
#sns.move_legend(g, "center", bbox_to_anchor=(.5, 1), ncol=3, title='')
g.axes.flat[0].set_yticks([0,25,50,75,100])
g.axes.flat[0].set_ylabel('Probability (%)')
for axis in g.axes.flat:
  title = axis.get_title().split(' = ')[1].replace('Neoantigen','Neoag')
  #axis.set_title('')# g.axes.flat[0].set_xlabel('')
  axis.set_title(title)
  firstChar = title[0].upper()
  newlabels = [firstChar+str(i+1) for i,x in enumerate(axis.get_xticklabels())]
  axis.set_xticklabels(newlabels)
if do_save_plots:
    g.figure.savefig(pj(fig_dir, '5J-balachandranModelFCdistributions-7plot-summary-horizontal-separateTCR.pdf'),
                     bbox_inches='tight',transparent=True)
plt.show()
plt.close()

## Cartoon dose response plot

In [None]:
#@title Dose response plots
ln10 = np.log(10.0)
def normalized_loghill(x, xmin, logh, n, a):
    xnorm = np.log(x / xmin) / (logh - np.log(xmin))
    return np.exp(a * xnorm**n / (xnorm**n + 1.0) * ln10)

conc_range = np.logspace(-3, 2, 200)  # uM
x_min = conc_range.min()
mock_ec50s = [8e-3, 4e-1, 1e1]
mock_amplitudes = [1.0, 0.8, 0.3]
mock_curves = np.stack([normalized_loghill(conc_range, x_min,
        np.log(mock_ec50s[i]), 4, mock_amplitudes[i]) for i in range(len(mock_amplitudes))])
#mock_colors = sns.color_palette(n_colors=len(mock_ec50s))
mock_colors = ["orange", "grey", "purple"]
# Smaller plot than usual, cartoon
fig, ax = plt.subplots()
default_figsize = fig.get_size_inches()
fig.set_size_inches(default_figsize[0]*0.6, default_figsize[1]*0.5)
ax.set(xscale="log", yscale="log")
for i in range(len(mock_ec50s)):
    ax.plot(conc_range, mock_curves[i], color=mock_colors[i], lw=5.0)
ylims = ax.get_ylim()
for i in range(len(mock_ec50s)):
    ax.axvline(mock_ec50s[i], ymin=0.0,
               ymax=(0.5*mock_amplitudes[i] - np.log10(ylims[0]))/(np.log10(ylims[1]/ylims[0])),
               lw=2.0, ls="--", color=mock_colors[i])
ax.set(xlabel=r"Dose", ylabel="Activation")
for side in ["top", "right"]:
    ax.spines[side].set_visible(False)
ax.set_xticklabels([])
ax.set_yticklabels([])
#ax.set_xticks([])
#ax.set_yticks([])
ax.set_xticks(np.logspace(-3, 2, 6))
locmin = mpl.ticker.LogLocator(base=10.0, subs=np.arange(2, 10) * .1,
                                      numticks=100)
ax.xaxis.set_minor_locator(locmin)
fig.tight_layout()
#fig.savefig(fig_dir + "5a_cartoon_antigen_ec50s.pdf",
#            transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Panel L: HHAT dose responses

In [None]:
import colorcet as cc

In [None]:
df_dose = pd.read_hdf(pj(root_dir, "data", "dose_response", "hhatlibrary4_cellData.h5"))
df_dose = df_dose.droplevel("Donor")  # only Donor C is available anyways
df_dose.index = df_dose.index.set_names(["Peptide", "Concentration"])
df_dose = df_dose.loc[:,['41BB+','CD25+','PD1+']]
df_dose = pd.concat([df_dose]*3,keys=['A','C','D'],names=['Donor'])
df_dose

In [None]:
# Load EC50 fits. Only for 4-1BB+, which we used for MCMC fits
marker_choice = ["41BB+"]
fitDf = pd.read_hdf(pj(root_dir, "results", "pep_libs", "hhatv4_ec50_mcmc_stats_backgnd.h5"))
fitDf = fitDf.xs("MAP", axis=1).droplevel(["Antigen", "TCR"], axis=0)
# Format to be more similar to the data frame
fitDf = pd.concat({marker_choice[0]:fitDf}, names=["Measurement"])
fitDf = pd.concat([fitDf]*3,keys=['A','C','D'],names=['Donor'])  # For plotting consistency
fitDf

In [None]:
df2 = df_dose.stack().to_frame('Value')
df2.index.names = ['[Peptide]' if x == 'Concentration' else x for x in df2.index.names]
xvals,yvals,peptideVals,measurementVals,dVals = [],[],[],[],[]
xvals2,yvals2,peptideVals2,measurementVals2,dVals2 = [],[],[],[],[]
#display(fullDf)
param_order = ["V_inf", "n", "log_ec50_M", "logN_eff", "backgnd"]
conc_lims = (df2.index.get_level_values("[Peptide]").unique().min(), 
            df2.index.get_level_values("[Peptide]").unique().max())
Concentrations_plot=np.linspace(np.log10(conc_lims[0])-1.0, np.log10(conc_lims[1]), 101)
for measurement in marker_choice:
    for donor in pd.unique(fitDf.reset_index()['Donor']):
        fitDf2a = fitDf.loc[(donor, measurement), :]
        for i,peptide in enumerate(fitDf2a.index.unique('Peptide')):
            pvec = fitDf2a.loc[peptide, param_order]
            newX = (Concentrations_plot).copy().ravel().tolist()[1:]
            newXfunc = (Concentrations_plot).copy().ravel().tolist()[1:]
            # hill_back returns a fraction in [0, 1]
            # so multiply by 100 to get percentages
            newY = list(100.0*hill_back(newXfunc, pvec))
            xvals+=newX
            yvals+=newY
            peptideVals+=[peptide]*len(newX)
            dVals+=[donor]*len(newX)
            measurementVals+=[measurement]*len(newX)

            temp = df2.query("Peptide == @peptide and Measurement == @measurement").reset_index()
            temp['[Peptide]'] = [np.log10(x) for x in temp['[Peptide]']]
            xvals2+=list(temp['[Peptide]'].values)
            yvals2+=list(temp['Value'].values)
            dVals2+=[donor]*temp.shape[0]
            peptideVals2+=[peptide]*temp.shape[0]
            measurementVals2+=[measurement]*temp.shape[0]

drPlottingDf = pd.DataFrame({'x':xvals,'y':yvals,'Donor':dVals,'Peptide':peptideVals,'Measurement':measurementVals})
drPlottingDf2 = pd.DataFrame({'x':xvals2,'y':yvals2,'Donor':dVals2,'Peptide':peptideVals2,'Measurement':measurementVals2})
drPlottingDf

In [None]:
# Peptide mutation colors generated with colorcet.glasbey palette
pepPalette = [
    '#d60000', '#8c3bff', '#018700', '#00acc6', 
    '#97ff00', '#ff7ed1', '#6b004f', '#ffa52f', 
    '#573b00', '#005659', '#0000dd', '#00fdcf', 
    '#a17569', '#bcb6ff', '#95b577', '#bf03b8', 
    '#645474', '#790000', '#0774d8', '#fdf490'
]

In [None]:
sns.set_context('poster')
cyt = marker_choice[0]
badPeps = ['WildType','PMA/iono','DMSO','p8F']
plottingDf3 = drPlottingDf.copy().query("Measurement == @cyt").query("Peptide != @badPeps").set_index('Peptide').rename({'WT':'F8L'}).reset_index()#.query("Donor == 'D'")
plottingDf4 = drPlottingDf2.copy().query("Measurement == @cyt").query("Peptide != @badPeps").set_index('Peptide').rename({'WT':'F8L'}).reset_index()#.query("Donor == 'D'")
plottingDf3['[Peptide] (M)'] = [np.power(10,x) for x in plottingDf3.reset_index()['x']]
plottingDf4['[Peptide] (M)'] = [np.power(10,x) for x in plottingDf4.reset_index()['x']]

#-1*np.min(df_dose_cyt.iloc[:,0])+df_dose_cyt.iloc[:,0]
if cyt not in ['IL-2','IFNg']:
    yaxislabel = cyt+'+ Frequency (%)'
    plottingDf3[yaxislabel] = np.add(np.min(df_dose.loc[:,[cyt]].iloc[:,0]),plottingDf3.reset_index()['y'])
    plottingDf4[yaxislabel] = np.add(np.min(df_dose.loc[:,[cyt]].iloc[:,0]),plottingDf4.reset_index()['y'])
else:
    yaxislabel = '['+cyt+'] (nM)'
    plottingDf3[yaxislabel] = np.add(np.min(np.log10(cytDf.query("Donor == 'D'")).query("Cytokine == @cyt").iloc[:,0]),plottingDf3.reset_index()['y'])
    plottingDf4[yaxislabel] = np.add(np.min(np.log10(cytDf.query("Donor == 'D'")).query("Cytokine == @cyt").iloc[:,0]),plottingDf4.reset_index()['y'])
#plottingDf3['['+cyt+'] (nM)'] = [x for x in plottingDf3.reset_index()['y']]
#plottingDf4['['+cyt+'] (nM)'] = [x for x in plottingDf4.reset_index()['y']]

ogPep = 'KQWLVWLFL'
plottingDf3['Position'] = [x[1] for x in plottingDf3['Peptide']]
plottingDf3['Initial AA'] = [ogPep[int(x[1])-1]+x[1] for x in plottingDf3['Peptide']]
plottingDf3['OGPep'] = [x[0] for x in plottingDf3['Peptide']]
plottingDf3['Mutation'] = [x[2] for x in plottingDf3['Peptide']]

plottingDf4['Position'] = [x[1] for x in plottingDf4['Peptide']]
plottingDf4['Initial AA'] = [ogPep[int(x[1])-1]+str(x[1]) for x in plottingDf4['Peptide']]
plottingDf4['OGPep'] = [x[0] for x in plottingDf4['Peptide']]
plottingDf4['Mutation'] = [x[2] for x in plottingDf4['Peptide']]
display(plottingDf3)
ho = sorted(plottingDf3.set_index('Mutation').index.unique('Mutation').tolist())
co = [str(x+1) for x in range(9)]
co2 = [ogPep[x]+str(x+1) for x in range(9)]
plottingDf3.columns = [r'Initial AA (p$\mathrm{^{Ag}}$)' if x == 'Initial AA' else x for x in plottingDf3.columns]
plottingDf3.columns = [r'Mutation (p$\mathrm{^{APL}}$)' if x == 'Mutation' else x for x in plottingDf3.columns]
plottingDf4.columns = [r'Mutation (p$\mathrm{^{APL}}$)' if x == 'Mutation' else x for x in plottingDf4.columns]
g = sns.relplot(data=plottingDf3.query("Donor == 'A'"),x='[Peptide] (M)',col_order=co2,y=yaxislabel,hue=r'Mutation (p$\mathrm{^{APL}}$)',kind='line',palette=pepPalette,
                col=r'Initial AA (p$\mathrm{^{Ag}}$)',hue_order=ho)
for i,axis in enumerate(g.axes.flat):
    pos = str(i+1)
    peps = plottingDf3.query("Position == @pos").set_index('Peptide').index.unique('Peptide').tolist()
    sns.scatterplot(data=plottingDf4.query("Peptide == @peps"),x='[Peptide] (M)',y=yaxislabel,hue=r'Mutation (p$\mathrm{^{APL}}$)',legend=False,palette=pepPalette,ax=g.axes.flat[i],hue_order=ho)
g.set(xscale='log')
sns.move_legend(g, "center left", bbox_to_anchor=(0, 1.05),ncol=25)

g.axes.flat[0].set_ylabel('Response (4-1BB+ %)')
if do_save_plots:
    g.fig.savefig(fig_dir+'hhatlibrary_doseResponses-'+cyt+'.pdf',bbox_inches='tight',transparent=True)
plt.show()
plt.close()

# Panel M: heatmap of peptide EC50s

In [None]:
badPeps = ['WildType','PMA/iono','DMSO','p8F']
dfList,dfList2 = [],[]
dfList_b,dfList2_b = [],[]
dm1 = pd.read_hdf(pj(res_dir, 'mskcc_antagonism_fc_predictions_corrected_revised.h5'), key='fc_samples').query("TCR_Antigen_Pulse_uM == '1uM'")
dm1_b = pd.read_hdf(pj(res_dir, 'mskcc_antagonism_fc_predictions_corrected_revised.h5'), key='EC50_samples')
dm2 = (pd.read_hdf(pj(res_dir, 'hhatv4_antagonism_fc_predictions_corrected_revised.h5'), key='fc_samples')
       .rename({'HHAT-L8F':'HHAT'}).query("TCR_Antigen_Pulse_uM == '1uM'").rename({'WT':'F8L','L8F':'WT'}).query("Peptide != @badPeps"))
dm2_b = (pd.read_hdf(pj(res_dir, 'hhatv4_antagonism_fc_predictions_corrected_revised.h5'), key='EC50_samples')
         .rename({'HHAT-L8F':'HHAT'}).rename({'WT':'F8L','L8F':'WT'}).query("Peptide != @badPeps"))
dm = pd.concat([dm1,dm2])
dm_b = pd.concat([dm1_b,dm2_b])

for cat,bounds in zip(['antagonists','null','agonists'],[(-np.inf,0.5),(0.5,2),(2,np.inf)]):
#     print(cat)
    lowerBound,upperBound = bounds
    nulls = (lowerBound <= dm) & (dm < upperBound)
    fullData = nulls.groupby(["Antigen", "TCR"]).sum()
    for i,tcr in enumerate(dm.index.unique('TCR')):
        arr = np.array(fullData.query("TCR == @tcr"))[0]
        sorted_indices = np.argsort(arr)
        sorted_arr = arr[sorted_indices]
        n = len(sorted_arr)
        if n % 2 == 0:
            median_index = (n // 2) - 1  # Index to the left of the median for even-length arrays
        else:
            median_index = n // 2  # Index of the median for odd-length arrays
        original_median_index = sorted_indices[median_index]

        tcrDf = np.log2(dm.query("TCR == @tcr").iloc[:,original_median_index].to_frame('FC'))
        tcrDf2 = dm_b.query("TCR == @tcr").iloc[:,original_median_index].to_frame('EC50 (M)')
        dfList.append(pd.concat([tcrDf],keys=[cat],names=['peptide_type']))
        dfList_b.append(pd.concat([tcrDf2],keys=[cat],names=['peptide_type']))
        if cat == 'null':
            dfList2.append(tcrDf)
            dfList2_b.append(tcrDf2)
#         print(tcr)
#         print(tcrDf[abs(tcrDf['FC'])<=1].shape[0]/172)
fullFCDf = pd.concat(dfList)
fullEC50Df = pd.concat(dfList_b)

balachandranDf = np.power(10,pd.concat(dfList2_b))
display(balachandranDf)

balachandranModelDf = np.power(2,pd.concat(dfList2))
balachandranModelDf.index.names = ['TCR_Antigen_Density_uM' if x == 'TCR_Antigen_Pulse_uM' else x for x in balachandranModelDf.index.names]
display(balachandranModelDf)

In [None]:
#@title Plot
from matplotlib.patches import Rectangle
sns.set_context('talk')

cyt='41BB+'
tcrDict = {'CMV':['1','2','3'],'gp100':['4','5','6'],'Neoantigen':['7'],'HHAT':['8']}
tcrDict2 = {'7':'N1','4':'G1','6':'G2','5':'G3','1':'C1','2':'C2','3':'C3','8':'N2'}
wtseqDict = {'CMV': 'NLVPMVATV','gp100': 'IMDQVPFSV', 'Neoantigen': 'GRLKALCQR','HHAT':'KQWLVWLFL'}
for antigen in ['HHAT']:
  for tcr in tcrDict[antigen]:
    labeledCytDf = balachandranDf.query("Antigen == @antigen and TCR == @tcr")
    #KVPRNQDWL
    wtseq = wtseqDict[antigen]
    wtDf = pd.concat([labeledCytDf.query("Peptide == 'WT'")]*len(wtseq),keys=[x+str(i+1)+x for i,x in enumerate(wtseq)]).droplevel('Peptide')
    wtDf.index.names = ['Peptide']+list(wtDf.index.names)[1:]
    wtDf = wtDf.reset_index().set_index(labeledCytDf.index.names)
    labeledCytDf2 = labeledCytDf.query("Peptide != 'WT'")
    labeledCytDf2 = pd.concat([labeledCytDf2,wtDf])
    labeledCytDf2['WT'] = [x[:2] for x in labeledCytDf2.index.get_level_values('Peptide')]
    labeledCytDf2['Mutant'] = [x[2] for x in labeledCytDf2.index.get_level_values('Peptide')]
    labeledCytDf2 = labeledCytDf2.set_index(['WT','Mutant'],append=True)

    fig = plt.figure(figsize=(5,10))
    plottingDf = labeledCytDf2.droplevel(['Peptide']).loc[:,'EC50 (M)'].unstack('WT')
    plottingDf = plottingDf[[x+str(i+1) for i,x in enumerate(wtseq)]]
    mutantOrder = 'WFYCMLIVDENHRKQGPSAT'
    plottingDf = plottingDf.droplevel(['Antigen','TCR']).reindex([x for x in mutantOrder],axis=0,level=0).astype(float)
    plottingDf = np.log10(plottingDf)
    #print('-'.join([tcr,antigen]))
    #display(plottingDf.max().max())
    #display(plottingDf.min().min())
    plottingDf = np.clip(plottingDf,a_min=-9,a_max=0)
    g = sns.heatmap(plottingDf,cbar_kws={'label':'EC$_{50}$ (M)', 'shrink':0.8,'pad':0.15}, cmap='magma_r',square=True,vmin=-9,vmax=0)
    wtposes = [[i,plottingDf.index.unique('Mutant').tolist().index(x)] for i,x in enumerate(wtseq)]
    self = 'F8L'
    selfposes = [[int(self[1])-1,plottingDf.index.unique('Mutant').tolist().index(self[2])]]
    ax = plt.gca()
    ax.set_title(antigen+'$\mathrm{^{L8F}}$'+', TCR '+tcrDict2[tcr])# + r"$^{M\L}$")
    for wtpose in wtposes:
        ax.add_patch(Rectangle((wtpose[0], wtpose[1]), 1, 1, fill=False, edgecolor='r', lw=3))
    for wtpose in selfposes:
        ax.add_patch(Rectangle((wtpose[0], wtpose[1]), 1, 1, fill=False, edgecolor='b', lw=3))
    g.set_xticklabels([x.get_text()[0] for x in g.get_xticklabels()],rotation=270)
    #g.set_yticklabels([x for x in mutantOrder],rotation=0)
    g.set_yticklabels([x.get_text()[0] for x in g.get_yticklabels()],rotation=270,va='center')
    ax.set_xlabel('AA in '+pAg)
    ax.set_ylabel('AA in '+pSelf,rotation=270,va='top')
    ax.collections[0].colorbar.ax.tick_params(labelsize=14)

    ogyticks = g.collections[0].colorbar.get_ticks()#[1:]
    newyticks = list(pd.unique([int(x) for x in ogyticks]))
    newyticklabels = ['10$^{'+str(x)+'}$' for x in newyticks]
    g.collections[0].colorbar.set_ticks(newyticks)
    g.collections[0].colorbar.set_ticklabels(newyticklabels,rotation=270)
    g.collections[0].colorbar.set_label('EC$_{50}$ (M)',rotation=270,va='bottom',ha='center')
    g.collections[0].colorbar.ax.yaxis.set_ticks_position('left')

    if do_save_plots:
        fig.savefig(fig_dir+'5M-MSKCCExperimentalEC50-'+antigen+','+tcr+'-'+cyt+'_horiz.pdf',
                    bbox_inches='tight',transparent=True, dpi=300)
    if tcr == "8":
        plt.show()
    plt.close()

# Panel N: heatmap of model FC predictions

In [None]:
dfList = []
dfList2 = []
dm1 = pd.read_hdf(pj(res_dir, 'mskcc_antagonism_fc_predictions_corrected_revised.h5'), 
                  key='fc_samples').query("TCR_Antigen_Pulse_uM == '1uM'")
dm2 = (pd.read_hdf(pj(res_dir, 'hhatv4_antagonism_fc_predictions_corrected_revised.h5'),key='fc_samples')
       .query("TCR_Antigen_Pulse_uM == '1uM'").rename({'HHAT-L8F':'HHAT'})
       .query("Peptide != @badPeps").rename({'WT':'F8L','L8F':'WT'}))
dm = pd.concat([dm1,dm2])
for cat,bounds in zip(['antagonists','null','agonists'],[(-np.inf,0.5),(0.5,2),(2,np.inf)]):
    lowerBound,upperBound = bounds
    nulls = (lowerBound <= dm) & (dm < upperBound)
    fullData = nulls.groupby(["Antigen", "TCR"]).sum()
    for i,tcr in enumerate(dm.index.unique('TCR')):
        arr = np.array(fullData.query("TCR == @tcr"))[0]
        sorted_indices = np.argsort(arr)
        sorted_arr = arr[sorted_indices]
        n = len(sorted_arr)
        if n % 2 == 0:
            median_index = (n // 2) - 1  # Index to the left of the median for even-length arrays
        else:
            median_index = n // 2  # Index of the median for odd-length arrays
        original_median_index = sorted_indices[median_index]

        tcrDf = np.log2(dm.query("TCR == @tcr").iloc[:,original_median_index].to_frame('FC'))
        dfList.append(pd.concat([tcrDf],keys=[cat],names=['peptide_type']))
        if cat == 'null':
            dfList2.append(tcrDf)
#         print(tcr)
#         print(tcrDf[abs(tcrDf['FC'])<=1].shape[0]/172)
fullFCDf = pd.concat(dfList)
balachandranModelDf = np.power(2,pd.concat(dfList2))
balachandranModelDf.index.names = (['TCR_Antigen_Density_uM' if x == 'TCR_Antigen_Pulse_uM' 
                                    else x for x in balachandranModelDf.index.names])
balachandranModelDf

In [None]:
#@title Plot
sns.set_context('talk')
cyt='41BB+'
from matplotlib.patches import Rectangle
from matplotlib.colors import LinearSegmentedColormap

def create_custom_colormap():
    # Get the PuOr colormap
    puor = plt.get_cmap('PuOr_r')

    # Calculate the proportion of each segment
    total_range = 6  # from -2.5 to 3.5
    purple_range = 1.5  # from -2.5 to -1
    white_range = 2   # from -1 to 1
    orange_range = 2.5  # from 1 to 3.5

    purple_proportion = purple_range / total_range
    white_proportion = white_range / total_range
    orange_proportion = orange_range / total_range

    # Extract colors from PuOr colormap
    purple_colors = puor(np.linspace(0, 0.45, int(256 * purple_proportion)))
    white_colors = np.array([[1, 1, 1, 1]] * int(256 * white_proportion))
    orange_colors = puor(np.linspace(0.55, 1, int(256 * orange_proportion)))

    # Combine color lists
    colors = np.vstack((purple_colors, white_colors, orange_colors))

    # Create the new colormap
    custom_cmap = LinearSegmentedColormap.from_list('custom', colors)

    # Set the range for each segment
    custom_cmap.set_under(purple_colors[0])  # Color for values < -2
    custom_cmap.set_over(orange_colors[-1])  # Color for values > 3

    return custom_cmap

# Create the custom colormap
cmap_custom = create_custom_colormap()

dens_name = 'TCR_Antigen_Density_uM'
tcrDict = {'CMV':['1','2','3'],'gp100':['4','5','6'],'Neoantigen':['7'],'HHAT':['8']}
tcrDict2 = {'7':'N1','4':'G1','6':'G2','5':'G3','1':'C1','2':'C2','3':'C3','8':'N2'}
wtseqDict = {'CMV': 'NLVPMVATV','gp100': 'IMDQVPFSV', 'Neoantigen': 'GRLKALCQR','HHAT':'KQWLVWLFL'}
agDensity = '1uM'
for antigen in ['HHAT']:
  for tcr in tcrDict[antigen]:
      labeledCytDf = balachandranModelDf.query("Antigen == @antigen and TCR == @tcr")
      #KVPRNQDWL
      wtseq = wtseqDict[antigen]
      wtDf = pd.concat([labeledCytDf.query("Peptide == 'WT'")]*len(wtseq),keys=[x+str(i+1)+x for i,x in enumerate(wtseq)]).droplevel('Peptide')
      wtDf.index.names = ['Peptide']+list(wtDf.index.names)[1:]
      wtDf = wtDf.reset_index().set_index(labeledCytDf.index.names)
      labeledCytDf2 = labeledCytDf.query("Peptide != ['Irrelevant','WT']")
      labeledCytDf2 = pd.concat([labeledCytDf2,wtDf])
      labeledCytDf2['WT'] = [x[:2] for x in labeledCytDf2.index.get_level_values('Peptide')]
      labeledCytDf2['Mutant'] = [x[2] for x in labeledCytDf2.index.get_level_values('Peptide')]
      labeledCytDf2 = labeledCytDf2.set_index(['WT','Mutant'],append=True)

      fig = plt.figure(figsize=(5,10))
      plottingDf = labeledCytDf2.droplevel(['Peptide']).loc[:,'FC'].xs(agDensity, level=dens_name).unstack('WT')
      plottingDf = plottingDf[[x+str(i+1) for i,x in enumerate(wtseq)]]
      mutantOrder = 'WFYCMLIVDENHRKQGPSAT'
      plottingDf = plottingDf.droplevel(['Antigen','TCR']).reindex([x for x in mutantOrder],axis=0,level=0).astype(float)
      plottingDf = np.log2(plottingDf)
      display(plottingDf.max().max())
      display(plottingDf.min().min())
      #print('-'.join([tcr,antigen]))
      #display(plottingDf.max().max())
      #display(plottingDf.min().min())
      #plottingDf = np.clip(plottingDf,a_min=-11,a_max=0)
        
      g = sns.heatmap(plottingDf,cbar_kws={'label':r'$FC_{\mathrm{TCR/CAR}}$', 'shrink':0.8,'pad':0.15},vmin=-2.5,vmax=3.5,cmap=cmap_custom,square=True)
      wtposes = [[i,plottingDf.index.unique('Mutant').tolist().index(x)] for i,x in enumerate(wtseq)]
      self = 'F8L'
      selfposes = [[int(self[1])-1,plottingDf.index.unique('Mutant').tolist().index(self[2])]]
      ax = plt.gca()
      ax.set_title(antigen+'$\mathrm{^{L8F}}$'+', TCR '+tcrDict2[tcr])# + r"$^{M\L}$")
      for wtpose in wtposes:
        ax.add_patch(Rectangle((wtpose[0], wtpose[1]), 1, 1, fill=False, edgecolor='r', lw=3))
      for wtpose in selfposes:
        ax.add_patch(Rectangle((wtpose[0], wtpose[1]), 1, 1, fill=False, edgecolor='b', lw=3))
      g.set_xticklabels([x.get_text()[0] for x in g.get_xticklabels()],rotation=270)
      #g.set_yticklabels([x for x in mutantOrder],rotation=0)
      g.set_yticklabels([x.get_text()[0] for x in g.get_yticklabels()],rotation=270,va='center')
      ax.set_xlabel('AA in '+pAg)
      ax.set_ylabel('AA in '+pSelf,rotation=270,va='top')
      ax.collections[0].colorbar.ax.tick_params(labelsize=14)
      ogyticks = g.collections[0].colorbar.get_ticks()[1:]
      newyticks = list(pd.unique([int(x) for x in ogyticks]))[:-1]
      newyticklabels = ['2$^{'+str(x)+'}$' for x in newyticks]
      g.collections[0].colorbar.set_ticks(newyticks)
      g.collections[0].colorbar.set_ticklabels(newyticklabels,rotation=270)
      g.collections[0].colorbar.set_label(r'$FC_{\mathrm{TCR/CAR}}$',rotation=270,va='bottom',ha='center')
      g.collections[0].colorbar.ax.yaxis.set_ticks_position('left')
      g.collections[0].colorbar.ax.invert_yaxis()
      if do_save_plots:
          fig.savefig(pj(fig_dir, '5N-ModelFC-'+antigen+','+tcr+'-'+cyt+'_horiz.pdf'), 
                      bbox_inches='tight',transparent=True, dpi=300)
      plt.show()
      plt.close()

# Panel O: regression of antigen kind fractions versus original peptide strength

In [None]:
from matplotlib.ticker import LogLocator, LogFormatter
import matplotlib.ticker as mticker
from scipy.optimize import curve_fit

sns.set_context('poster')

plottingDf1 = pd.read_hdf(pj(res_dir, 'mskcc_antagonism_fc_predictions_corrected_revised.h5'),key='fracs_stats')
plottingDf2 = pd.read_hdf(pj(res_dir, 'hhatv4_antagonism_fc_predictions_corrected_revised.h5'),key='fracs_stats')
plottingDf = pd.concat([plottingDf1,plottingDf2])
EC50df1 = pd.read_hdf(pj(res_dir, 'mskcc_antagonism_fc_predictions_corrected_revised.h5'),key='EC50_fits').loc[:,'MAP'].loc[:,['log_ec50_M']].query("Peptide == 'WT'")
EC50df2 = pd.read_hdf(pj(res_dir, 'hhatv4_antagonism_fc_predictions_corrected_revised.h5'),key='EC50_fits').loc[:,'MAP'].loc[:,['log_ec50_M']].query("Peptide == 'WT'")
EC50df = pd.concat([EC50df1,EC50df2])
EC50Dict = {x+'-'+y:z for x,y,z in zip(EC50df.reset_index()['Antigen'],EC50df.reset_index()['TCR'],EC50df['log_ec50_M'])}

plottingDf['EC50'] = [EC50Dict[x+'-'+y] if y != '8' else np.log10(1.3e-8) for x,y in zip(plottingDf.index.get_level_values('Antigen'),plottingDf.index.get_level_values('TCR'))]

# Read the CSV file
temp = plottingDf.loc[:,['median','EC50','percentile_2.5','percentile_97.5']].rename({'null':'Null'})
temp.columns.name = ''
temp.columns = ['y','x','lowerError','upperError']
df = temp.reset_index().loc[:,['peptide_type','Antigen','x','y','lowerError','upperError']]

df['x'] = -1*df['x']
df['y']*=100
df['lowerError']*=100
df['upperError']*=100

hue_order = ['agonists', 'Null', 'antagonists']
color_order = ['orange', 'grey', 'purple']
palette = dict(zip(hue_order, color_order))

# Define a mapping from Antigen types to marker styles
marker_styles = {
    'CMV': 'X',
    'Neoantigen': 'D',
    'gp100': 'o',
    'HHAT-L8F': 'P',
    # Add more types if needed
}

# Define the quadratic function
def quadratic(x, a, b, c):
    return a * x**2 + b * x + c

# Function to calculate prediction interval
# Function to perform regression and plot results
def fit_and_plot(peptide_type):
    data = df[df['peptide_type'] == peptide_type]
    x = data['x']
    y = data['y']
    c = palette[peptide_type]
    
    if len(x) == 0 or len(y) == 0:
        print(f"No data for {peptide_type}")
        return None
    
    # Calculate errors and weights
    yerr = (data['upperError'] - data['lowerError']) / 2
    weights = 1 / yerr**2
    
    # Perform the weighted fit
    popt, pcov = curve_fit(quadratic, x, y)#, sigma=yerr, absolute_sigma=True)
    
    # Generate points for plotting the fit
    x_fit = np.linspace(np.min(x), np.max(x), 100)
    y_fit = quadratic(x_fit, *popt)
    x_fit = np.power(10,x_fit)
    
    # Plot the results with varying marker styles based on Antigen type
    for antigen_type in data['Antigen'].unique():
        antigen_data = data[data['Antigen'] == antigen_type]
        marker = marker_styles.get(antigen_type, 'o')
        y = np.array(antigen_data['y'].values)
        ymin = np.array(antigen_data['lowerError'].values)
        ymax = np.array(antigen_data['upperError'].values)
        ytop = ymax-y
        ybot = y-ymin
        plt.errorbar(antigen_data['10x'], y, yerr=[ybot,ytop],
                     fmt=marker, label=f'{peptide_type} - {antigen_type} (data)', color=c,capsize=5,elinewidth=2,markeredgewidth=2,markersize=15)
    
    plt.plot(x_fit, y_fit, label=f'{peptide_type} (fit)', color=c)
    #plt.fill_between(x_fit, lower, upper, alpha=0.2, color=c, label=f'{peptide_type} (95% CI)')
    
    return popt, pcov

# Create a new figure
df['10x'] = np.power(10,df['x'])
h=6.7
g = sns.relplot(data=df,
                x='10x',y='y',hue='peptide_type',palette=['orange','grey','purple'],
                legend=False,hue_order=['agonist','Null','antagonist'],height=h,aspect=7.5/h)
g.set(xscale='log')

# Fit and plot for each peptide type
peptide_types = df['peptide_type'].unique()

for pt in peptide_types:
    try:
        result = fit_and_plot(pt)
        if result is not None:
            popt, pcov = result
            perr = np.sqrt(np.diag(pcov))
    except Exception as e:
        print(f"Could not fit {pt}: {str(e)}")

plt.xlabel('1/EC$_{50}$ (M$^{-1}$)')
plt.ylabel('Probability (%)')
plt.xlim(df['10x'].min()*0.7, df['10x'].max()*1.42)

axis = plt.gca()
#locmin = mticker.LogLocator(base=10, subs=np.arange(0.1,1,0.1), numticks=10)  

# ogxticks =axis.get_xticks()
# newxticks = list(pd.unique([int(x) for x in ogxticks]))
# newxticklabels = ['10$^{'+str(x)+'}$' for x in newxticks]
# axis.set_xticks(newxticks)
# axis.set_xticklabels(newxticklabels)

# Display grid for minor ticks
# axis.grid(True, which='both', linestyle='--', linewidth=0.5)

#plt.legend()

import matplotlib.patches as patches
# Create a Rectangle patch
#ax = g.axes.flat[0]
#rect = patches.Rectangle((0.255, 0.04), 0.06, 0.87, edgecolor='k', facecolor='none',transform=ax.transAxes)
# Add the patch to the Axes
#ax.add_patch(rect)

if do_save_plots:
    g.fig.savefig(fig_dir+'5O-hhatlibrary_summary3b2-noBox.pdf',bbox_inches='tight',transparent=True)
plt.show()
plt.close()

### Associated legend

In [None]:
#@title Plot
sns.set_context('poster')
balachandranDf = pd.read_hdf(pj(data_dir, "dose_response", 'MSKCC_originalEC50df.hdf'))
balachandranDf = pd.concat([balachandranDf,balachandranDf.query("Antigen == 'Neoantigen'").rename({'7':'8','Neoantigen':'HHAT'})])

tcr_to_antigen_map = {"1":"CMV", "2":"CMV", "3":"CMV", "4":"gp100", "5":"gp100", "6":"gp100", "7":"Neoantigen","8":"HHAT"}
tempDf1 = balachandranDf.query("Peptide == 'WT'").droplevel(['Peptide','Antigen']).iloc[:,[1]]
tempDf1 = pd.concat([tempDf1]*3,keys=['Antagonist','Null','Enhancer'],names=['Type']).swaplevel(0,1)
tempDf2 = antagonistPercentageDf1.stack("sample").unstack('Type')
tempDf2.loc[:,"Null"] = tempDf2.loc[:,"Null"]-tempDf2.loc[:,"Antagonist"]
tempDf2.loc[:,"Enhancer"] = tempDf2.loc[:,"Enhancer"]-tempDf2.loc[:,"Null"]-tempDf2.loc[:,"Antagonist"]
tempDf2 = tempDf2.stack("Type")
tempDf2.index = tempDf2.index.reorder_levels(["TCR", "Type", "sample"])
tempDf2 = tempDf2.sort_index()

ec50Column = tempDf1.loc[tempDf2.index.droplevel("sample")]
ec50Column["sample"] = tempDf2.index.get_level_values("sample")
ec50Column = ec50Column.set_index("sample", append=True, drop=True)
tempDf_full = pd.concat({"EC50 (M)":ec50Column["EC50 (M)"], "Percentage":tempDf2}, axis=1, names=["Feature"])
tempDf_full['Antigen'] = list(map(tcr_to_antigen_map.get, tempDf_full.index.get_level_values("TCR")))
tempDf_full['1/EC50 (1/M)'] = [1/(x) for x in tempDf_full['EC50 (M)']]

tempDf_full2 = tempDf_full.copy()
tempDf_full2.index.names = ['Peptide Type' if x == 'Type' else x for x in tempDf_full2.index.names]

hue_order = ['Enhancer','Null','Antagonist']
color_order = ['orange','grey','purple']
palette = dict(zip(hue_order, color_order))
antigen_order = ["CMV", "gp100", "Neoantigen","HHAT"]
markers_map = dict(zip(antigen_order, ['X','o','D','P']))
antigen_order2 = ["CMV", "gp100", "PDAC$\mathrm{^{S6L}}$","HHAT$\mathrm{^{L8F}}$"]
markers_map2 = dict(zip(antigen_order2, ['X','o','D','P']))

newDf = pd.DataFrame({'TCR':['8','8','8'],'Peptide Type':['Enhancer','Antagonist','Null'],
                      'sample':['best','best','best'],
                      'EC50 (M)':[1/(1.3e-8),1/(1.3e-8),1/(1.3e-8)],'Antigen':['HHAT','HHAT','HHAT'],
                    '1/EC50 (1/M)':[1/(1.3e-8),1/(1.3e-8),1/(1.3e-8)],'Percentage':[0.104651*100,0.441860*100,0.453488*100],
                      'EC50 (M)':[1/(1.3e-8),1/(1.3e-8),1/(1.3e-8)]}).set_index(['TCR','Peptide Type','sample'])
tempDf_full2 = pd.concat((tempDf_full2, newDf))

g = sns.relplot(data=tempDf_full2.set_index('Antigen',append=True).rename({'HHAT':'HHAT$\mathrm{^{L8F}}$','Neoantigen':'PDAC$\mathrm{^{S6L}}$'}),x='EC50 (M)',y='Percentage',hue='Peptide Type',
                palette=palette,hue_order=hue_order,kind='scatter',height=5,aspect=1.5,
                style='Antigen',markers=markers_map2,style_order=antigen_order2)

# From this dummy plot, make a stand-alone legend
handles, labels = g.figure.axes[0].get_legend_handles_labels()
plt.close()
fig, ax, leg = standalone_legend(handles=handles, frameon=False)
if do_save_plots:
    fig.savefig(pj(fig_dir, '5O-hhatlibrary_summary-legend.pdf'), 
                transparent=True, bbox_inches="tight", bbox_extra_artists=(leg,))
plt.show()
plt.close()