# Imports

In [None]:
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import os, json, h5py
import warnings
pj = os.path.join
# Less known modules, but very useful
import corner

from scripts.preprocess import (
    write_conc_uM, 
    read_conc_uM, 
    string_to_tuple, 
    geo_mean_levels, 
    geo_mean, 
    geo_mean_apply
)
from scripts.plotting import (
    change_log_ticks, 
    data_model_handles_legend, 
    standalone_legend, 
    corner_plot_mcmc, 
    plot_tcr_tcr_fits, 
    plot_autocorr, 
    standalone_parameter_values
)
from scripts.analysis import find_best_grid_point, autocorr_func_avg, autocorr_avg

In [None]:
root_dir = ".."
data_dir = os.path.join(root_dir, 'data')
res_dir = os.path.join(root_dir, 'results', 'for_plots')
mcmc_dir = pj(root_dir, "results", "mcmc")
fig_dir = "panels_sup2"

do_save_plots = False

In [None]:
# Aesthetic parameters. Small scale figure.
scaleup = 1.0

# rcParams
plt.rcParams["font.size"] =  6. * scaleup  # Default, smallish
plt.rcParams["figure.dpi"] = 150.0
plt.rcParams["figure.figsize"] = (3, 3)
plt.rcParams["axes.labelpad"] = 0.5 * scaleup
plt.rcParams["axes.linewidth"] = 0.75 * scaleup     # edge line width
plt.rcParams["lines.linewidth"] = 1.5 * scaleup               # line width in points
plt.rcParams["lines.markersize"] = 2.5 * scaleup   # marker size, in points
for x in ["xtick.", "ytick."]:
    plt.rcParams[x + "major.size"] = 2.25 * scaleup
    plt.rcParams[x + "minor.size"] = 1.8 * scaleup
    plt.rcParams[x + "major.pad"] = 2.0 * scaleup
    plt.rcParams[x + "minor.pad"] = 1.9 * scaleup
    plt.rcParams[x + "minor.width"] = 0.5 * scaleup
    plt.rcParams[x + "major.width"] = 0.75 * scaleup
plt.rcParams["axes.spines.top"] = False
plt.rcParams["axes.spines.right"] = False

with open(pj(res_dir, "perturbations_palette.json"), "r") as h:
    perturbations_palette = json.load(h)
perturbations_palette["None"] = (0.0, 0.0, 0.0, 1.0)
sns.palplot(perturbations_palette.values())

In [None]:
# Teal palette for antigen densities
agconc_palette = np.asarray([(66, 150, 141, 255), (15, 85, 97, 255), (13, 50, 70, 255)]) / 255
agconc_palette = np.concatenate([agconc_palette, np.asarray([(0.0, 0.0, 0.0, 1.0)])])
sns.palplot(agconc_palette)

### Generally useful quantities

In [None]:
full_car_data = pd.read_hdf(pj(data_dir, "antagonism", "combined_OT1-CAR_antagonism.hdf"))
best_kmf_car = (1, 2, 1)

# TCR/CAR MCMC

## Corner plot: moved to main figure
Very similar to what I did for TCR/TCR. 

## TCR/CAR autocorrelation functions: also moved to main figure

## CAR integer parameters

