In [None]:
import os
import json, h5py
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors as clr
import numpy as np
import pandas as pd
import seaborn as sns
#sns.set_context('poster')
import warnings
warnings.filterwarnings('ignore')
idx = pd.IndexSlice
pj = os.path.join

In [None]:
# Paths
do_save = False
root_dir = ".."
data_dir = pj(root_dir, "results", "for_plots")
mcmc_folder = pj("..", "results", "mcmc")
fig_dir = "panels3/"

In [None]:
#Specific to this figure
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches

import matplotlib.text as mtext
from matplotlib.legend_handler import HandlerBase
from matplotlib.legend import Legend
from matplotlib.lines import Line2D

import itertools

In [None]:
#@title Aesthetic parameters
with open(pj(data_dir, "perturbations_palette.json"), "r") as f:
    perturb_palette = json.load(f)
perturb_palette["None"] = [0., 0., 0., 1.]  # Black
sns.palplot(perturb_palette.values())

# Palettes for antigen density and such
teal3 = ['#D1EEEA', '#68ABB8', '#2A5674']
magenta3 = ['#F3CBD3', '#CA699D', '#6C2167']
sns.palplot(teal3)

# Panel A: Modeling perturbation experiment schematic

# Panel B: Perturbation experiment crosstalk matrix

In [None]:
#@title Tools
sns.set_context('talk')
def zero_centered_min_max_scaling(dataframe):
    """
    Scale the numerical values in the dataframe to be between -1 and 1, preserving the
    signal of all values.
    """
    df_copy = dataframe.copy(deep=True)
    for column in df_copy.columns:
        max_absolute_value = df_copy[column].abs().max()
        df_copy[column] = df_copy[column] / max_absolute_value
    return df_copy
def zero_centered_min_max_scaling2(dataframe):
    """
    Scale the numerical values in the dataframe to be between -1 and 1
    """
    df_copy = dataframe.copy(deep=True)
    for i,column in enumerate(df_copy.columns):
        max_pos = df_copy[column].max()
        min_pos = df_copy[column].min()
        for row in range(df_copy.shape[0]):
          val = df_copy.iloc[row,i]
          if val <= 0:
            val= -1*val/min_pos
          else:
            val = val/max_pos
          df_copy.iloc[row,i] = val
    return df_copy

In [None]:
#@title Load data
# Version of the CAR antagonism data with the ratio pre-computed for the heatmap plot. 
antagonismDf = pd.read_hdf(pj("..", "data", "antagonism", "ot1_car_antagonism_df_with_ratio.h5"))
antagonismDf

In [None]:
#@title Plot matrix colorbar label
from matplotlib.patches import Patch

sns.set_context('talk')
#@title Plot matrix of data with color key
roughAntigenEC50Dict = {'None':1e6,'E1':6e4,'G4':1e4,'V4':9e2,'T4':9e1,'Q4':2e1,'Y3':8e0,'A2':3e0,'N4':1}
temp = antagonismDf.query("CAR_Antigen == 'CD19' and `CAR_ITAM_Number` != '0'")
#temp = temp.query("Cytokine == ['IL-2','TNFa']")
temp1 = temp.query("Cytokine == ['IL-2','TNFa']")
temp2 = temp.query("Cytokine != ['IL-2','TNFa'] and Time == [1,3,6]")
#temp2 = temp.query("Cytokine != ['IL-2','TNFa']")
#temp1 = temp1[temp1['Ratio'] < 2e2]
#temp2 = temp2[temp2['Ratio'] < 2e1]
temp = pd.concat([temp1,temp2])
o = temp.index.unique('TCR_Antigen').tolist()[::-1]
temp = pd.concat([temp.query("TCR_Antigen == @x") for x in o])
pertDict = {'1uM-10':'None','1uM-4':'Fewer TCR ITAMs','1nM-10':'Less TCR antigen density'}
temp['Condition'] = [pertDict['-'.join([x,y])] for x,y in zip(temp.index.get_level_values('TCR_Antigen_Density'),temp.index.get_level_values('TCR_ITAM_Number'))]
temp = temp.set_index(['Condition'],append=True).iloc[:,1].unstack('TCR_Antigen')
temp = temp[['E1','G4','V4','T4','Q4','A2','N4']]
temp = np.log2(temp)

