# Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import os, json, h5py
pj = os.path.join
import warnings

from scripts.preprocess import write_conc_uM, read_conc_uM, string_to_tuple
from scripts.plotting import (
    change_log_ticks, 
    data_model_handles_legend, 
    standalone_legend, 
    corner_plot_mcmc, 
    plot_tcr_tcr_fits, 
    plot_autocorr, 
    standalone_parameter_values, 
    put_log_ticks
)
from scripts.analysis import find_best_grid_point

In [None]:
do_save = False
root_dir = ".."
res_folder = pj(root_dir, "results", "for_plots")
mcmc_folder = pj(root_dir, "results", "mcmc")
panels_folder = "panels2/"
data_dir = pj(root_dir, "data", "antagonism")

In [None]:
# Aesthetic parameters. Small scale figure.
# But scale everything up 2x times because otherwise too few axes tickes are put in by Matplotlib
# Will shrink to 50 % in InDesign to fit in the actual figure size. 
scaleup = 1.0

# Total of 4 panels + legend/cartoon + axes labels*4 + whitespace(0.5 in) = 7 inches, page width = 180 mm
# So panels = 5.5 / 5 inches seems reasonable
# Panel height: 2 rows for TCR/CAR, 2 rows for TCR/TCR, 1 for 6F, 2 for diagrams, total 170 mm = 6.7 inches
# Also 5 rows of labels = 1.25 in and say 0.5 in whitespace. So left with 5. / 7
panel_dimensions = {
    "panel_width": 5.5 / 5.0 * scaleup,     # inches
    "axes_label_width": 0.25 * scaleup,    # inches
    "panel_height": (5/7 + 0.05) * scaleup  # inches
}
panel_dimensions["legend_width"] = panel_dimensions["panel_width"]


# rcParams
plt.rcParams["font.size"] =  6. * scaleup  # Default, smallish
plt.rcParams["figure.dpi"] = 150.0
plt.rcParams["axes.labelpad"] = 0.5 * scaleup
plt.rcParams["axes.linewidth"] = 0.75 * scaleup     # edge line width
plt.rcParams["lines.linewidth"] = 1.5 * scaleup               # line width in points
plt.rcParams["lines.markersize"] = 2.5 * scaleup   # marker size, in points
for x in ["xtick.", "ytick."]:
    plt.rcParams[x + "major.size"] = 2.25 * scaleup
    plt.rcParams[x + "minor.size"] = 1.8 * scaleup
    plt.rcParams[x + "major.pad"] = 2.0 * scaleup
    plt.rcParams[x + "minor.pad"] = 1.9 * scaleup
    plt.rcParams[x + "minor.width"] = 0.5 * scaleup
    plt.rcParams[x + "major.width"] = 0.75 * scaleup
plt.rcParams["axes.spines.top"] = False
plt.rcParams["axes.spines.right"] = False

with open(pj(res_folder, "perturbations_palette.json"), "r") as h:
    perturbations_palette = json.load(h)
perturbations_palette["None"] = (0.0, 0.0, 0.0, 1.0)
sns.palplot(perturbations_palette.values())

In [None]:
# Special palette for figure 1
# Teals: from low density to high density of agonists
#agconc_palette = ["#d6ebe9", "#77a9b6", "#345671"]
#agconc_palette = sns.dark_palette("#77a9b6", n_colors=4)[1:][::-1]
agconc_palette = np.asarray([(66, 150, 141, 255), (15, 85, 97, 255), (13, 50, 70, 255)]) / 255
sns.palplot(agconc_palette)
# Lighter, darker: low, high densities
antagconc_palette = ["#bd5653", "#641c1e"]
teal3 = ['#D1EEEA', '#68ABB8', '#2A5674']
#perturbations_palette["AgDens"] = antagconc_palette[0]
#perturbations_palette["None"] = antagconc_palette[1]
#sns.palplot(antagconc_palette)
sns.palplot(teal3)

# TCR/CAR model typical behaviours - Moved to supplementary
Simple graphs. Two subplots, one legend. 

# TCR/TCR antagonism fits and CIs
These are 90 % CIs, I computed 5th and 95th percentiles. 

Alternate layout: 3 plots on same row, shared y axis helps. 

Above the fits: MCMC cornerplot, and TCR cartoon and legend in the white space. Do an upper corner plot so the legend in the blank space is close to the fit plots. 


In [None]:
# Read best kmf for AKPR and best m for SHP-1
with open(pj(mcmc_folder, "mcmc_analysis_akpr_i.json"), "r") as h:
    best_kmf_akpr, _, _ = find_best_grid_point(json.load(h), strat="best")
with open(pj(mcmc_folder, "mcmc_analysis_shp1.json"), "r") as h:
    best_m_shp1, _, _ = find_best_grid_point(json.load(h), strat="best")  

# For 6F, need to restrict to k <= 1 to avoid overfitting
with open(pj(mcmc_folder, "mcmc_analysis_tcr_tcr_6f.json"), "r") as h:
    lysis = json.load(h)
# Drop all points with k > klim, as large ks can be overfitted
for p in list(lysis.keys()):
    kmf_tuple = string_to_tuple(p)
    if kmf_tuple[0] > 1:
        lysis.pop(p)