In [None]:
car_integers_names = ["N^C", "k^C_I", "m^C", "f^C"]
car_integers_values = [3, *best_kmf_car]
fig, ax, text = standalone_parameter_values(car_integers_names, car_integers_values)
if do_save_plots:
    fig.savefig(pj(fig_dir, "integer_parameters_annotation_car.pdf"), 
            transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Panels showing how the CAR and TCR outputs combine
Illustrate two cases: antagonist and partial agonist.

The model curves were generated elsewhere, so here, we just import them.

In [None]:
#@title Tools
# Code taken from
# https://stackoverflow.com/questions/49223702/adding-a-legend-to-a-matplotlib-plot-with-a-multicolored-line
from matplotlib.legend_handler import HandlerLineCollection
from matplotlib.collections import LineCollection

class HandlerColorLineCollection(HandlerLineCollection):
    def create_artists(self, legend, artist ,xdescent, ydescent,
                        width, height, fontsize,trans):
        x = np.linspace(0,width,self.get_numpoints(legend)+1)
        y = np.zeros(self.get_numpoints(legend)+1)+height/2.-ydescent
        points = np.array([x, y]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        lc = LineCollection(segments, cmap=artist.cmap,
                     transform=trans)
        lc.set_array(x)
        lc.set_linewidth(artist.get_linewidth())
        return [lc]

def strip_names(df, char, axis=0):
    stripper = lambda x: x.strip(char)
    return df.rename(mapper=stripper, axis=axis)

def pad_names(df, char, axis=0):
    padder = lambda x: char + x + char
    return df.rename(mapper=padder, axis=axis)

def pad_index_names(idx, char, after=True):
    if after:
        padder = lambda x: char + x + char
    else:
        padder = lambda x: char + x
    idx = idx.set_names([padder(a) for a in idx.names])
    return idx

In [None]:
#@title
df_model2 = pad_names(pd.read_hdf(pj(res_dir, "output_contributions.h5"), key="df"), "$", axis=1)
df_model2.index = pad_index_names(df_model2.index, "\\", after=False)
df_model2.index = df_model2.index.set_names([a.replace("_", "^") for a in df_model2.index.names])
print(df_model2.index.names)
df_model2.index = df_model2.index.set_names("L^T", level=df_model2.index.names.index(r"\L^T"))
df_model2.index = pad_index_names(df_model2.index, "$")
df_model2 = df_model2.rename({"$Ratio$":"Ratio"}, axis=1)
with h5py.File(pj(res_dir, "output_contributions.h5"), "r") as h:
    ag_alone = float(h.get("ag_alone").get("Z_C")[()])
print("CD19 alone:", ag_alone)
tau_tcr_range = np.sort(df_model2.index.get_level_values(r"$\tau^T$").unique().values)
l_tcr_range = np.sort(df_model2.index.get_level_values(r"$L^T$").unique().values)
print("taus:", tau_tcr_range)

In [None]:
df_model2

In [None]:
# Two subplots, one for each TCR tau. Shared legend.
# New version with 4 mini-cartoons. Two plots on top of each other.
#tcr_ags_palette = ["#4c78bbff", "#d83ad8ff", "#DA3833"]  # antagonist, partial agonist, agonist
tcr_car_palette = {"TCR":"#009C4B", "CAR": "#016792"}
car_cd19_color = "grey"
labelsize = 6
linewidth = 1.5
antagonism_palette = np.asarray(sns.color_palette('PuOr_r', n_colors=100))[[13, 87]]  # Antagonism, enhancement
print(sns.color_palette(antagonism_palette).as_hex()[:])
fig, axes = plt.subplots(3, 1)
axes = axes.flatten()
fig.set_size_inches(1.5, 1.15*3)
len(tau_tcr_range)
for i, tcr_tau in enumerate(tau_tcr_range):
    ax = axes[i]
    y_c = df_model2.loc[(l_tcr_range, tcr_tau), r"$Z^C$"]
    y_t = df_model2.loc[(l_tcr_range, tcr_tau), r"$Z^T$"]
    y_tot = (y_c + y_t) / ag_alone
    y_c = y_c / ag_alone
    y_t = y_t / ag_alone

    # plot the same data on both axes
    # Code from https://stackoverflow.com/questions/38051922/how-to-get-differents
    # -colors-in-a-single-line-in-a-matplotlib-figure
    # select how to color
    cmap = mpl.colors.ListedColormap(antagonism_palette)
    norm = mpl.colors.BoundaryNorm([-np.inf, 1.0], cmap.N)
    # get segments
    xy = np.array([l_tcr_range, y_tot]).T.reshape(-1, 1, 2)
    segments = np.hstack([xy[:-1], xy[1:]])
    # make line collection
    lc = mpl.collections.LineCollection(segments, cmap=cmap, norm=norm,
                                        label="Total", lw=2.5)
    lc.set_array(y_tot)
    ax.add_collection(lc)
    ax.plot(l_tcr_range, y_c, label=r"$Z^C$", color=tcr_car_palette["CAR"], ls=":", lw=linewidth)
    ax.plot(l_tcr_range, y_t, label=r"$Z^T$", color=tcr_car_palette["TCR"], ls="--", lw=linewidth)

    if ax.get_ylim()[1] < 2.0:
        ax.set_ylim([ax.get_ylim()[0], 2.0])
    #ax.set_ylim(-0.1, df_model2["Ratio"].max()*1.1)
    ax.set_xlabel("TCR Ag density (#)", size=labelsize, labelpad=0.1)
    ax.set_xscale("log")
    ax.axhline(1.0, ls=':', color="k", lw=0.75)
    ax.set_ylabel(r"$\frac{Z(\mathrm{mix})}{Z^C(\mathrm{CD19})}$", size=labelsize+2, labelpad=0.2)
    # Create final legend inside the first plot
    if i == 0:
        handles, labels = ax.get_legend_handles_labels()
        legend = ax.legend(handles=handles, labels=labels,
                  handler_map={lc: HandlerColorLineCollection(numpoints=2)},
                  loc="lower center", bbox_to_anchor=(0.5, 1.0), frameon=False,
                  #frameon=True, edgecolor=(1, 1, 1, 0), facecolor=(1, 1, 1, 0.9), framealpha=0.9,
                  handlelength=1.5,
                  labelspacing=0.2, handletextpad=0.3, columnspacing=0.7,  borderaxespad=0.3, borderpad=0.2,
                  fontsize=labelsize, title="Output", title_fontsize=labelsize, ncol=3)

    # Force more log ticks
    locmaj = mpl.ticker.LogLocator(base=10,numticks=5)
    ax.xaxis.set_major_locator(locmaj)
    # Force minor log ticks
    locmin = mpl.ticker.LogLocator(base=10.0,subs=np.arange(0.1, 1.0, 0.1), numticks=5)
    ax.xaxis.set_minor_locator(locmin)
    ax.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())
    # Change label size
    ax.tick_params(axis='both', which='major', labelsize=labelsize*0.85, pad=2.0, length=2.0)
    ax.tick_params(axis='both', which='minor', pad=1.0, length=1.5)

    # Hide unnecessary spines to make room for diagrams.
    for s in ["top", "right"]:
        ax.spines[s].set_visible(False)

fig.tight_layout(w_pad=0.3, h_pad=1.2)
if do_save_plots:
    fig.savefig(pj(fig_dir, "tcr_car_output_contributions.pdf"),
            bbox_inches="tight", transparent=True, bbox_extra_artists=[legend])
plt.show()
plt.close()

In [None]:
# Horizontal layout of this figure, for PhD thesis
# Two subplots, one for each TCR tau. Shared legend.
# New version with 4 mini-cartoons. Two plots on top of each other.
#tcr_ags_palette = ["#4c78bbff", "#d83ad8ff", "#DA3833"]  # antagonist, partial agonist, agonist
tcr_car_palette = {"TCR":"#009C4B", "CAR": "#016792"}
car_cd19_color = "grey"
labelsize = 6
linewidth = 1.5
antagonism_palette = np.asarray(sns.color_palette('PuOr_r', n_colors=100))[[13, 87]]  # Antagonism, enhancement
print(sns.color_palette(antagonism_palette).as_hex()[:])
fig, axes = plt.subplots(1, 3)
axes = axes.flatten()
fig.set_size_inches(1.5*3, 1.25)
axes_titles = ["Antagonist", "Partial agonist", "Agonist"]
for i, tcr_tau in enumerate(tau_tcr_range):
    ax = axes[i]
    y_c = df_model2.loc[(l_tcr_range, tcr_tau), r"$Z^C$"]
    y_t = df_model2.loc[(l_tcr_range, tcr_tau), r"$Z^T$"]
    y_tot = (y_c + y_t) / ag_alone
    y_c = y_c / ag_alone
    y_t = y_t / ag_alone

    # plot the same data on both axes
    # Code from https://stackoverflow.com/questions/38051922/how-to-get-differents
    # -colors-in-a-single-line-in-a-matplotlib-figure
    # select how to color
    cmap = mpl.colors.ListedColormap(antagonism_palette)
    norm = mpl.colors.BoundaryNorm([-np.inf, 1.0], cmap.N)
    # get segments
    xy = np.array([l_tcr_range, y_tot]).T.reshape(-1, 1, 2)
    segments = np.hstack([xy[:-1], xy[1:]])
    # make line collection
    lc = mpl.collections.LineCollection(segments, cmap=cmap, norm=norm,
                                        label="Total", lw=2.5)
    lc.set_array(y_tot)
    ax.add_collection(lc)
    ax.plot(l_tcr_range, y_c, label=r"$Z^C$", color=tcr_car_palette["CAR"], ls=":", lw=linewidth)
    ax.plot(l_tcr_range, y_t, label=r"$Z^T$", color=tcr_car_palette["TCR"], ls="--", lw=linewidth)

    if ax.get_ylim()[1] < 2.0:
        ax.set_ylim([ax.get_ylim()[0], 2.0])
    #ax.set_ylim(-0.1, df_model2["Ratio"].max()*1.1)
    ax.set_xlabel("TCR Ag density (#)", size=labelsize, labelpad=0.1)
    ax.set_xscale("log")
    ax.axhline(1.0, ls=':', color="k", lw=0.75)
    ax.set_ylabel(r"$\frac{Z(\mathrm{mix})}{Z^C(\mathrm{CD19})}$", size=labelsize+2, labelpad=0.2)

    # Force more log ticks
    locmaj = mpl.ticker.LogLocator(base=10,numticks=5)
    ax.xaxis.set_major_locator(locmaj)
    # Force minor log ticks
    locmin = mpl.ticker.LogLocator(base=10.0,subs=np.arange(0.1, 1.0, 0.1), numticks=5)
    ax.xaxis.set_minor_locator(locmin)
    ax.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())
    # Change label size
    ax.tick_params(axis='both', which='major', labelsize=labelsize*0.85, pad=2.0, length=2.0)
    ax.tick_params(axis='both', which='minor', pad=1.0, length=1.5)

    # Hide unnecessary spines to make room for diagrams.
    for s in ["top", "right"]:
        ax.spines[s].set_visible(False)
    ax.set_title(axes_titles[i], fontsize=labelsize+1, y=0.95)