fig = plt.figure(figsize=(5,10))
from matplotlib.colors import TwoSlopeNorm
#display(temp)
#print(temp.groupby([x for x in temp.index.names if x != 'Time']).mean().min().min())
#print(temp.groupby([x for x in temp.index.names if x != 'Time']).mean().max().max())
clippedValues = np.clip(temp,a_min=-2,a_max=6)
clippedValues = zero_centered_min_max_scaling2(clippedValues.unstack('Cytokine').stack('TCR_Antigen')).unstack('TCR_Antigen').stack('Cytokine')
clippedValues = clippedValues[['E1','G4','V4','T4','Q4','A2','N4']]
#clippedValues = clippedValues.droplevel(['Perturbation','CAR_Antigen','Tumor'])
clippedValues = clippedValues.droplevel(['Condition','CAR_Antigen','Tumor'])
clippedValues = clippedValues.swaplevel(-1,0).swaplevel(-1,-2)
clippedValues = pd.concat([clippedValues.query("Cytokine == @x") for x in ['IL-2','TNFa','IFNg']]).rename({'TNFa':'TNF','IFNg':'IFN-$\gamma$'}).query("Cytokine == 'IL-2'")
#clippedValues = clippedValues.droplevel(['TCR_ITAM_Number','TCR_Antigen_Density','Tumor','CAR_Antigen'])
#g = sns.heatmap(clippedValues.values,cmap='PuOr_r', cbar_kws={'shrink': 0.5})
# cbar = g.collections[0].colorbar
# cbar.set_ticks([])
# g.set_xticks([])
# g.set_yticks([])
# for spine in g.spines.values():
#     spine.set_visible(True)
# cbar.outline.set_edgecolor('k')
# cbar.outline.set_linewidth(2)

clippedValues = clippedValues.groupby([x for x in clippedValues.index.names if x != 'Time']).mean()
clippedValues = clippedValues.droplevel(['Spleen','Data']).swaplevel(-1,-3).swaplevel(-2,-3)
#clippedValues = clippedValues.rename({'1':'Low','3':'High'},level='CAR_ITAM_Number').rename({'4':'Low','10':'High'},level='TCR_ITAM_Number').rename({'1nM':'Low','1uM':'High'},level='TCR_Antigen_Density').droplevel('Cytokine')
clippedValues = clippedValues.rename({'1nM':'Low','1uM':'High'},level='TCR_Antigen_Density').droplevel('Cytokine')
# Add a single level containing all perturbations
#clippedValues["Perturbation"] = (clippedValues.index.get_level_values("TCR_Antigen_Density")
#                              + "+" + clippedValues.index.get_level_values("CAR_ITAM_Number")
#                              + "+" + clippedValues.index.get_level_values("TCR_ITAM_Number") )
#clippedValues = clippedValues.set_index("Perturbation", append=True)
# Prepare a color mapping each value in the Perturbation level to a color
# in the palette
dfDict = {}
levels = list(clippedValues.index.names)
#palettes = ['Greys']*len(clippedValues.index.names)
palettes = ['Greens_r','Blues_r','Reds_r']
lutList = []
setPalette = sns.color_palette("Set2")
setPalette = [setPalette[x] for x in [3,5,6]]
speciesDict = {'CAR_ITAM_Number':['3','1'],'TCR_ITAM_Number':['10','4'],'TCR_Antigen_Density':['High','Low']}
for level,p in zip(levels,palettes):
  subsetDf = clippedValues.reset_index()[levels+list(clippedValues.columns)]
  species = subsetDf.pop(level)
  #lut = dict(zip(speciesDict[level], sns.color_palette(p,len(species.unique()))))
  if level != 'Cytokine':
    lut = dict(zip(speciesDict[level], sns.color_palette(p,len(species.unique()))))
  else:
    #lut = dict(zip(speciesDict[level], setPalette))
    lut = dict(zip(speciesDict[level], cytokinePalette))
  lutList.append(lut)
  dfDict[level] = species.map(lut)

row_colors = pd.DataFrame(dfDict)
#cm = sns.clustermap(clippedValues.reset_index().loc[:,['E1','G4','V4','T4','Q4','A2','N4']],row_colors=row_colors,row_cluster=False,col_cluster=False,cmap='PuOr_r', cbar_kws={'shrink': 0.5})
cm = sns.clustermap(clippedValues.fillna(value=0).reset_index().loc[:,['E1','G4','V4','T4','Q4','A2','N4']],row_colors=row_colors,
                    col_cluster=False,cmap='PuOr_r', cbar_kws={'shrink': 0.5,'orientation':'horizontal'},figsize=(5,25))
g = cm.ax_heatmap
cbar = g.collections[0].colorbar
antagonismRatioYAxislabel = r'$FC_{\mathrm{TCR/CAR}}$'
cbar.set_label(antagonismRatioYAxislabel)
#g.set_xticks([])
# -2.1892270492399706
# 6.270750880443189
#cbar.set_ticks([-1,0,1])
cbar.set_ticks([])

g.set_yticks([])
for spine in g.spines.values():
    spine.set_visible(True)
cbar.outline.set_edgecolor('k')
cbar.outline.set_linewidth(2)
#g.set_xlabel('Stimulation')
g.set_xlabel('')