best_kmf_6f, _, _ = find_best_grid_point(lysis, strat="best")
del lysis, kmf_tuple, p

In [None]:
# This alternate layout changes the panel widths and heights
# After diagrams and CAR/TCR, we are left with 110 mm height approx. 
# We want to fill this with cornerplots plus one row of model fits. 
# We move 6F TCR to the second figure. 
# Say the cornerplot is 2x panel height, we have 3x panel height, plus 2 times label height
# and maybe some blank space
# So 110 mm = 4.3 in, 4.3 - 1. = 3.3 in / 3 = 1.0 in panel height, approx. 
panel_dimensions2 = {
    "panel_width": 6.0 / 6.0 * scaleup,     # inches
    "axes_label_width": 0.25 * scaleup,    # inches
}
panel_dimensions2["panel_height"] = ((4.3 - 3*panel_dimensions2["axes_label_width"] - 0.5) 
                                                    / 3.0 * scaleup)  # inches
panel_dimensions2["legend_width"] = panel_dimensions2["panel_width"]

In [None]:
fig, axes, handles, labels = plot_tcr_tcr_fits(pj(res_folder, "dfs_model_data_ci_mcmc_tcr_tcr.h5"), 
                        "shp1", best_m_shp1, panel_dimensions2, perturbations_palette)
for i, ax in enumerate(fig.axes):
    ax.title.set_color(agconc_palette[i])