fig.tight_layout(w_pad=0.3, h_pad=1.2)
# Final legend outside of the rightmost plot. 
handles, labels = axes[-1].get_legend_handles_labels()
legend = ax.legend(handles=handles, labels=labels,
          handler_map={lc: HandlerColorLineCollection(numpoints=2)},
          loc="upper left", bbox_to_anchor=(1.0, 1.0), frameon=False,
          #frameon=True, edgecolor=(1, 1, 1, 0), facecolor=(1, 1, 1, 0.9), framealpha=0.9,
          handlelength=2.0,
          labelspacing=0.2, handletextpad=0.8, columnspacing=0.7,  borderaxespad=0.3, borderpad=0.2,
          fontsize=labelsize, title="Output", title_fontsize=labelsize, ncol=1)
if do_save_plots:
    fig.savefig(pj(fig_dir, "tcr_car_output_contributions_horizontal.pdf"),
            bbox_inches="tight", transparent=True, bbox_extra_artists=[legend])
plt.show()
plt.close()

# Further fits for different TCR/CAR types

## TCR/TCR parameters for 6F
Plots already generated in Figure1 generation, because the useful functions were defined there. 

In [None]:
# Read best kmf with k <= 1 for 6F TCR
with open(pj(mcmc_dir, "mcmc_analysis_tcr_tcr_6f.json"), "r") as h:
    lysis = json.load(h)
# Drop all points with k > klim, as large ks can be overfitted
for p in list(lysis.keys()):
    kmf_tuple = string_to_tuple(p)
    if kmf_tuple[0] > 1:
        lysis.pop(p)
best_kmf_6f, _, _ = find_best_grid_point(lysis, strat="best")
del lysis, kmf_tuple, p

In [None]:
# TODO: tune panel sizes for 6F TCR-TCR to the needs of Extended Data Figure 2
panel_dimensions3 = {
    "panel_width": 1.25 * scaleup,     # inches
    "axes_label_width": 0.25 * scaleup,    # inches
}
panel_dimensions3["panel_height"] = 0.68
panel_dimensions3["legend_width"] = panel_dimensions3["panel_width"]

In [None]:
# We keep 6F for the next figure. 
# Might be exactly the same as before, since it fitted nicely in one row. 
fig, axes, handles, labels = plot_tcr_tcr_fits(pj(res_dir, "dfs_model_data_ci_mcmc_tcr_tcr.h5"), 
                    "tcr_tcr_6f", best_kmf_6f, panel_dimensions3, perturbations_palette, col_wrap=2)
for i, ax in enumerate(axes.flat):
    ax.title.set_color(agconc_palette[i])