handles = [Patch(facecolor='w')]
labels = [r"$\bf{Perturbation:}$"]
titleList = ['TCR ITAM #','CAR ITAM #','TCR Ag Density']
for level,lut,title in zip(levels,lutList,titleList):
  handles.append(Patch(facecolor='w'))
  labels.append(title)
  for name in lut:
    handles.append(Patch(facecolor=lut[name],edgecolor='k'))
    labels.append(name)
plt.legend(handles, labels,
           bbox_to_anchor=(0.95, 0.55),loc='center left', bbox_transform=cm.fig.transFigure,frameon=False)

cm.ax_cbar.set_position([cm.ax_col_dendrogram.get_position().x0, cm.ax_col_dendrogram.get_position().y0+0.02,cm.ax_col_dendrogram.get_position().width, 0.02])
cm.ax_cbar.set_title('')
cm.ax_cbar.tick_params(axis='x', length=10)

cm.ax_heatmap.annotate('',xytext=(0.5,1.12),xy=(1,1.12),arrowprops=dict(arrowstyle='-|>',color='darkorange',lw=3), xycoords='axes fraction')
# cm.ax_heatmap.annotate('',xytext=(0,1.12),xy=(0.5,1.12),arrowprops=dict(arrowstyle='<-',color='purple',lw=3), xycoords='axes fraction')
cm.ax_heatmap.annotate('',xytext=(0,1.12),xy=(0.5,1.12),arrowprops=dict(arrowstyle='|-|, widthB=0,widthA=0.5',color='purple',lw=3), xycoords='axes fraction')
t = cm.ax_heatmap.text(0.25,1.16,'Antagonism',va='center',ha='center',fontsize=12,color='purple',zorder=100,transform=cm.ax_heatmap.transAxes,fontweight='bold')
#t.set_bbox(dict(facecolor='white', edgecolor='white'))
t = cm.ax_heatmap.text(0.75,1.16,'Enhancement',va='center',ha='center',fontsize=12,color='darkorange',zorder=100,transform=cm.ax_heatmap.transAxes,fontweight='bold')
#t.set_bbox(dict(facecolor='white', edgecolor='white'))
c1 = plt.Circle((0.5, 1.12), 0.01, color='k', clip_on=False,transform=cm.ax_heatmap.transAxes,zorder=100)
#cm.ax_heatmap.add_patch(c1)

cm.ax_row_colors.set_xticks([])
cm.ax_row_colors.set_xticklabels([])

cm.ax_heatmap.annotate('',xytext=(0,-0.12),xy=(1,-0.12),arrowprops=dict(arrowstyle='->',color='k',lw=3), xycoords='axes fraction')
t = cm.ax_heatmap.text(0.5,-0.12,'TCR Ag Strength',va='center',ha='center',fontsize=16,color='k',zorder=100,transform=cm.ax_heatmap.transAxes)
t.set_bbox(dict(facecolor='white', edgecolor='white'))

#cm.ax_heatmap.annotate('(',xy=(0.028,-0.065), xycoords='axes fraction',fontsize=18)
#cm.ax_heatmap.annotate(')+CD19',xy=(0.957,-0.065), xycoords='axes fraction',fontsize=18)
cm.ax_heatmap.annotate('(',xy=(-0.02,-0.065), xycoords='axes fraction',fontsize=18)
cm.ax_heatmap.annotate(')+CD19',xy=(0.99,-0.065), xycoords='axes fraction',fontsize=16)

if do_save:
    cm.savefig(pj(fig_dir, '3B-synergismAntagonismMatrix-label.pdf'),bbox_inches='tight',transparent=True)
plt.clf()
#sns.set_context('poster')

In [None]:
sns.set_context('talk')
#@title Plot matrix of data with color key (wide)
roughAntigenEC50Dict = {'None':1e6,'E1':6e4,'G4':1e4,'V4':9e2,'T4':9e1,'Q4':2e1,'Y3':8e0,'A2':3e0,'N4':1}
temp = antagonismDf.query("CAR_Antigen == 'CD19' and `CAR_ITAM_Number` != '0'")
#temp = temp.query("Cytokine == ['IL-2','TNFa']")
temp1 = temp.query("Cytokine == ['IL-2','TNFa']")
temp2 = temp.query("Cytokine != ['IL-2','TNFa'] and Time == [1,3,6]")
#temp2 = temp.query("Cytokine != ['IL-2','TNFa']")
temp1 = temp1[temp1['Ratio'] < 2e2]
temp2 = temp2[temp2['Ratio'] < 2e1]
temp = pd.concat([temp1,temp2])
o = temp.index.unique('TCR_Antigen').tolist()[::-1]
temp = pd.concat([temp.query("TCR_Antigen == @x") for x in o])
pertDict = {'1uM-10':'None','1uM-4':'Fewer TCR ITAMs','1nM-10':'Less TCR antigen density'}
temp['Condition'] = [pertDict['-'.join([x,y])] for x,y in zip(temp.index.get_level_values('TCR_Antigen_Density'),temp.index.get_level_values('TCR_ITAM_Number'))]
temp = temp.set_index(['Condition'],append=True).iloc[:,1].unstack('TCR_Antigen')
temp = temp[['E1','G4','V4','T4','Q4','A2','N4']]
temp = np.log2(temp)