if do_save:
    fig.savefig(pj(panels_folder, "2F-fit_confidence_tcr_tcr_francois2013.pdf"), 
                        transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Make a stand-alone legend
fig, ax, leg = standalone_legend(handles=handles, handler_map=labels, frameon=False)
if do_save:
    fig.savefig(pj(panels_folder, "2F-fit_confidence_tcr_tcr_francois2013_legend.pdf"),
            transparent=True, bbox_inches="tight", bbox_extra_artists=(leg,))
plt.show()
plt.close()

In [None]:
fig, axes, handles, labels = plot_tcr_tcr_fits(pj(res_folder, "dfs_model_data_ci_mcmc_tcr_tcr.h5"), 
                            "akpr_i", best_kmf_akpr, panel_dimensions2, perturbations_palette)
for i, ax in enumerate(fig.axes):
    ax.title.set_color(agconc_palette[i])
if do_save:
    fig.savefig(pj(panels_folder, "2I-fit_confidence_tcr_tcr_revised_akpr.pdf"), 
            transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Make a stand-alone legend
fig, ax, leg = standalone_legend(handles=handles, handler_map=labels, frameon=False)
if do_save:
    fig.savefig(pj(panels_folder, "2I-fit_confidence_tcr_tcr_revised_akpr_legend_ec50.pdf"),
            transparent=True, bbox_inches="tight", bbox_extra_artists=(leg,))
plt.show()
plt.close()

In [None]:
# corner_plot_mcmc(samples_fname, analysis_fname, kmf, pdims, sizes_kwargs={}, **kwargs)
corner_plot_kwargs = {
    "scaleup": scaleup,
    "small_lw": 0.8, 
    "truth_lw": 1.25, 
    "small_markersize": 1.0, 
    "truth_color": np.asarray((0.0, 156.0, 75.0, 255.0)) / 255.0,  # deep key lime green
    "reverse_plots": False, 
    "labelpad": 0.1,
    "n_times_height":2, 
    "n_extra_x_labels": 2, 
    "n_extra_y_labels": 0
}
tick_label_size_mcmc = 5.5
nice_pnames = [r"$\varphi$", r"$C_{m, th}$", r"$I_{tot}$"]
fig = corner_plot_mcmc(pj(mcmc_folder, "mcmc_results_shp1.h5"), 
                       pj(mcmc_folder, "mcmc_analysis_shp1.json"), 
                 best_m_shp1, panel_dimensions2, pnames=nice_pnames, sizes_kwargs=corner_plot_kwargs)

# Adjust tick labels to gain a bit of space
for ax in fig.axes:
    ax.xaxis.set_tick_params(labelsize=tick_label_size_mcmc*scaleup, pad=0.5*scaleup)
    if ax.yaxis.get_label() is not None:
        ylbl_coords = ax.yaxis.get_label().get_position()
        ax.yaxis.set_label_coords(ylbl_coords[0]+0.1, ylbl_coords[1])
    ax.yaxis.set_tick_params(labelsize=tick_label_size_mcmc*scaleup, pad=0.5*scaleup)

if do_save:
    fig.savefig(pj(panels_folder, "2E-mcmc_cornerplot_shp1.pdf"), 
            transparent=True, bbox_inches="tight", dpi=600)
plt.show()
plt.close()

In [None]:
corner_plot_kwargs = {
    "scaleup": scaleup,
    "small_lw": 0.8, 
    "truth_lw": 1.25, 
    "small_markersize": 1.0, 
    "truth_color": np.asarray((0.0, 156.0, 75.0, 255.0)) / 255.0,  # deep key lime green
    "reverse_plots": False, 
    "labelpad": 0.25, 
    "n_times_height": 2, 
    "n_extra_x_labels": 2, 
    "n_extra_y_labels": 0
}
tick_label_size_mcmc = 5.0
nice_pnames = [r"$\varphi$", r"$C_{m, th}$", 
               r"$I_{th}$", r"$\psi_0$"]
fig = corner_plot_mcmc(pj(mcmc_folder, "mcmc_results_akpr_i.h5"), 
                       pj(mcmc_folder, "mcmc_analysis_akpr_i.json"), 
                 best_kmf_akpr, panel_dimensions2, pnames=nice_pnames, sizes_kwargs=corner_plot_kwargs)

# Adjust tick labels to gain a bit of space
for ax in fig.axes:
    ax.xaxis.set_tick_params(labelsize=tick_label_size_mcmc*scaleup, pad=0.5*scaleup)
    if ax.yaxis.get_label() is not None:
        ylbl_coords = ax.yaxis.get_label().get_position()
        ax.yaxis.set_label_coords(ylbl_coords[0]+0.15, ylbl_coords[1])
    ax.yaxis.set_tick_params(labelsize=tick_label_size_mcmc*scaleup, pad=0.5*scaleup)

if do_save:
    fig.savefig(pj(panels_folder, "2H-mcmc_cornerplot_revised_akpr.pdf"), 
            transparent=True, bbox_inches="tight", dpi=600)
plt.show()
plt.close()

## Label of N, k, m, f

In [None]:
francois2013_integers_names = [r"N^T", r"m^T"]
francois2013_integers_values = [6, string_to_tuple(best_m_shp1)[0]]
fig, ax, text = standalone_parameter_values(francois2013_integers_names, francois2013_integers_values)
if do_save:
    fig.savefig(pj(panels_folder, "integer_parameters_annotation_francois2013.pdf"), 
                transparent=True, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
akpr_integers_names = [r"N^T", r"k^T_I", r"m^T", r"f^T"]
akpr_integers_values = [6, *string_to_tuple(best_kmf_akpr)]
fig, ax, text = standalone_parameter_values(akpr_integers_names, akpr_integers_values, n_per_line=1)
if do_save:
    fig.savefig(pj(panels_folder, "integer_parameters_annotation_revised_akpr.pdf"), 
            transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Cartoon illustrating MCMC
Illustrate the differential evolution move. Build a KDE surface of the real sampled distribution, plot 32 random walkers on it, select two at random, propose a move along the vector separating those two. 

In [None]:
from scipy.stats import gaussian_kde

In [None]:
def plot_xy_arrow(ax, vec, **kwargs):
    w = 0
    z = 0
    x, y = vec[:, 0]
    u, v = vec[:, 1] - vec[:, 0]
    length = np.sqrt(np.sum((vec[:, 1] - vec[:, 0])**2))
    ax.quiver(x, y, z, u, v, w, normalize=True, length=length*1.3, **kwargs)
    return ax

def plot_xyz_arrow(ax, vec, **kwargs):
    x, y, z = vec[:, 0]
    u, v, w = vec[:, 1] - vec[:, 0]
    length = np.sqrt(np.sum((vec[:, 1] - vec[:, 0])**2))
    ax.quiver(x, y, z, u, v, w, normalize=True, length=length, **kwargs)
    return ax

In [None]:
# Version 2: illustrating the path of one walker -- simplest possible MCMC algorithm
# Make a KDE of the desired two parameters
# Load relevant data: samples, parameter names, burn in fraction
samples_file = h5py.File(pj(mcmc_folder, "mcmc_results_akpr_i.h5"), "r")
chosen_params_idx = [0, 1]
samples = samples_file.get("samples").get(str(best_kmf_akpr))[chosen_params_idx]
chosen_param_names = samples_file.get("samples").attrs.get("param_names")[chosen_params_idx]
print(chosen_param_names)

# Prepare nicer parameter labels (add log10, etc.)
param_labels = list(map(lambda a: r"$\log_{10}\," + a.strip("$").replace(r"\log ", "") + "$", 
                        chosen_param_names))
param_labels = list(map(lambda a: a.replace("thresh", "th"), param_labels))

# Drop the burn_in fraction.
with open(pj(mcmc_folder,  "mcmc_analysis_akpr_i.json"), "r") as h:
    results_dict = json.load(h).get(str(best_kmf_akpr))
burn_in_steps = results_dict["burn_in_steps"]
processed_samples = samples[:, :, burn_in_steps:]

# Build KDE on flattened walker dimension. Remove some samples for speed
kde = gaussian_kde(processed_samples[:, :, ::4].reshape([len(chosen_params_idx), -1]))

# Select walker snapshot
chosen_snapshots = slice(5001, 5020, 1)
chosen_walkers = [11, 25]
walker_points = processed_samples[:, chosen_walkers, chosen_snapshots]

# Compute each path
walker_vectors = np.diff(walker_points, axis=2)
walker_lines = []
for j in range(walker_points.shape[2]-1):
    walker_lines.append(walker_points[:, :, j:j+1] + walker_vectors[:, :, j:j+1] * np.arange(0.0, 1.0, 0.05))
walker_lines.append(walker_points[:, :, -1:])
walker_lines = np.concatenate(walker_lines, axis=2)

xlims = [processed_samples[0].min(), processed_samples[0].max()-0.1]
ylims = [processed_samples[1].min(), processed_samples[1].max()-0.1]
xx, yy = np.meshgrid(np.linspace(*xlims, 40), np.linspace(*ylims, 40), indexing="ij")
xygridpoints = np.vstack([xx.ravel(), yy.ravel()])
pdf = kde(xygridpoints).reshape(xx.shape)

walker_points_z = np.stack([kde(walker_points[:, i]) for i in range(walker_points.shape[1])], axis=0)
walker_lines_z = np.stack([kde(walker_lines[:, i]) for i in range(walker_lines.shape[1])], axis=0)

# Add jumps?
#periodic_jumps = 0.3*np.abs(np.sin(np.linspace(0.0, np.pi*walker_points.shape[2]-1, walker_lines_z.shape[1])))
#walker_lines_z += periodic_jumps[None, :]

In [None]:
# Plot in 3D the pdf, walkers, and proposed move. 
fig = plt.figure()
fig.set_size_inches(panel_dimensions["panel_width"]+panel_dimensions["axes_label_width"]*3, 
                    panel_dimensions["panel_height"]+panel_dimensions["axes_label_width"]*3)
ax = fig.add_subplot(projection='3d')
ax.plot_surface(xx, yy, pdf, color="lightgrey", edgecolor=(0.8,)*4, lw=0.35, rstride=4, cstride=4,
                alpha=0.2, zorder=0)
colors = ["k", "k"]
for i in range(walker_points.shape[1]):
    c = colors[i]
    for j in range(walker_points.shape[2]-1):
        vec = np.stack([*walker_points[:, i, [j, j+1]], walker_points_z[i, [j, j+1]]], axis=0)
        #plot_xyz_arrow(ax, vec, color=c)
    ax.plot(*walker_points[:, i], walker_points_z[i], marker="o", ls="none", 
            ms=1.5, mfc=c, mec=c, zorder=10+2*i)
    ax.plot(*walker_lines[:, i], walker_lines_z[i], color=c, ls="-", lw=0.8, zorder=11+2*i)


#ax.set(xlabel=chosen_param_names[0], ylabel=chosen_param_names[1])
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel(r"Parameter $\theta_i$", labelpad=-17)
ax.set_ylabel(r"Parameter $\theta_j$", labelpad=-17)
ax.set_zlabel(r"Posterior ($p(\vec{\theta}|D)$)", labelpad=-17, rotation=-180)
ax.view_init(azim=30, elev=30)
fig.tight_layout(pad=0)
if do_save:
    fig.savefig(pj(panels_folder, "2C_bottom-mcmc_illustration.pdf"), transparent=True, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
### Version with transparent background for graphical abstract
fig = plt.figure()
fig.set_size_inches(panel_dimensions["panel_width"]+panel_dimensions["axes_label_width"]*3, 
                    panel_dimensions["panel_height"]+panel_dimensions["axes_label_width"]*3)
ax = fig.add_subplot(projection='3d')
ax.plot_surface(xx, yy, pdf, color="lightgrey", edgecolor=(0.8,)*4, lw=0.35, rstride=4, cstride=4,
                alpha=0.2, zorder=0)
colors = ["k", "k"]
for i in range(walker_points.shape[1]):
    c = colors[i]
    for j in range(walker_points.shape[2]-1):
        vec = np.stack([*walker_points[:, i, [j, j+1]], walker_points_z[i, [j, j+1]]], axis=0)
        #plot_xyz_arrow(ax, vec, color=c)
    ax.plot(*walker_points[:, i], walker_points_z[i], marker="o", ls="none", 
            ms=1.5, mfc=c, mec=c, zorder=10+2*i)
    ax.plot(*walker_lines[:, i], walker_lines_z[i], color=c, ls="-", lw=0.8, zorder=11+2*i)


#ax.set(xlabel=chosen_param_names[0], ylabel=chosen_param_names[1])
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
#ax.set_xlabel(r"Parameter $\theta_i$", labelpad=-17)
#ax.set_ylabel(r"Parameter $\theta_j$", labelpad=-17)
#ax.set_zlabel(r"Posterior ($p(\vec{\theta}|D)$)", labelpad=-17, rotation=-180)
ax.view_init(azim=30, elev=30)
fig.tight_layout(pad=0)
ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
if do_save:
    fig.savefig(pj(panels_folder, "graphicalAbstract-mcmc_illustration_transparent.png"), 
                transparent=True, bbox_inches="tight", dpi=300)
plt.show()
plt.close()

# Panel B: TCR/TCR data carpet

In [None]:
#@title Plot matrix colorbar label
# Using the TCR/CAR data for the colorbar
from matplotlib.patches import Patch
def zero_centered_min_max_scaling2(dataframe):
    """
    Scale the numerical values in the dataframe to be between -1 and 1
    """
    df_copy = dataframe.copy(deep=True)
    for i,column in enumerate(df_copy.columns):
        max_pos = df_copy[column].max()
        min_pos = df_copy[column].min()
        for row in range(df_copy.shape[0]):
          val = df_copy.iloc[row,i]
          if val <= 0:
            val= -1*val/min_pos
          else:
            val = val/max_pos
          df_copy.iloc[row,i] = val
    return df_copy

antagonismDf = pd.read_hdf(pj(data_dir, 'ot1_car_antagonism_df_with_ratio.h5'))

sns.set_context('talk')
#@title Plot matrix of data with color key
roughAntigenEC50Dict = {'None':1e6,'E1':6e4,'G4':1e4,'V4':9e2,'T4':9e1,'Q4':2e1,'Y3':8e0,'A2':3e0,'N4':1}
temp = antagonismDf.query("CAR_Antigen == 'CD19' and `CAR_ITAM_Number` != '0'")
#temp = temp.query("Cytokine == ['IL-2','TNFa']")
temp1 = temp.query("Cytokine == ['IL-2','TNFa']")
temp2 = temp.query("Cytokine != ['IL-2','TNFa'] and Time == [1,3,6]")
#temp2 = temp.query("Cytokine != ['IL-2','TNFa']")
#temp1 = temp1[temp1['Ratio'] < 2e2]
#temp2 = temp2[temp2['Ratio'] < 2e1]
temp = pd.concat([temp1,temp2])
o = temp.index.unique('TCR_Antigen').tolist()[::-1]
temp = pd.concat([temp.query("TCR_Antigen == @x") for x in o])
pertDict = {'1uM-10':'None','1uM-4':'Fewer TCR ITAMs','1nM-10':'Less TCR antigen density'}
temp['Condition'] = [pertDict['-'.join([x,y])] for x,y in zip(temp.index.get_level_values('TCR_Antigen_Density'),temp.index.get_level_values('TCR_ITAM_Number'))]
temp = temp.set_index(['Condition'],append=True).iloc[:,1].unstack('TCR_Antigen')
temp = temp[['E1','G4','V4','T4','Q4','A2','N4']]
temp = np.log2(temp)

fig = plt.figure(figsize=(5,10))
from matplotlib.colors import TwoSlopeNorm
#display(temp)
#print(temp.groupby([x for x in temp.index.names if x != 'Time']).mean().min().min())
#print(temp.groupby([x for x in temp.index.names if x != 'Time']).mean().max().max())
clippedValues = np.clip(temp,a_min=-2,a_max=6)
with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    clippedValues = zero_centered_min_max_scaling2(
        clippedValues.unstack('Cytokine').stack('TCR_Antigen')).unstack('TCR_Antigen').stack('Cytokine')
clippedValues = clippedValues[['E1','G4','V4','T4','Q4','A2','N4']]
#clippedValues = clippedValues.droplevel(['Perturbation','CAR_Antigen','Tumor'])
clippedValues = clippedValues.droplevel(['Condition','CAR_Antigen','Tumor'])
clippedValues = clippedValues.swaplevel(-1,0).swaplevel(-1,-2)
clippedValues = pd.concat([clippedValues.query("Cytokine == @x") for x in ['IL-2','TNFa','IFNg']]).rename({'TNFa':'TNF','IFNg':r'IFN-$\gamma$'}).query("Cytokine == 'IL-2'")
#clippedValues = clippedValues.droplevel(['TCR_ITAM_Number','TCR_Antigen_Density','Tumor','CAR_Antigen'])
#g = sns.heatmap(clippedValues.values,cmap='PuOr_r', cbar_kws={'shrink': 0.5})
# cbar = g.collections[0].colorbar
# cbar.set_ticks([])
# g.set_xticks([])
# g.set_yticks([])
# for spine in g.spines.values():
#     spine.set_visible(True)
# cbar.outline.set_edgecolor('k')
# cbar.outline.set_linewidth(2)

clippedValues = clippedValues.groupby([x for x in clippedValues.index.names if x != 'Time']).mean()
clippedValues = clippedValues.droplevel(['Spleen','Data']).swaplevel(-1,-3).swaplevel(-2,-3)
#clippedValues = clippedValues.rename({'1':'Low','3':'High'},level='CAR_ITAM_Number').rename({'4':'Low','10':'High'},level='TCR_ITAM_Number').rename({'1nM':'Low','1uM':'High'},level='TCR_Antigen_Density').droplevel('Cytokine')
clippedValues = clippedValues.rename({'1nM':'Low','1uM':'High'},level='TCR_Antigen_Density').droplevel('Cytokine')
# Add a single level containing all perturbations
#clippedValues["Perturbation"] = (clippedValues.index.get_level_values("TCR_Antigen_Density")
#                              + "+" + clippedValues.index.get_level_values("CAR_ITAM_Number")
#                              + "+" + clippedValues.index.get_level_values("TCR_ITAM_Number") )
#clippedValues = clippedValues.set_index("Perturbation", append=True)
# Prepare a color mapping each value in the Perturbation level to a color
# in the palette
dfDict = {}
levels = list(clippedValues.index.names)
#palettes = ['Greys']*len(clippedValues.index.names)
palettes = ['Greens_r','Blues_r','Reds_r']
lutList = []
setPalette = sns.color_palette("Set2")
setPalette = [setPalette[x] for x in [3,5,6]]
speciesDict = {'CAR_ITAM_Number':['3','1'],'TCR_ITAM_Number':['10','4'],'TCR_Antigen_Density':['High','Low']}
for level,p in zip(levels,palettes):
  subsetDf = clippedValues.reset_index()[levels+list(clippedValues.columns)]
  species = subsetDf.pop(level)
  #lut = dict(zip(speciesDict[level], sns.color_palette(p,len(species.unique()))))
  if level != 'Cytokine':
    lut = dict(zip(speciesDict[level], sns.color_palette(p,len(species.unique()))))
  else:
    #lut = dict(zip(speciesDict[level], setPalette))
    lut = dict(zip(speciesDict[level], cytokinePalette))
  lutList.append(lut)
  dfDict[level] = species.map(lut)

row_colors = pd.DataFrame(dfDict)
#cm = sns.clustermap(clippedValues.reset_index().loc[:,['E1','G4','V4','T4','Q4','A2','N4']],row_colors=row_colors,row_cluster=False,col_cluster=False,cmap='PuOr_r', cbar_kws={'shrink': 0.5})
cm = sns.clustermap(clippedValues.fillna(value=0).reset_index().loc[:,['E1','G4','V4','T4','Q4','A2','N4']],row_colors=row_colors,
                    col_cluster=False,cmap='PuOr_r', cbar_kws={'shrink': 0.5,'orientation':'horizontal'},figsize=(5,25))
g = cm.ax_heatmap
cbar = g.collections[0].colorbar
antagonismRatioYAxislabel = r'$FC_{\mathrm{TCR/TCR}}$'
cbar.set_label(antagonismRatioYAxislabel)
#g.set_xticks([])
# -2.1892270492399706
# 6.270750880443189
#cbar.set_ticks([-1,0,1])
cbar.set_ticks([])

g.set_yticks([])
for spine in g.spines.values():
    spine.set_visible(True)
cbar.outline.set_edgecolor('k')
cbar.outline.set_linewidth(2)
#g.set_xlabel('Stimulation')
g.set_xlabel('')

handles = [Patch(facecolor='w')]
labels = [r"$\bf{Perturbation:}$"]
titleList = ['TCR ITAM #','CAR ITAM #','TCR Ag Density']
for level,lut,title in zip(levels,lutList,titleList):
  handles.append(Patch(facecolor='w'))
  labels.append(title)
  for name in lut:
    handles.append(Patch(facecolor=lut[name],edgecolor='k'))
    labels.append(name)
plt.legend(handles, labels,
           bbox_to_anchor=(0.95, 0.55),loc='center left', bbox_transform=cm.fig.transFigure,frameon=False)

cm.ax_cbar.set_position([cm.ax_col_dendrogram.get_position().x0, cm.ax_col_dendrogram.get_position().y0+0.02,cm.ax_col_dendrogram.get_position().width, 0.02])
cm.ax_cbar.set_title('')
cm.ax_cbar.tick_params(axis='x', length=10)

cm.ax_heatmap.annotate('',xytext=(0.5,1.12),xy=(1,1.12),arrowprops=dict(arrowstyle='-|>',color='darkorange',lw=3), xycoords='axes fraction')
# cm.ax_heatmap.annotate('',xytext=(0,1.12),xy=(0.5,1.12),arrowprops=dict(arrowstyle='<-',color='purple',lw=3), xycoords='axes fraction')
cm.ax_heatmap.annotate('',xytext=(0,1.12),xy=(0.5,1.12),arrowprops=dict(arrowstyle='|-|, widthB=0,widthA=0.5',color='purple',lw=3), xycoords='axes fraction')
t = cm.ax_heatmap.text(0.25,1.16,'Antagonism',va='center',ha='center',fontsize=12,color='purple',zorder=100,transform=cm.ax_heatmap.transAxes,fontweight='bold')
#t.set_bbox(dict(facecolor='white', edgecolor='white'))
t = cm.ax_heatmap.text(0.75,1.16,'Enhancement',va='center',ha='center',fontsize=12,color='darkorange',zorder=100,transform=cm.ax_heatmap.transAxes,fontweight='bold')
#t.set_bbox(dict(facecolor='white', edgecolor='white'))
c1 = plt.Circle((0.5, 1.12), 0.01, color='k', clip_on=False,transform=cm.ax_heatmap.transAxes,zorder=100)
#cm.ax_heatmap.add_patch(c1)

cm.ax_row_colors.set_xticks([])
cm.ax_row_colors.set_xticklabels([])

cm.ax_heatmap.annotate('',xytext=(0,-0.12),xy=(1,-0.12),arrowprops=dict(arrowstyle='->',color='k',lw=3), xycoords='axes fraction')
t = cm.ax_heatmap.text(0.5,-0.12,'TCR Ag Strength',va='center',ha='center',fontsize=16,color='k',zorder=100,transform=cm.ax_heatmap.transAxes)
t.set_bbox(dict(facecolor='white', edgecolor='white'))

#cm.ax_heatmap.annotate('(',xy=(0.028,-0.065), xycoords='axes fraction',fontsize=18)
#cm.ax_heatmap.annotate(')+CD19',xy=(0.957,-0.065), xycoords='axes fraction',fontsize=18)
cm.ax_heatmap.annotate('(',xy=(-0.02,-0.065), xycoords='axes fraction',fontsize=18)
cm.ax_heatmap.annotate(')+CD19',xy=(0.99,-0.065), xycoords='axes fraction',fontsize=16)

if do_save:
    cm.savefig(pj(panels_folder, '2B-synergismAntagonismMatrix_tcrtcr-label.pdf'),bbox_inches='tight',transparent=True)
plt.clf()
#sns.set_context('poster')

In [None]:
#@title Load data
data = pd.read_hdf(os.path.join(root_dir, "data", "antagonism", 'allManualSameCellAntagonismDfs_v3.h5'),key='df')
with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    df_fit = data.stack("Time").dropna().iloc[:, 0]
sln = slice(None)
slc = (sln, sln, sln, "1nM", "None", "1nM")
slc2 = (sln, sln, sln, "1nM", "None", "1uM")
rest_of_index = df_fit.loc[slc].index
df_fit.loc[slc] = df_fit.loc[slc2].reindex(rest_of_index).values
subsetDf = df_fit.to_frame('Concentration (nM)')

antagonismDf = subsetDf.copy()#query("TCR_Antigen != 'None'")
ratioList = []
for row in range(antagonismDf.shape[0]):
  name = list(antagonismDf.iloc[row,:].name)
  name2 = name.copy()
  name2[-3] = 'None'
  agAlone = subsetDf.loc[tuple(name2),:].values[0]
  agAntag = antagonismDf.iloc[row,:].values[0]
  ratio = agAntag/agAlone
  ratioList.append(ratio)
antagonismDf['Ratio'] = ratioList
antagonismDf

In [None]:
sns.set_context('talk')
#@title Plot matrix of data with color key (wide)
roughAntigenEC50Dict = {'None':1e6,'E1':6e4,'G4':1e4,'V4':9e2,'T4':9e1,'Q4':2e1,'Y3':8e0,'A2':3e0,'N4':1}
temp = antagonismDf.query("AntagonistConcentration == ['1uM','1nM'] and `Agonist` == 'N4' and AgonistConcentration == ['1nM','100pM','10pM'] and Cytokine == ['IFNg','IL-2','TNFa']")
o = temp.index.unique('Antagonist').tolist()[::-1]
temp = pd.concat([temp.query("Antagonist == @x") for x in o])
temp = temp.iloc[:,1].unstack('Antagonist')
temp = temp[['E1','G4','V4','T4','Q4']]
temp = np.log2(temp)

fig = plt.figure(figsize=(5,10))
from matplotlib.colors import TwoSlopeNorm
clippedValues = np.clip(temp,a_min=-6,a_max=3)
#clippedValues = temp.copy()

with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    clippedValues = zero_centered_min_max_scaling2(
        clippedValues.unstack('Cytokine').stack('Antagonist')).unstack('Time').stack('Cytokine')
clippedValues = clippedValues.swaplevel(-1,0).swaplevel(-1,-2)
clippedValues = pd.concat([clippedValues.query("Cytokine == @x") for x in ['IL-2','TNFa','IFNg']]).rename({'TNFa':'TNF','IFNg':r'IFN-$\gamma$'}).query("Cytokine == 'IL-2'")

clippedValues = clippedValues.droplevel(['Agonist','Experiment']).swaplevel(-1,-3).swaplevel(-2,-3)

clippedValues = clippedValues.rename({'1nM':'Low','1uM':'High'},level='AntagonistConcentration').rename({'1nM':'High','100pM':'Med','10pM':'Low'},level='AgonistConcentration').droplevel('Cytokine')

# Prepare a color mapping each value in the Perturbation level to a color
# in the palette
dfDict = {}
levels = list(clippedValues.index.names)[::-1]
levels = [levels[1],levels[2],levels[0]]
#palettes = ['Greys']*len(clippedValues.index.names)
palettes = ['Greys_r','Reds_r', teal3]
#palettes = [palettes[0],palettes[1],palettes[2],palettes[1]]
lutList = []
setPalette = sns.color_palette("Set2")
setPalette = [setPalette[x] for x in [3,5,6]]
speciesDict = {'AgonistConcentration':['High','Med','Low'],'AntagonistConcentration':['High','Low'],'Antagonist':['Q4','T4','V4','G4','E1']}
times = clippedValues.columns
for level,p in zip(levels,palettes):
  subsetDf = clippedValues.reset_index()[levels+list(clippedValues.columns)]
  species = subsetDf.pop(level)
  if level == 'AgonistConcentration':
    p = p[::-1]
  #lut = dict(zip(speciesDict[level], sns.color_palette(p,len(species.unique()))))
  if level != 'Cytokine':
    if level == 'Antagonist':
      lut = dict(zip(speciesDict[level], sns.color_palette(p,len(species.unique())+2)[2:]))
    else:
      lut = dict(zip(speciesDict[level], sns.color_palette(p,len(species.unique()))))
  else:
    #lut = dict(zip(speciesDict[level], setPalette))
    lut = dict(zip(speciesDict[level], cytokinePalette))
  lutList.append(lut)
  dfDict[level] = species.map(lut)

row_colors = pd.DataFrame(dfDict)
#cm = sns.clustermap(clippedValues.reset_index().loc[:,['E1','G4','V4','T4','Q4','A2','N4']],row_colors=row_colors,row_cluster=False,col_cluster=False,cmap='PuOr_r', cbar_kws={'shrink': 0.5})
cm = sns.clustermap(clippedValues.fillna(value=0).reset_index().loc[:,list(times)],row_colors=row_colors,
                    col_cluster=False,cmap='PuOr_r', cbar_kws={'shrink': 0.5,'orientation':'horizontal'},figsize=(5,10))
g = cm.ax_heatmap
cbar = g.collections[0].colorbar
cbar.set_ticks([])
#g.set_xticks([])
g.set_yticks([])
g.set_ylabel('')
for spine in g.spines.values():
    spine.set_visible(True)
cbar.outline.set_edgecolor('k')
cbar.outline.set_linewidth(2)
#g.set_xlabel('Stimulation')
g.set_xlabel('')

handles = [Patch(facecolor='w')]
labels = [r"$\bf{Perturbation:}$"]
titleList = ['Agonist Density','Antagonist Density','Antagonist Strength'][::-1]
#titleList = [titleList[0],titleList[2],titleList[1],titleList[3]]

for level,lut,title in zip(levels,lutList,titleList):
  handles.append(Patch(facecolor='w'))
  labels.append(title)
  for name in lut:
    handles.append(Patch(facecolor=lut[name],edgecolor='k'))
    labels.append(name)
plt.legend(handles, labels,
           bbox_to_anchor=(0.9, 0.6),loc='center left', bbox_transform=cm.fig.transFigure,frameon=False)

cm.ax_cbar.set_position([cm.ax_col_dendrogram.get_position().x0, cm.ax_col_dendrogram.get_position().y0+0.02,cm.ax_col_dendrogram.get_position().width, 0.02])
cm.ax_cbar.set_title('')
cm.ax_cbar.tick_params(axis='x', length=10)

cm.ax_heatmap.annotate('',xytext=(0.5,1.12),xy=(1,1.12),arrowprops=dict(arrowstyle='-|>',color='darkorange',lw=3), xycoords='axes fraction')
cm.ax_heatmap.annotate('',xytext=(0,1.12),xy=(0.5,1.12),arrowprops=dict(arrowstyle='<-',color='purple',lw=3), xycoords='axes fraction')
#cm.ax_heatmap.annotate('',xytext=(0,1.12),xy=(0.5,1.12),arrowprops=dict(arrowstyle='|-|, widthB=0,widthA=0.5',color='purple',lw=3), xycoords='axes fraction')
t = cm.ax_heatmap.text(0.25,1.16,'Antagonism',va='center',ha='center',fontsize=12,color='purple',zorder=100,transform=cm.ax_heatmap.transAxes,fontweight='bold')
#t.set_bbox(dict(facecolor='white', edgecolor='white'))
t = cm.ax_heatmap.text(0.75,1.16,'Enhancement',va='center',ha='center',fontsize=12,color='darkorange',zorder=100,transform=cm.ax_heatmap.transAxes,fontweight='bold')
#t.set_bbox(dict(facecolor='white', edgecolor='white'))
c1 = plt.Circle((0.5, 1.12), 0.01, color='k', clip_on=False,transform=cm.ax_heatmap.transAxes,zorder=100)

cm.ax_row_colors.set_xticks([])
cm.ax_row_colors.set_xticklabels([])

#cm.ax_heatmap.add_patch(c1)

# cm.ax_row_colors.set_xticks([])
# cm.ax_row_colors.set_xticklabels([])

cm.ax_heatmap.annotate('',xytext=(0,-0.12),xy=(1,-0.12),arrowprops=dict(arrowstyle='->',color='k',lw=3), xycoords='axes fraction')
t = cm.ax_heatmap.text(0.5,-0.12,'Time (h)',va='center',ha='center',fontsize=16,color='k',zorder=100,transform=cm.ax_heatmap.transAxes)
t.set_bbox(dict(facecolor='white', edgecolor='white'))

#cm.ax_heatmap.annotate('(',xy=(0.028,-0.065), xycoords='axes fraction',fontsize=18)
#cm.ax_heatmap.annotate(')+CD19',xy=(0.957,-0.065), xycoords='axes fraction',fontsize=18)
# cm.ax_heatmap.annotate('(',xy=(-0.02,-0.065), xycoords='axes fraction',fontsize=18)
# cm.ax_heatmap.annotate(')+CD19',xy=(0.99,-0.065), xycoords='axes fraction',fontsize=16)
timesOfInterest = [6,12,24,36,48,72]
ticksOfInterest = [(list(times).index(x)+0.5) for x in timesOfInterest]
cm.ax_heatmap.set_xticks(ticksOfInterest)
cm.ax_heatmap.set_xticklabels([str(int(x)) for x in timesOfInterest],rotation=0)

if do_save:
    cm.savefig(pj(panels_folder, '2B-synergismAntagonismMatrix-tcrtcr.pdf'),bbox_inches='tight',transparent=True)
#sns.set_context('poster')