if do_save_plots:
    fig.savefig(pj(fig_dir, "fit_confidence_tcr_tcr_6f.pdf"), transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Make a stand-alone legend
fig, ax, leg = standalone_legend(handles=handles, handler_map=labels, 
                                 frameon=False, handlelength=3.5)
if do_save_plots:
    fig.savefig(pj(fig_dir, "fit_confidence_tcr_tcr_6f_legend.pdf"), 
            transparent=True, bbox_inches="tight", bbox_extra_artists=(leg,))
plt.show()
plt.close()

## 6F TCR/TCR MCMC cornerplot and autocorrelation

In [None]:
# TODO: tune panel sizes for 6F TCR-TCR to the needs of Extended Data Figure 2
panel_dimensions4 = {
    "panel_width": 1.25 * scaleup,     # inches
    "axes_label_width": 0.25 * scaleup,    # inches
}
panel_dimensions4["panel_height"] = 0.9
panel_dimensions4["legend_width"] = panel_dimensions3["panel_width"]

In [None]:
corner_plot_kwargs = {
    "scaleup": scaleup,
    "small_lw": 0.8, 
    "truth_lw": 1.25, 
    "small_markersize": 1.0, 
    "truth_color": np.asarray((0.0, 156.0, 75.0, 255.0)) / 255.0,  # deep key lime green
    "reverse_plots": False, 
    "labelpad": 0.15, 
    "n_times_height": 2, 
    "n_extra_x_labels": 2, 
    "n_extra_y_labels": 2
}
tick_label_size_mcmc = 5.0
nice_pnames = [r"$\varphi^T$", r"$C^T_{m, th}$", 
               r"$I^T_{th}$", r"$\psi^T_0$"]

fig = corner_plot_mcmc(pj(mcmc_dir, "mcmc_results_tcr_tcr_6f.h5"), 
                       pj(mcmc_dir, "mcmc_analysis_tcr_tcr_6f.json"), 
                 best_kmf_6f, panel_dimensions4, pnames=nice_pnames, sizes_kwargs=corner_plot_kwargs)

# Adjust tick labels to gain a bit of space
for ax in fig.axes:
    ax.xaxis.set_tick_params(labelsize=tick_label_size_mcmc*scaleup, pad=0.5*scaleup)
    if ax.yaxis.get_label() is not None:
        ylbl_coords = ax.yaxis.get_label().get_position()
        ax.yaxis.set_label_coords(ylbl_coords[0] + 0.12, ylbl_coords[1])
    ax.yaxis.set_tick_params(labelsize=tick_label_size_mcmc*scaleup, pad=0.5*scaleup)

if do_save_plots:
    fig.savefig(pj(fig_dir, "mcmc_cornerplot_tcr_tcr_6f.pdf"), 
            transparent=True, bbox_inches="tight", dpi=600)
plt.show()
plt.close()

In [None]:
panel_dimensions2 = {
    "panel_width": 5.5 / 5.0 * scaleup,     # inches
    "axes_label_width": 0.25 * scaleup,    # inches
    "panel_height": 5.5 / 5.0 * scaleup  # inches
}
panel_dimensions2["legend_width"] = panel_dimensions2["panel_width"]

panel_dimensions_autocorr = {
    'panel_width': 1.1,
    'axes_label_width': 0.25,
    'panel_height': 1.0,
    'legend_width': 0.4
}

In [None]:
nice_pnames = [r"$\varphi^T$", r"$C^T_{m, th}$", r"$I^T_{th}$", r"$\psi^T_0$"]

# Honestly, so clear we ran for long enough, we don't need to show 
# the autocorrelation estimated from 50 % of time
fig, ax, leg, _, _ = plot_autocorr(pj(mcmc_dir, "mcmc_results_tcr_tcr_6f.h5"), 
                pj(mcmc_dir, "mcmc_analysis_tcr_tcr_6f.json"), 
                best_kmf_car, panel_dimensions_autocorr, pnames=nice_pnames, 
                show_50=False, show_tmax=False)
if do_save_plots:
    fig.savefig(pj(fig_dir, "mcmc_autocorrelation_tcr_tcr_6f.pdf"), transparent=True, bbox_inches="tight", 
            bbox_extra_artists=(leg,))
plt.show()
plt.close()

## Autocorrelation functions for all main figures together
Fit in one column of 3 plots in the supplementary model figure. 

In [None]:
# Read best kmf for AKPR and best m for SHP-1
with open(pj(mcmc_dir, "mcmc_analysis_akpr_i.json"), "r") as h:
    best_kmf_akpr, _, _ = find_best_grid_point(json.load(h), strat="best")
with open(pj(mcmc_dir, "mcmc_analysis_shp1.json"), "r") as h:
    best_m_shp1, _, _ = find_best_grid_point(json.load(h), strat="best")  

In [None]:
fig, axes = plt.subplots(3, 1)
pdim = panel_dimensions_autocorr
fig.set_size_inches(pdim["panel_width"] + pdim["legend_width"]*2 + pdim["axes_label_width"], 
                   pdim["panel_height"]*3 + pdim["axes_label_width"]*1.75)

# First plot: Francois 2013 model, TCR-TCR
nice_pnames = [r"$\varphi^T$", r"$C^T_{m, th}$", r"$I_{tot}$"]
_, _, leg1, _, _ = plot_autocorr(pj(mcmc_dir, "mcmc_results_shp1.h5"), 
                pj(mcmc_dir, "mcmc_analysis_shp1.json"), 
                best_m_shp1, pdim, pnames=nice_pnames, 
                show_50=False, show_tmax=False, figax=[fig, axes[0]])
axes[0].set_xlabel("")
axes[0].set_title("Classical AKPR, TCR/TCR", fontsize=6,  fontweight="semibold")

# Second plot: revised AKPR, TCR-TCR
nice_pnames = [r"$\varphi^T$", r"$C^T_{m, th}$", r"$I^T_{th}$", r"$\psi^T_0$"]
_, _, leg2, _, _ = plot_autocorr(pj(mcmc_dir, "mcmc_results_akpr_i.h5"), 
                pj(mcmc_dir, "mcmc_analysis_akpr_i.json"), 
                best_kmf_akpr, pdim, pnames=nice_pnames, 
                show_50=False, show_tmax=False, figax=[fig, axes[1]])
axes[1].set_xlabel("")
axes[1].set_title("Revised AKPR, TCR/TCR", fontsize=6, fontweight="semibold")

# Third plot: reivsed AKPR, TCR/CAR
nice_pnames = [r"C^C_{m, th}", r"I^C_{th}", r"{\gamma^T}_C", r"{\gamma^C}_T", r"\tau^T_c", r"\tau^C_c"]
nice_pnames = ["$" + p + "$" for p in nice_pnames]
_, _, leg3, _, _ = plot_autocorr(pj(mcmc_dir, "mcmc_results_tcr_car_both_conc.h5"), 
                pj(mcmc_dir, "mcmc_analysis_tcr_car_both_conc.json"), 
                best_kmf_car, pdim, pnames=nice_pnames, 
                show_50=False, show_tmax=False, figax=[fig, axes[2]])
axes[2].set_title("Revised AKPR, TCR/CAR", fontsize=6, fontweight="semibold")

if do_save_plots:
    fig.savefig(pj(fig_dir, "mcmc_autocorrelation_combined.pdf"), transparent=True, bbox_inches="tight", 
            bbox_extra_artists=(leg1,leg2,leg3))
plt.show()
plt.close()

## 6F TCR integer parameters

In [None]:
tcr_6f_integers_names = ["N^T", "k^T_I", "m^T", "f^T"]
tcr_6f_integers_values = [4, *string_to_tuple(best_kmf_6f)]
fig, ax, text = standalone_parameter_values(tcr_6f_integers_names, tcr_6f_integers_values, n_per_line=2)
if do_save_plots:
    fig.savefig(pj(fig_dir, "integer_parameters_annotation_tcr_6f.pdf"), 
            transparent=True, bbox_inches="tight")
plt.show()
plt.close()

## CAR/TCR parameters for 6F
Threshold $\tau_c^T$ and max. amplitude of $Z^T$. Fitted from data on response of 6F T cells to TCR antigen alone. Reproduce fits here. 

In [None]:
def process_6f_tcr_ampli(full_df):
    """ Find amplitude of output of 6F TCR compared to 6Y CAR,
    based on mock CAR T cell data. This is model-independent,
    just comparing amplitude of response to CD19 only in both CAR types.
    Args:
        full_df (pd.DataFrame): raw cytokine dataframe
    Returns:
        relative_ampli_6f (float): ratio of max. response of 6F TCR over 6Y TCR
    """
    df = full_df.xs("None", level="CAR_Antigen")

    # Levels left in this df: "Cytokine", "Tumor", "TCR_ITAMs", "CAR_ITAMs",
    # "TCR_Antigen", "TCR_Antigen_Density", "Time".
    df = (df.xs("1uM", level="TCR_Antigen_Density")
            .xs("E2APBX", level="Tumor")
            .xs("IL-2", level="Cytokine")
            .stack("Time"))
    # Compute max amplitude based on A2 and N4 peptides, which both saturate
    # the TCR response in blast (mock) T cells
    df = df.loc[df.index.isin(['N4', 'A2'], level="TCR_Antigen")]
    # Look at the ratio of averaged logs across times, repeats and agonists
    df = df.groupby("TCR_ITAMs").apply(geo_mean_apply)
    relative_ampli_6f = df["4"] / df["10"]
    return relative_ampli_6f


def loglin_hill_fit(x, a, b, h, k):
    """ Given x, return log(Hill(x)). Fit a, b, h parameters """
    xnormk = (x / h)**k
    hill = a * xnormk / (xnormk + 1.0) + b
    return np.log(hill)


def process_6f_tcr_thresh_fact(full_df, pep_tau_map):
    """ Find ratio of TCR tau threshold of 6F vs 6Y, based on a Hill fit of cytokine
    versus peptide binding time for both TCR types.
    These thresholds may be different from those used to computed threshold
    on C_N, because C_N is a complicated function of tau; yet we assume
    the relative tau thresholds between 6F and 6Y is approx. conserved.
    This is simpler than computing C_N for both receptor types.

    Args:
        df_tcr_only (pd.DataFrame)
        pep_tau_map (dict)
    """
    df = full_df.xs("None", level="CAR_Antigen")
    data_spleen_lvl = pd.Series(df.index.get_level_values("Data")
                                + "-" + df.index.get_level_values("Spleen"),
                                name="Data-spleen")
    df = df.set_index(data_spleen_lvl, append=True)
    df = df.droplevel(["Data", "Spleen"])

    # Levels left in this df: "Cytokine", "Tumor", "CAR_ITAMs", "TCR_Antigen",
    # "TCR_Antigen_Density", "Time", "Data-spleen".
    df = (df.xs("1uM", level="TCR_Antigen_Density")
            .xs("E2APBX", level="Tumor")
            .xs("IL-2", level="Cytokine"))
    # Geometric average of IL-2 over time is what we fit against tau
    df = geo_mean(df, axis=1)
    
    # Determine amplitudes from N4 and A2 only, as I actually do 
    tcr_amplis = (df.loc[df.index.isin(["N4", "A2"], level="TCR_Antigen")]
                .groupby(["TCR_ITAMs"]).apply(geo_mean_apply))
    
    # Convert peptide names to taus for what follows
    df = df.rename(pep_tau_map, level="TCR_Antigen")

    # Fit a Hill curve on each data-spleen replicate for each TCR ITAM number
    df = df.reorder_levels(["Data-spleen", "TCR_ITAMs", "CAR_ITAMs", "TCR_Antigen"])
    gps = df.groupby(["Data-spleen", "TCR_ITAMs", "CAR_ITAMs"])
    tau_threshs = pd.Series(np.zeros(len(gps.groups)),
            index=pd.MultiIndex.from_tuples(gps.groups.keys(), names=gps.keys))
    #tau_amplis = tau_threshs.copy()
    for k in gps.groups.keys():
        lower, upper = df.loc[k].min(), df.loc[k].max()
        ampli = upper - lower
        # Limits on ampli, background, tau_thresh, hill power k
        pbounds = [(0.25*ampli, 4.0*ampli), (0.8*lower, 1.2*lower), (0.1, 10.0), (1.0, 16.0)]
        pbounds = np.asarray(list(zip(*pbounds)))
        p0 = (pbounds[0] + pbounds[1]) / 2.0
        popt, _ = curve_fit(loglin_hill_fit, xdata=df.loc[k].index.get_level_values("TCR_Antigen").values,
                           ydata=np.log(df.loc[k].values), p0=p0, bounds=pbounds)
        tau_threshs.loc[k] = popt[2]
        #tau_amplis.loc[k] = popt[0]

    # Compute average tau threshold and amplitudes for 6F and 6Y, return their difference
    tau_threshs = tau_threshs.groupby("TCR_ITAMs").mean()
    tau_threshs.name = "tau^T_c"
    #tau_amplis = tau_amplis.groupby("TCR_ITAMs").apply(geo_mean_apply)
    #tau_amplis.name = "Z^T_amplitude"
    
    # Compute average response curves and log std
    df_mean = df.groupby(["TCR_ITAMs", "TCR_Antigen"]).apply(geo_mean_apply)
    df_std = np.log10(df).groupby(["TCR_ITAMs", "TCR_Antigen"]).std()
    
    # Also return the actual fitted thresholds
    return df_mean, df_std, tau_threshs, tcr_amplis

In [None]:
# Process response curves like for the threshold factor determination
# Average across CAR types, put error bars
# Plot them on top of one another, annotate amplitudes and thresholds ratios

# x axis has to be model tau: fine, process_6f_tcr_thresh_fact does it
with open(pj(data_dir, "pep_tau_map_ot1.json"), "r") as h:
    pep_tau_map_ot1 = json.load(h)

res = process_6f_tcr_thresh_fact(full_car_data, pep_tau_map_ot1)
df_6y6f, df_6y6f_log10std, threshs, amplis = res

# Plot response curves and annotate appropriately
fig, ax = plt.subplots()
palette = {"10": perturbations_palette["None"], "4": perturbations_palette["TCRNum"]}
markers = {"10":"o", "4":"s"}
for n in df_6y6f.index.get_level_values("TCR_ITAMs").unique():
    y = df_6y6f.loc[n].values
    x = df_6y6f.loc[n].index.get_level_values("TCR_Antigen").values
    yerr = df_6y6f_log10std.loc[(n, x)].values
    y_upper = 10.0**(np.log10(y) + yerr) - y
    y_lower = y - 10.0**(np.log10(y) - yerr)
    ax.errorbar(x, y, yerr=np.stack([y_lower, y_upper]), color=palette[n], 
                label=n, marker=markers[n], linewidth=1.0, alpha=0.5)

# Annotate amplitudes and thresholds with vertical and horizontal lines
ax.set(yscale="log", xlabel=r"TCR antigen $\tau^T$ (s)", 
       ylabel=r"$\langle$IL-2$\rangle_t$ response to TCR" + "\nantigen, no CD19 (nM)")
xlims = list(ax.get_xlim())
xlims[1] += 1.5
ylims = ax.get_ylim()
for n in df_6y6f.index.get_level_values("TCR_ITAMs").unique():
    ax.plot([pep_tau_map_ot1["A2"]-1.0, xlims[1]], [amplis.loc[n]]*2, ls="--", color=palette.get(n))
    yfrac = float((np.log(amplis.at[n]*.5) - np.log(ylims[0])) / np.diff(np.log(ylims))[0])
    ax.axvline(threshs.loc[n], ls="-", color=palette.get(n), ymin=0, ymax=yfrac)
#reset x lims
ax.set_xlim(xlims)
# Indicate that we are fitting the ratios between 6Y and 6F
ax.annotate("", xy=(threshs["4"], df_6y6f.min()), xytext=(threshs["10"], df_6y6f.min()),
            arrowprops={"arrowstyle":"->", "color":palette["4"]}, ha="center", va="center")
ax.annotate("Fit\n" + r"$\tau^{T,4}_c / \tau^{T,10}_c$", xy=(np.mean(threshs), df_6y6f.min()*1.5), 
            ha="center", va="bottom", color=palette["4"])

# Amplitudes ratio
xn4 = pep_tau_map_ot1["N4"] + 1.5
ax.annotate("", xy=(xn4, amplis["10"]*1.25), xytext=(xn4, amplis["4"]*.8),
            arrowprops={"arrowstyle":"->", "color":palette["4"]}, ha="center", va="center")
ax.annotate("Fit\n" + r"$Z^{T,4} / Z^{T,10}$", xy=(xn4, amplis["4"]*0.75), 
            ha="center", va="top", color=palette["4"])

ax.legend(title="TCR ITAM #", frameon=False, loc="lower left", bbox_to_anchor=(0.65, 0.0), 
         labelspacing=0.25)
fig.set_size_inches(panel_dimensions2["panel_width"] + 2*panel_dimensions2["axes_label_width"]
                    + 0.5*panel_dimensions2["legend_width"], 
                    panel_dimensions2["panel_height"] + panel_dimensions2["axes_label_width"])
fig.tight_layout()
if do_save_plots:
    fig.savefig(pj(fig_dir, "fit_tcr_6f_amplitude_threshold.pdf"), transparent=True)
plt.show()
plt.close()

In [None]:
amplis

## CAR parameters for 1-ITAM CAR

### Print parameters

In [None]:
car_1_integers_names = ["N^C", "k^C_I", "m^C", "f^C"]
car_1_integers_values = [1, best_kmf_car[0], 1, 1]
fig, ax, text = standalone_parameter_values(car_1_integers_names, car_1_integers_values, n_per_line=1)
if do_save_plots:
    fig.savefig(pj(fig_dir, "integer_parameters_annotation_car_1itam.pdf"), 
            transparent=True, bbox_inches="tight")
plt.show()
plt.close()

### Amplitude of $Z^{C,1}$
Adjusting the amplitude of $Z^{C,1}$ to match the ratio of 1-ITAM vs 3-ITAM CAR responses to CD19 only, no TCR antigen. 

In [None]:
# Process the data as in the functions used to compute correction factors
def process_1itam_car_ampli(full_df):
    """ Find amplitude of output of 1-ITAM CAR compared to 3-ITAM CAR,
    based on data. 
    Args:
        full_df (pd.DataFrame): raw cytokine dataframe

    Returns:
        
        max_ampli (float): ratio of response of 1-ITAM CAR over 3-ITAM CAR
    """
    df = (full_df.xs("None", level="TCR_Antigen")
                        .xs("CD19", level="CAR_Antigen"))

    # Levels left in this df: "Cytokine", "Tumor", "TCR_ITAMs", "CAR_ITAMs",
    # "TCR_Antigen_Density", "Time"
    # Look at coupling with both 4- and 10-ITAM TCR
    df = (df.xs("E2APBX", level="Tumor")
            .xs("IL-2", level="Cytokine")
            .stack("Time"))
    # Look at time-averaged series TODO. Single time points too spread out
    # Will plot dots for different replicates, TCR_ITAMs, TCR_Antigen_Densities
    df = geo_mean_levels(df, ["Time"])
    # Drop T cells without CAR
    df = df.loc[df.index.isin(["1", "3"], level="CAR_ITAMs")]
    
    return df

In [None]:
def plot_car_1vs3itams(df_ampli, ylabel):
    palette = {"1":perturbations_palette["CARNum"], "3":perturbations_palette["None"]}
    jitter_ampli = 0.1
    fig, ax = plt.subplots()
    fig.set_size_inches(panel_dimensions2["panel_width"] + 2*panel_dimensions2["axes_label_width"], 
                        panel_dimensions2["panel_height"])

    dots_palette = {"10":"grey", "4":"xkcd:sage"}
    itam1_points = df_ampli.xs("1", level="CAR_ITAMs")
    itam3_points = df_ampli.xs("3", level="CAR_ITAMs")

    for tcr_num in df_1itam_ampli.index.get_level_values("TCR_ITAMs").unique():
        y = itam1_points.xs(tcr_num, level="TCR_ITAMs").values
        jitter = jitter_ampli*np.random.normal(size=y.size)
        ax.plot(1.0 + jitter, y, marker="o", ls="none", color=dots_palette[tcr_num], alpha=0.7)
        y = itam3_points.xs(tcr_num, level="TCR_ITAMs").values
        jitter = jitter_ampli*np.random.normal(size=y.size)
        ax.plot(3.0 + jitter, y, marker="o", ls="none", color=dots_palette[tcr_num], alpha=0.7)

    # Show geometric means and line in between
    itam_means, itam_lowers, itam_uppers = [], [], []
    for x in [itam1_points, itam3_points]:
        xmean = geo_mean(x, axis=0)
        itam_means.append(xmean)
        log_std = np.std(np.log(x))
        itam_lowers.append(xmean - np.exp(np.log(xmean) - log_std))
        itam_uppers.append(np.exp(np.log(xmean) + log_std) - xmean)
    for i, n in zip(range(2), range(1, 4, 2)):
        ax.errorbar([n], itam_means[i], yerr=np.asarray([[itam_lowers[i], itam_uppers[i]]]).T, 
                   marker="s", color=palette[str(n)], zorder=20)
    ax.plot([1, 3], itam_means, ls="-", color=palette["1"], zorder=21)
    ax.set_xlim([0, 4])
    ax.invert_xaxis()
    ax.set(xlabel="CAR ITAM #", yscale="log", ylabel=ylabel)

    # Manual legend
    leg_handles = [mpl.lines.Line2D([0], [0], color=clr, label=lbl, marker="o", alpha=0.7, ls="none") 
                   for lbl, clr in dots_palette.items()]
    ax.set_xticks([1, 3])
    ax.set_xticklabels([1, 3])
    ax.legend(handles=leg_handles, frameon=False, title="TCR ITAMs", handlelength=2.0, 
              handletextpad=0.05, loc="upper left", bbox_to_anchor=(0.8, 1.1),)
    fig.tight_layout()
    return fig, ax, itam_means, palette

In [None]:
df_1itam_ampli = process_1itam_car_ampli(full_car_data)
print("Replicates are from levels:", df_1itam_ampli.droplevel("CAR_ITAMs").index.names)
ylabel = r"$\langle$IL-2$\rangle_t$ response to" + "\nCD19, no TCR Ag (nM)"
fig, ax, itam_means, palette = plot_car_1vs3itams(df_1itam_ampli, ylabel)

# Annotate "Fit ratio" between points. Interpolate linearly in log y scale
xmid = (1.0+3.0)/2.0
ymid = 10**((np.log10(itam_means[1]) - np.log10(itam_means[0])) / 2 * (xmid - 1.0) + np.log10(itam_means[0]))
ax.annotate("Fit\n"+r"$Z^C_1 / Z^C_3$", xy=(xmid-0.25, ymid*1.25), xycoords="data", ha="center", va="bottom", 
           color=palette["1"], fontweight="semibold")
if do_save_plots:
    fig.savefig(pj(fig_dir, "fit_1itamcar_cd19only_response_ratio.pdf"), transparent=True, bbox_inches="tight", 
           bbox_extra_artists=(ax.get_legend(),))
plt.show()
plt.close()

### Amplitude of $Z^T$ affected by presence of 1-ITAM CAR

In [None]:
def process_1itam_effect_tcr_ampli(full_df):
    """ Find relative amplitude of TCR output in the presence of 1-ITAM CAR
    but without CAR antigen, compared to TCR output in the presence of 3-ITAM
    CAR. This looks at response to TCR antigen only, no CD19: separate
    data from what we are trying to predict.

    This amplitude factor effect will be multiplied with the 6F TCR amplitude
    effect for the double perturbed condition.

    We take the geometric average of this factor across 6F and 6Y cells.
    Just as for 6F correction, we average across CAR types.

    Args:
        full_df (pd.DataFrame): raw cytokine dataframe

    Returns:
        effect_tcr_ampli_1itam (float): ratio of max. response of TCR with
            1 ITAM CAR over response with 3-ITAM CAR, averaged
            across 6Y and 6F TCR genotypes.
    """
    df = full_df.xs("None", level="CAR_Antigen")

    # Levels left in this df: "Cytokine", "Tumor", "TCR_ITAMs", "CAR_ITAMs",
    #  "TCR_Antigen", "TCR_Antigen_Density", "Time".
    df = (df.xs("1uM", level="TCR_Antigen_Density")
            .xs("E2APBX", level="Tumor")
            .xs("IL-2", level="Cytokine")
            .stack("Time"))
    # Compute max amplitude based on A2 and N4 peptides, which both saturate
    # the TCR response in blast (mock) T cells
    df = df.loc[df.index.isin(['N4', 'A2'], level="TCR_Antigen")]
    # Average over time; other levels will be plotted as replicates. 
    df = geo_mean_levels(df, ["Time"])
    return df


In [None]:
df_1tcr_ampli = process_1itam_effect_tcr_ampli(full_car_data)
print("Replicates are from levels:", df_1tcr_ampli.droplevel("CAR_ITAMs").index.names)

ylabel = r"$\langle$IL-2$\rangle_t$ response to TCR" + "\nagonist, no CD19 (nM)"
fig, ax, itam_means, palette = plot_car_1vs3itams(df_1tcr_ampli, ylabel)

# Annotate "Fit ratio" between points. Interpolate linearly in log y scale
xmid = (1.0+3.0)/2.0
ymid = 10**((np.log10(itam_means[1]) - np.log10(itam_means[0])) / 2 * (xmid - 1.0) + np.log10(itam_means[0]))
ax.annotate("Fit\n"+r"$Z^T_1/Z^T_3$", xy=(xmid-0.2, ymid*0.8), xycoords="data", ha="center", va="top", 
           color=palette["1"], fontweight="semibold")
if do_save_plots:
    fig.savefig(pj(fig_dir, "fit_1itamcar_tcronly_response_ratio.pdf"), transparent=True, bbox_inches="tight", 
           bbox_extra_artists=(ax.get_legend(),))
plt.show()
plt.close()

# Model predictions heatmap
Load the model predictions plotted in main figure 2 and make a heatmap with it. Put it next to the data heatmap. Also, show a single color code for perturbations. 

In [None]:
antagDfData = pd.read_hdf(pj(data_dir, "antagonism", "ot1_car_antagonism_df_with_ratio.h5"))
antagDfData

In [None]:
## Processing from Sooraj
#@title Tools
def zero_centered_min_max_scaling(dataframe):
    """
    Scale the numerical values in the dataframe to be between -1 and 1, preserving the
    signal of all values.
    """
    df_copy = dataframe.copy(deep=True)
    for column in df_copy.columns:
        max_absolute_value = df_copy[column].abs().max()
        df_copy[column] = df_copy[column] / max_absolute_value
    return df_copy

def zero_centered_min_max_scaling2(dataframe):
    """
    Scale the numerical values in the dataframe to be between -1 and 1
    """
    df_copy = dataframe.copy(deep=True)
    for i,column in enumerate(df_copy.columns):
        max_pos = df_copy[column].max()
        min_pos = df_copy[column].min()
        for row in range(df_copy.shape[0]):
            val = df_copy.iloc[row,i]
            if val <= 0:
                val= -1*val/min_pos
            else:
                val = val/max_pos
            df_copy.iloc[row,i] = val
    return df_copy

In [None]:
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]