fig = plt.figure(figsize=(5,10))
from matplotlib.colors import TwoSlopeNorm
clippedValues = np.clip(temp,a_min=-2,a_max=6)
clippedValues = zero_centered_min_max_scaling2(clippedValues.unstack('Cytokine').stack('TCR_Antigen')).unstack('Time').stack('Cytokine')
#clippedValues = clippedValues[['E1','G4','V4','T4','Q4','A2','N4']]
#clippedValues = clippedValues.droplevel(['Perturbation','CAR_Antigen','Tumor'])
clippedValues = clippedValues.droplevel(['Condition','CAR_Antigen','Tumor'])
clippedValues = clippedValues.swaplevel(-1,0).swaplevel(-1,-2)
clippedValues = pd.concat([clippedValues.query("Cytokine == @x") for x in ['IL-2','TNFa','IFNg']]).rename({'TNFa':'TNF','IFNg':'IFN-$\gamma$'}).query("Cytokine == 'IL-2'")
#clippedValues = clippedValues.droplevel(['TCR_ITAM_Number','TCR_Antigen_Density','Tumor','CAR_Antigen'])
#g = sns.heatmap(clippedValues.values,cmap='PuOr_r', cbar_kws={'shrink': 0.5})
# cbar = g.collections[0].colorbar
# cbar.set_ticks([])
# g.set_xticks([])
# g.set_yticks([])
# for spine in g.spines.values():
#     spine.set_visible(True)
# cbar.outline.set_edgecolor('k')
# cbar.outline.set_linewidth(2)

#clippedValues = clippedValues.groupby([x for x in clippedValues.index.names if x != 'Time']).mean()
clippedValues = clippedValues.droplevel(['Spleen','Data']).swaplevel(-1,-3).swaplevel(-2,-3)
#clippedValues = clippedValues.rename({'1':'Low','3':'High'},level='CAR_ITAM_Number').rename({'4':'Low','10':'High'},level='TCR_ITAM_Number').rename({'1nM':'Low','1uM':'High'},level='TCR_Antigen_Density').droplevel('Cytokine')
clippedValues = clippedValues.rename({'1nM':'Low','1uM':'High'},level='TCR_Antigen_Density').droplevel('Cytokine')
# Add a single level containing all perturbations
#clippedValues["Perturbation"] = (clippedValues.index.get_level_values("TCR_Antigen_Density")
#                              + "+" + clippedValues.index.get_level_values("CAR_ITAM_Number")
#                              + "+" + clippedValues.index.get_level_values("TCR_ITAM_Number") )
#clippedValues = clippedValues.set_index("Perturbation", append=True)

# Prepare a color mapping each value in the Perturbation level to a color
# in the palette
dfDict = {}
levels = list(clippedValues.index.names)[::-1]
levels = [levels[1],levels[3],levels[0],levels[2]]
#palettes = ['Greys']*len(clippedValues.index.names)
palettes = ['Greys_r','Reds_r','Greens_r','Blues_r']
#palettes = [palettes[0],palettes[1],palettes[2],palettes[1]]
lutList = []
setPalette = sns.color_palette("Set2")
setPalette = [setPalette[x] for x in [3,5,6]]
speciesDict = {'CAR_ITAM_Number':['3','1'],'TCR_ITAM_Number':['10','4'],'TCR_Antigen_Density':['High','Low'],'TCR_Antigen':['N4','A2','Q4','T4','V4','G4','E1']}
times = clippedValues.columns
for level,p in zip(levels,palettes):
  subsetDf = clippedValues.reset_index()[levels+list(clippedValues.columns)]
  species = subsetDf.pop(level)
  #lut = dict(zip(speciesDict[level], sns.color_palette(p,len(species.unique()))))
  if level != 'Cytokine':
    lut = dict(zip(speciesDict[level], sns.color_palette(p,len(species.unique()))))
  else:
    #lut = dict(zip(speciesDict[level], setPalette))
    lut = dict(zip(speciesDict[level], cytokinePalette))
  lutList.append(lut)
  dfDict[level] = species.map(lut)

row_colors = pd.DataFrame(dfDict)
#cm = sns.clustermap(clippedValues.reset_index().loc[:,['E1','G4','V4','T4','Q4','A2','N4']],row_colors=row_colors,row_cluster=False,col_cluster=False,cmap='PuOr_r', cbar_kws={'shrink': 0.5})
cm = sns.clustermap(clippedValues.fillna(value=0).reset_index().loc[:,list(times)],row_colors=row_colors,
                    col_cluster=False,cmap='PuOr_r', cbar_kws={'shrink': 0.5,'orientation':'horizontal'},figsize=(7,10))