with open(pj(data_dir, "pep_tau_map_ot1.json"), "r") as h:
    pep_tau_map_ot1 = json.load(h)

In [None]:
def antagonism_heatmap(antagDf, perturb_palette, source="data", dendro=None, 
                 show_dendro=True, size_inches=plt.rcParams["figure.figsize"]):
    """ antagonismDf: Either continuous curves from model 
    or discrete values at peptides from experiment. 
    columns should be TCR antigen quality, rows are conditions, perturbations, etc. 
    """
    # Prepare a color mapping each value in the Perturbation level to a color
    # in the palette
    dfDict = {}
    for k in antagDf.index.get_level_values("Perturbation").unique():
        tcr_ag_dens, car_nb, tcr_nb = k.split("+")
        palette_lbl = []
        if tcr_ag_dens == "Low":
            palette_lbl.append("AgDens")
        if int(car_nb) < 3:
            palette_lbl.append("CARNum")
        if int(tcr_nb) < 10:
            palette_lbl.append("TCRNum")
        palette_lbl = "_".join(palette_lbl)
        if palette_lbl == "": 
            palette_lbl = "None"
        dfDict[k] = perturb_palette.get(palette_lbl)

    perturbs = antagDf.reset_index().pop("Perturbation")
    row_colors = perturbs.map(dfDict)
    if source == "data":
        xrange = ['E1','G4','V4','T4','Q4','A2','N4']
        plotDf = antagDf.fillna(value=0).reset_index().loc[:, xrange]
    else:
        xrange = np.sort(antagDf.columns.get_level_values("TCR_Antigen").unique())
        plotDf = antagDf.reset_index().loc[:, xrange]
    cm = sns.clustermap(plotDf, row_colors=row_colors, col_cluster=False, cmap='PuOr_r', 
                        cbar_kws={'shrink': 0.5,'orientation':'horizontal'}, 
                        figsize=size_inches, dendrogram_ratio=0.1,
                        colors_ratio=0.05, tree_kws={"lw":0.5}, 
                        row_linkage=dendro)
    cm.ax_row_dendrogram.set_visible(show_dendro)

    g = cm.ax_heatmap
    cbar = g.collections[0].colorbar
    cbar.set_ticks([])
    #g.set_xticks([])
    g.set_yticks([])
    for spine in g.spines.values():
        spine.set_visible(True)
    cbar.outline.set_edgecolor('k')
    cbar.outline.set_linewidth(plt.rcParams["axes.linewidth"])
    #g.set_xlabel('Stimulation')
    g.set_xlabel('')

    cm.ax_cbar.set_position([cm.ax_col_dendrogram.get_position().x0, 
                             cm.ax_col_dendrogram.get_position().y0+0.02, 
                             cm.ax_col_dendrogram.get_position().width, 0.02])
    cm.ax_cbar.set_title('')
    cm.ax_cbar.tick_params(axis='x', length=10)

    cm.ax_heatmap.annotate('',xytext=(0.5,1.135),xy=(1,1.135), xycoords='axes fraction', va="top",
                           arrowprops=dict(arrowstyle='-|>',color='darkorange'))
    cm.ax_heatmap.annotate('',xytext=(0,1.135),xy=(0.5,1.135),xycoords='axes fraction', va="top",
                           arrowprops=dict(arrowstyle='<-',color='purple'))
    t = cm.ax_heatmap.text(0.25,1.2,'Antagonism',va='center',ha='center',color='purple',
                           zorder=100,transform=cm.ax_heatmap.transAxes,fontweight='bold')
    #t.set_bbox(dict(facecolor='white', edgecolor='white'))
    t = cm.ax_heatmap.text(0.75,1.2,'Enhancement',va='center',ha='center',color='darkorange',
                           zorder=100,transform=cm.ax_heatmap.transAxes,fontweight='bold')
    #t.set_bbox(dict(facecolor='white', edgecolor='white'))
    c1 = plt.Circle((0.5, 1.2), 0.01, color='k', clip_on=False,transform=cm.ax_heatmap.transAxes,zorder=100)
    #cm.ax_heatmap.add_patch(c1)

    cm.ax_row_colors.set_xticks([])
    if source == "model":
        cm.ax_heatmap.set_xticks(cm.ax_heatmap.get_xticks())
        xticklabels = [str(round(float(a.get_text()), 1)) for a in cm.ax_heatmap.get_xticklabels()]
        cm.ax_heatmap.set_xticklabels(xticklabels)
    cm.ax_row_colors.set_xticklabels([])
    
    # Offsets
    parent_shift = -0.07
    if source == "data":
        xlb_shift = -0.15
    elif source == "model":
        xlb_shift = -0.2

    cm.ax_heatmap.annotate('',xytext=(0,xlb_shift),xy=(1,xlb_shift),
                    arrowprops=dict(arrowstyle='->',color='k'), xycoords='axes fraction')
    arrow_lbl = 'TCR Ag Strength' if source == "data" else r"TCR antigenicity, $\tau^T$ (s)"
    t = cm.ax_heatmap.text(0.5, xlb_shift, arrow_lbl, va='center',ha='center',
                           color='k',zorder=100,transform=cm.ax_heatmap.transAxes)
    t.set_bbox(dict(facecolor='white', edgecolor='white'))

    #cm.ax_heatmap.annotate('(',xy=(0.028,-0.065), xycoords='axes fraction',fontsize=18)
    #cm.ax_heatmap.annotate(')+CD19',xy=(0.957,-0.065), xycoords='axes fraction',fontsize=18)
    if source == "data":
        cm.ax_heatmap.annotate('(', xy=(-0.05,parent_shift), xycoords='axes fraction')
        cm.ax_heatmap.annotate(')+CD19', xy=(1.01,parent_shift), xycoords='axes fraction')
    if source == "model":
        cm.ax_heatmap.annotate('+CD19',xy=(1.01,parent_shift), xycoords='axes fraction')
    return cm, cm.dendrogram_row.linkage

## Prepare antagonism data

In [None]:
temp = antagDfData.query("CAR_Antigen == 'CD19' and `CAR_ITAM_Number` != '0'")
#temp = temp.query("Cytokine == ['IL-2','TNFa']")
temp1 = temp.query("Cytokine == ['IL-2','TNFa']")
temp2 = temp.query("Cytokine == 'IFNg' and Time == [1,3,6]")
#temp2 = temp.query("Cytokine != ['IL-2','TNFa']")
temp1 = temp1[temp1['Ratio'] < 2e2]
temp2 = temp2[temp2['Ratio'] < 2e1]
temp = pd.concat([temp1,temp2])
o = temp.index.unique('TCR_Antigen').tolist()[::-1]
temp = pd.concat([temp.query("TCR_Antigen == @x") for x in o])
pertDict = {'1uM-10':'None','1uM-4':'Fewer TCR ITAMs','1nM-10':'Less TCR antigen density'}
temp['Condition'] = [pertDict['-'.join([x,y])] for x,y in zip(temp.index.get_level_values('TCR_Antigen_Density'),
                                                              temp.index.get_level_values('TCR_ITAM_Number'))]

temp = temp.set_index(['Condition'],append=True).iloc[:,1].unstack('TCR_Antigen')
temp = temp[['E1','G4','V4','T4','Q4','A2','N4']]
temp = np.log2(temp)