g = cm.ax_heatmap
cbar = g.collections[0].colorbar
cbar.set_ticks([])
#g.set_xticks([])
g.set_yticks([])
g.set_ylabel('')
for spine in g.spines.values():
    spine.set_visible(True)
cbar.outline.set_edgecolor('k')
cbar.outline.set_linewidth(2)
#g.set_xlabel('Stimulation')
g.set_xlabel('')

handles = [Patch(facecolor='w')]
labels = [r"$\bf{Perturbation:}$"]
titleList = ['CAR ITAM #','TCR ITAM #','TCR Ag Density','TCR Ag Strength'][::-1]
#titleList = [titleList[0],titleList[2],titleList[1],titleList[3]]

for level,lut,title in zip(levels,lutList,titleList):
  handles.append(Patch(facecolor='w'))
  labels.append(title)
  for name in lut:
    handles.append(Patch(facecolor=lut[name],edgecolor='k'))
    labels.append(name)
plt.legend(handles, labels,
           bbox_to_anchor=(0.9, 0.6),loc='center left', bbox_transform=cm.fig.transFigure,frameon=False)

cm.ax_cbar.set_position([cm.ax_col_dendrogram.get_position().x0, cm.ax_col_dendrogram.get_position().y0+0.02,cm.ax_col_dendrogram.get_position().width, 0.02])
cm.ax_cbar.set_title('')
cm.ax_cbar.tick_params(axis='x', length=10)

cm.ax_heatmap.annotate('',xytext=(0.5,1.12),xy=(1,1.12),arrowprops=dict(arrowstyle='-|>',color='darkorange',lw=3), xycoords='axes fraction')
cm.ax_heatmap.annotate('',xytext=(0,1.12),xy=(0.5,1.12),arrowprops=dict(arrowstyle='<-',color='purple',lw=3), xycoords='axes fraction')
#cm.ax_heatmap.annotate('',xytext=(0,1.12),xy=(0.5,1.12),arrowprops=dict(arrowstyle='|-|, widthB=0,widthA=0.5',color='purple',lw=3), xycoords='axes fraction')
t = cm.ax_heatmap.text(0.25,1.16,'Antagonism',va='center',ha='center',fontsize=12,color='purple',zorder=100,transform=cm.ax_heatmap.transAxes,fontweight='bold')
#t.set_bbox(dict(facecolor='white', edgecolor='white'))
t = cm.ax_heatmap.text(0.75,1.16,'Enhancement',va='center',ha='center',fontsize=12,color='darkorange',zorder=100,transform=cm.ax_heatmap.transAxes,fontweight='bold')
#t.set_bbox(dict(facecolor='white', edgecolor='white'))
c1 = plt.Circle((0.5, 1.12), 0.01, color='k', clip_on=False,transform=cm.ax_heatmap.transAxes,zorder=100)

cm.ax_row_colors.set_xticks([])
cm.ax_row_colors.set_xticklabels([])

#cm.ax_heatmap.add_patch(c1)

# cm.ax_row_colors.set_xticks([])
# cm.ax_row_colors.set_xticklabels([])

cm.ax_heatmap.annotate('',xytext=(0,-0.12),xy=(1,-0.12),arrowprops=dict(arrowstyle='->',color='k',lw=3), xycoords='axes fraction')
t = cm.ax_heatmap.text(0.5,-0.12,'Time (h)',va='center',ha='center',fontsize=16,color='k',zorder=100,transform=cm.ax_heatmap.transAxes)
t.set_bbox(dict(facecolor='white', edgecolor='white'))

#cm.ax_heatmap.annotate('(',xy=(0.028,-0.065), xycoords='axes fraction',fontsize=18)
#cm.ax_heatmap.annotate(')+CD19',xy=(0.957,-0.065), xycoords='axes fraction',fontsize=18)
# cm.ax_heatmap.annotate('(',xy=(-0.02,-0.065), xycoords='axes fraction',fontsize=18)
# cm.ax_heatmap.annotate(')+CD19',xy=(0.99,-0.065), xycoords='axes fraction',fontsize=16)
timesOfInterest = [1,6,24,36,48,72]
timesOfInterest = [1,3,6,12,18,24,30,36,42,48,60,72]
ticksOfInterest = [(list(times).index(x)+0.5) for x in timesOfInterest]
print(times)
print(ticksOfInterest)
cm.ax_heatmap.set_xticks(ticksOfInterest)
cm.ax_heatmap.set_xticklabels([str(int(x)) for x in timesOfInterest],rotation=0)

if do_save:
    cm.savefig(pj(fig_dir, '3B-synergismAntagonismMatrix-expanded_wide.pdf'),bbox_inches='tight',transparent=True)
sns.set_context('poster')