clippedValues = np.clip(temp,a_min=-2,a_max=6)
with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    clippedValues = (zero_centered_min_max_scaling2(clippedValues.unstack('Cytokine')
                    .stack('TCR_Antigen')).unstack('TCR_Antigen').stack('Cytokine'))
clippedValues = clippedValues[['E1','G4','V4','T4','Q4','A2','N4']]

clippedValues = clippedValues.droplevel(['Condition','CAR_Antigen','Tumor'])
clippedValues = clippedValues.swaplevel(-1,0).swaplevel(-1,-2)
clippedValues = (pd.concat([clippedValues.query("Cytokine == @x") for x in ['IL-2','TNFa','IFNg']])
                .rename({'TNFa':'TNF','IFNg':r'IFN-$\gamma$'}).query("Cytokine == 'IL-2'"))

clippedValues = clippedValues.groupby([x for x in clippedValues.index.names if x != 'Time']).mean()
clippedValues = clippedValues.swaplevel(-1,-3).swaplevel(-2,-3)
clippedValues = clippedValues.sort_index()
clippedValues = clippedValues.rename(mapper={'1nM':'Low', '1uM':'High'}, axis=0, level='TCR_Antigen_Density')
clippedValues = clippedValues.droplevel('Cytokine')

### DIFFERENCE HERE: geometric average replicates. Just average, this is already log
clippedValues = clippedValues.groupby(
    [x for x in clippedValues.index.names if x not in ["Data", "Spleen"]]).mean()

# Add a single level containing all perturbations
clippedValues["Perturbation"] = (clippedValues.index.get_level_values("TCR_Antigen_Density") 
                              + "+" + clippedValues.index.get_level_values("CAR_ITAM_Number") 
                              + "+" + clippedValues.index.get_level_values("TCR_ITAM_Number") )
clippedValues = clippedValues.set_index("Perturbation", append=True)
clippedValues

In [None]:
heatmap_figsize = (1.9, 2.7)

In [None]:
clustmap, dendrogram = antagonism_heatmap(clippedValues, perturbations_palette, size_inches=heatmap_figsize)
if do_save_plots:
    clustmap.fig.savefig(pj(fig_dir, "heatmap_data.pdf"), transparent=True, 
                     dpi=200, bbox_inches="tight")
plt.show()
plt.close()

## Model predictions heatmap
Load the model predictions plotted in main figure 2 and make a heatmap with it. Put it next to the data heatmap. Also, show a single color code for perturbations. 

In [None]:
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]

# Load model predictions
antagDfModel = (pd.read_hdf(pj(res_dir, "dfs_model_data_ci_mcmc_both_conc.h5"), "model")
                .droplevel("TCR_Antigen_EC50"))
# Arrange in a nice way for cluster map plotting
antagDfModel = (antagDfModel.xs(str(best_kmf_car), level="kmf")["geo_mean"]
                .unstack("TCR_Antigen").droplevel("Subset")
               .rename({"1nM":"Low", "1uM":"High"}, level="TCR_Antigen_Density"))
for recep in ["TCR", "CAR"]:
    antagDfModel.index = antagDfModel.index.set_names(names=recep+"_ITAM_Number", level=recep+"_ITAMs")
    
# Add a single level containing all perturbations
antagDfModel["Perturbation"] = (antagDfModel.index.get_level_values("TCR_Antigen_Density") 
                              + "+" + antagDfModel.index.get_level_values("CAR_ITAM_Number") 
                              + "+" + antagDfModel.index.get_level_values("TCR_ITAM_Number") )
antagDfModel = antagDfModel.set_index("Perturbation", append=True)

# Keep taus closest to OT1 peptides, so the sampling is identical to data
# and the clustering can work its charm in the same way. 
#kept_taus = {find_nearest(antagDfModel.columns.values, k):p for p, k in pep_tau_map_ot1.items()}
#antagDfModel = antagDfModel.loc[:, kept_taus.keys()]
#antagDfModel = antagDfModel.rename(kept_taus, axis=1)

# Take the log and clip, just like the data
antagDfModel = np.clip(np.log2(antagDfModel), a_min=-2, a_max=6)

# Min-max scale, just like the data
antagDfModel = (zero_centered_min_max_scaling2(antagDfModel.stack("TCR_Antigen").to_frame())
                 .iloc[:, 0].unstack("TCR_Antigen"))
antagDfModel.index = antagDfModel.index.reorder_levels(
    ["TCR_Antigen_Density", "CAR_ITAM_Number", "TCR_ITAM_Number", "Perturbation"])
antagDfModel = antagDfModel.sort_index()

In [None]:
heatmap_figsize_model = (heatmap_figsize[0] * 1.05, heatmap_figsize[1] * 1.18)

clustmap, _ = antagonism_heatmap(antagDfModel, perturbations_palette, source="model", 
                              dendro=dendrogram, show_dendro=False, size_inches=heatmap_figsize_model)
if do_save_plots:
    clustmap.savefig(pj(fig_dir, "heatmap_model.pdf"), transparent=True, dpi=200, bbox_inches="tight")
plt.show()
plt.close()