# Load data
In vitro mouse data used for fitting, model fits and model predictions for not-fitted conditions.

In [None]:
# Back to default scales
sns.set_context('notebook')

In [None]:
def perturb_concatenator(x):
    lbl = []
    if x["TCR_Antigen_Density"] == "1nM":
        lbl.append("AgDens")
    if x["CAR_ITAMs"] == "1":
        lbl.append("CARNum")
    if x["TCR_ITAMs"] == "4":
        lbl.append("TCRNum")
    lbl = "_".join(lbl)
    if lbl == "":
        lbl = "None"
    return lbl


In [None]:
df_model = pd.read_hdf(pj(data_dir, "dfs_model_data_ci_mcmc_both_conc.h5"), key="model")
df_data = pd.read_hdf(pj(data_dir, "dfs_model_data_ci_mcmc_both_conc.h5"), key="data")
df_err = pd.read_hdf(pj(data_dir, "dfs_model_data_ci_mcmc_both_conc.h5"), key="ci")  # log2-scale standard error of mean

df_model.rename({"Fitted": "Model fit", "Predicted":"Model prediction"})

# Add level for perturbation type, used for hues
new_idx_lvl = pd.Index(df_model.index.to_frame().apply(perturb_concatenator, axis=1), name="Perturbation")
df_model = df_model.set_index(new_idx_lvl, append=True)

new_idx_lvl = pd.Index(df_data.index.to_frame().apply(perturb_concatenator, axis=1), name="Perturbation")
df_data = df_data.to_frame().set_index(new_idx_lvl, append=True)
df_err = df_err.to_frame().set_index(new_idx_lvl, append=True)

# Rename levels to nicer labels for the plot
rename_pairs = {
    "Subset": "Origin",
    "CAR_ITAMs": "CAR ITAMs",
    "TCR_ITAMs": "TCR ITAMs",
    "TCR_Antigen": r"TCR Antigen model $\tau$ (s)"
}
for pair in rename_pairs.items():
    df_model.index.set_names(pair[1], level=pair[0], inplace=True)
    if pair[0] == "Subset": continue
    df_data.index.set_names(pair[1], level=pair[0], inplace=True)
    df_err.index.set_names(pair[1], level=pair[0], inplace=True)

In [None]:
with open(pj(root_dir, "data", "pep_tau_map_ot1.json"), "r") as handle:
    pep_tau_map = json.load(handle)

## Choose one k, m, f for CAR
I  take $k_S=1$, $m=2$, $f=1$ (because it's biologically meaningful to change to $m=1$ while keeping $f=1$ for 1-ITAM CAR).

In [None]:
df_model = df_model.xs("(1, 2, 1)", level="kmf", axis=0)
df_model

# Model vs data figure
Distinguish fitted vs predicted using line styles?
Something even clearer would be better, but stick to this for now and receive feedback from others.

Issue: what should we use for the x axis? Model $\tau$ or discrete labels for experimental peptides? If $\tau$, how to identify the peptides? Maybe use $\tau$ axis implicitly but only put xticks with peptide names?  

In [None]:
#@title Tools
# For adding subtitles in legends.
# Artist class: handle, containing a string
# Special class compared to the one in scripts.plotting, to match the present figure
class LegendSubtitle(object):
    def __init__(self, message, **text_properties):
        self.text = message
        self.text_props = text_properties
        self.labelwidth = " "*int(len(self.text)*1.33)
    def get_label(self, *args, **kwargs):
        return self.labelwidth  # no label, the artist itself is the text

# Handler class, give it text properties
class LegendSubtitleHandler(HandlerBase):
    def legend_artist(self, legend, orig_handle, fontsize, handlebox):
        x0, y0 = handlebox.xdescent, handlebox.ydescent
        title = mtext.Text(x0, y0, orig_handle.text, size=fontsize, **orig_handle.text_props)
        # Update the (empty) label to have a length enough to cover the whole artist text box
        # orig_handle.labelwidth = " "*int(np.ceil(len(orig_handle.text) + legend.handlelength))
        handlebox.add_artist(title)
        # Make the legend box wider if needed
        return title

from scripts.plotting import (prepare_styles, prepare_hues, handles_properties_legend, 
unique_markers, unique_dashes, prepare_markers, prepare_subplots)
from scripts.preprocess import read_conc_uM

In [None]:
df_data2 = df_data.copy()
df_data2["Origin"] = "Prediction"
# Mark the fitted conditions manually
df_data2.loc[("10", "3"), "Origin"] = "Fit"
df_data2 = df_data2.set_index("Origin", append=True).squeeze()
df_err2 = df_err.copy()
df_err2 = df_err2.loc[:, "Ratio"]
df_err2.index = df_data2.index

In [None]:
def perturb_decoder(x):
    if x == "None":
        return x
    xsplit = x.split("_")
    lbl = []
    for u in xsplit:
        if u == "AgDens":
            lbl.append("1 nM TCR Ag")
        elif u == "CARNum":
            lbl.append("1 CAR ITAM")
        elif u == "TCRNum":
            lbl.append("4 TCR ITAMs")

    return ",\n".join(lbl)

In [None]:
# One panel for the fit(s): 6Y, and maybe 6F if necessary later.
# Then two panels for predictions.
# Place legend in space left by missing panel

# columns = TCR_Antigen_Density (or Perturbation, later). Also hues
hue_lvl = "Perturbation"
def sort_pert(x):
    return (len(x), x[0])
hue_vals, _ = prepare_hues(df_model, hue_lvl,
                sortkws={"key":sort_pert, "reverse":False})
palette = perturb_palette

# styles = Subset
sty_lvl = "Perturbation"
# Alphabetical order is OK: "fit" comes before "prediction"
sty_vals, styles = prepare_styles(df_model, sty_lvl)
# Different markers in the data for 1 nM and 1uM.
mark_lvl = "Perturbation"
mark_vals, markers = prepare_markers(df_data2, mark_lvl,
                        sortkws={"key":len, "reverse":False})
marksize = 6
lwidth = 3.0
# cols = CAR_ITAMs, if available
row_lvl = "Origin"
col_lvl = rename_pairs["TCR_ITAMs"]
x_lvl = rename_pairs["TCR_Antigen"]
row_vals, col_vals, fig, axes = prepare_subplots(df_model,
            row_lvl=row_lvl, col_lvl=col_lvl,
            sortkws_col={"key":int, "reverse":True},
            sharey=False)
legwidth = 0.0
figwidth = max(1, len(col_vals))*2.5 + legwidth
fig.set_size_inches(figwidth, max(1, len(row_vals))*2.5)

for i in range(max(1, len(row_vals))):
    if len(row_vals) > 0:
        data_row = df_data2.xs(row_vals[i], level=row_lvl, drop_level=False)
        model_row = df_model.xs(row_vals[i], level=row_lvl, drop_level=False)
        err_row = df_err2.xs(row_vals[i], level=row_lvl, drop_level=False)
    else:
        data_row, model_row, err_row = df_data2, df_model, df_err2
    for j in range(max(1, len(col_vals))):
        if len(col_vals) > 0:
            try:
                dat_loc = data_row.xs(col_vals[j], level=col_lvl, drop_level=False)
                mod_loc = model_row.xs(col_vals[j], level=col_lvl, drop_level=False)
                err_loc = err_row.xs(col_vals[j], level=col_lvl, drop_level=False)
            except KeyError:  # This combination is not available, empty plot
                continue
        else:
            dat_loc, mod_loc, err_loc = data_row, model_row, err_row
        ax = axes[i, j]
        ax.axhline(1.0, ls=":", color="k")
        local_hue_vals = [h for h in hue_vals
                if h in dat_loc.index.get_level_values(hue_lvl).unique()]
        for h in local_hue_vals:
            # Plot data +- error transformed to linear scale
            err_loc2 = err_loc.xs(h, level=hue_lvl)
            dat_loc2 = dat_loc.xs(h, level=hue_lvl)
            # Make sure data and error have same index order for plotting
            err_loc2 = err_loc2.reindex(index=dat_loc2.index, copy=False)
            dat_log = np.log2(dat_loc2)
            # Compute linear scale error bars (asymmetric)
            # from symmetric log-scale error bars
            yup = 2**(err_loc2 + dat_log) - dat_loc2
            ylo = dat_loc2 - 2**(-err_loc2 + dat_log)
            yerr = np.vstack([ylo.values, yup.values])
            xvals = dat_loc2.index.get_level_values(x_lvl).values
            ax.errorbar(xvals, dat_loc2, xerr=None, yerr=yerr, marker=markers.get(h, "o"),
                ecolor=palette[h], mfc=palette[h], mec=palette[h],
                ms=marksize, ls="none",
            )

            # Fill between confidence interval of model, highlight median
            # Use different styles for fitted or predicted subsets
            mod_loc2 = mod_loc.xs(h, level=hue_lvl)
            # There should be only one style value left
            #assert len(mod_loc2.index.get_level_values(sty_lvl).unique()) == 1
            #sty_val = mod_loc2.index.get_level_values(sty_lvl)[0]
            sty_val = h
            xvals = mod_loc2.index.get_level_values(x_lvl).values
            ax.fill_between(xvals, mod_loc2["percentile_2.5"], mod_loc2["percentile_97.5"],
                color=palette[h], alpha=0.2)
            ax.plot(xvals, mod_loc2["best"], color=palette[h],
                    lw=lwidth, ls="-", label=h)  # styles[sty_val]
        ax.set_yscale("log", base=2)
        itam_lbl = "ITAMs" if int(col_vals[j]) > 1 else "ITAM"
        #ax.set_title("{} TCR {}".format(col_vals[j], itam_lbl), size=10, y=0.95)

for ax in axes.flat:
    ax.set_xlabel(r"TCR antigen strength, $\tau$ (s)", labelpad=0.2, fontsize=11)
    ax.tick_params(which="major", axis="both", length=3.5, labelsize=11)
    for side in ["top", "right"]:
        ax.spines[side].set_visible(False)
for ax in axes[:, 0]:
    ax.set_ylabel(r"$FC_{\mathrm{TCR/CAR}}$",  #\frac{\mathrm{Output(CD19 + TCR\,Ag)}}{\mathrm{Output(CD19)}}$",
                  fontsize=11, labelpad=0.2)

# Manually entitle graphs in inDesign later: fit and predict.
# Plot data +- error transformed to linear scale
#shaded_color = sns.set_hls_values(palette.get("None"), l=0.8, s=0.1)
#shaded_color = palette.get("None")

#err_loc2 = df_err2.xs("Fit", level="Origin")
#dat_loc2 = df_data2.xs("Fit", level="Origin")
# Make sure data and error have same index order for plotting
#err_loc2 = err_loc2.reindex(index=dat_loc2.index, copy=False)
#dat_log = np.log2(dat_loc2)
# Compute linear scale error bars (asymmetric)
# from symmetric log-scale error bars
#yup = 2**(err_loc2 + dat_log) - dat_loc2
#ylo = dat_loc2 - 2**(-err_loc2 + dat_log)
#yerr = np.vstack([ylo.values, yup.values])
#xvals = dat_loc2.index.get_level_values(x_lvl).values
#ax.errorbar(xvals, dat_loc2, xerr=None, yerr=yerr, marker=markers.get(h, "o"),
#    ecolor=shaded_color, mfc=shaded_color, mec=shaded_color,
#    ms=marksize, ls="none", zorder=-1, alpha=0.5
#)
# Add light-shaded model fit results in the predict row.
#ax = axes[1, 0]
#xvals = df_model.xs("Fit", level="Origin").index.get_level_values(x_lvl).values
#yvals = df_model.xs("Fit", level="Origin")["best"]
#ax.plot(xvals, yvals, color=shaded_color, ls="-", lw=lwidth, zorder=-2, alpha=0.5)

fig.tight_layout(h_pad=3.0)
axes[0, 1].set_axis_off()

# Add custom legend. Old version, more categorical
#hues = (hue_lvl, palette)
#styles = {k+" (model)":(a, None) for k,a in styles.items()}
#styles.update({k+ " (data)":("none", a) for k, a in markers.items()})
#styles = (sty_lvl, styles)
#legend_handles, legend_handler_map = handles_properties_legend(hues, styles, None)

# Version for main text figure, fully manual here
# Hues (line and linestyle and marker)
legend_handles = [LegendSubtitle(hue_lvl)]
palette_entries = sorted(list(palette.keys()), key=lambda x: len(perturb_decoder(x)))
for p in palette_entries:
    h = palette.get(p)
    if p not in styles: continue
    legend_handles.append(Line2D([0], [0], ls=styles[p], marker=markers[p], color=h,
                                 mec=h, mfc=h, label=perturb_decoder(p)))

legend_handles.append(LegendSubtitle("Origin"))
#for p, h in styles.items():
#    legend_handles.append(Line2D([0], [0], ls=h, marker=None, color="k", label=p))
legend_handles.append(Line2D([0], [0], ls="-", marker=None, color="k", label="Model"))
legend_handles.append(Line2D([0], [0], ls="none", marker="o", ms=marksize, mec="k", mfc="k", label="Data"))
legend_handler_map = {LegendSubtitle: LegendSubtitleHandler()}

if do_save:
    fig.savefig(pj(fig_dir, "3EF-model_data_ratio_comparison.pdf"), transparent=True, bbox_inches="tight", 
               bbox_extra_artists=[axes.flat[i].xaxis.get_label() for i in range(axes.size)])
plt.show()
plt.close()

In [None]:
# Separate legend for data-model
legend_handles = []
legend_handles.append(Line2D([0], [0], ls="-", marker=None, color="k", label="Model"))
legend_handles.append(mpl.patches.Patch(color=(0.7, 0.7, 0.7), label='Model confidence\ninterval'))
legend_handles.append(Line2D([0], [0], ls="none", marker="o", ms=marksize, mec="k", mfc="k", label="Data"))

figl, axl = plt.subplots()
axl.axis(False)
axl.set_axis_off()
leg = axl.legend(handles=legend_handles, frameon=False, ncol=3)
figl.set_size_inches(2.5, 0.5)
if do_save:
    figl.savefig(pj(fig_dir, "3EF-legend_model_data.pdf"), bbox_inches="tight",
                 bbox_extra_artists=(leg,), transparent=True)
plt.show()
plt.close()