# ITS A PYTHON SCRIPT

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl

In [2]:
plt.style.use('ggplot')

# Functions

In [3]:
def calc_binding_concordance(df):
    assert df.size > 0, "df empty"
    gems_per_specificity_df = df.groupby(['clonotype','epitope']).gem.count().to_frame().reset_index()
    gems_per_specificity_df.rename(columns={'gem': 'gems_per_specificity'}, inplace=True)
    
    gems_per_clonotype_df = df.groupby(['clonotype']).gem.count().to_frame().reset_index()
    gems_per_clonotype_df.rename(columns={'gem': 'gems_per_clonotype'}, inplace=True)
    
    df = pd.merge(df, gems_per_specificity_df, on=['clonotype', 'epitope'], how='left', suffixes=('_total', '')).merge(gems_per_clonotype_df, on='clonotype', how='left', suffixes=('_total', ''))
    df['binding_concordance'] = df.gems_per_specificity / df.gems_per_clonotype
    
    return df

In [4]:
def epitope_sorter_index(df):
    EPITOPE_SORTER = ['CLYBL',
                      'v9', 'v15', 'v19', 'v3', 'v5', 'v6', 'v10', 'v13',
                      'v16', 'v24', 'v41', 'v2', 'v11', 'v18', 'v23', 'v25',
                      'v26', 'v27', 'v4', 'v7', 'v8', 'v12', 'v14', 'v20',
                      'v21', 'v1', 'v17', 'v22' ,'v30', 'v31', 'v36', 'v37',
                      'v38', 'v39', 'v40', 'v32', 'v33', 'v34', 'v35']

    sorterIndex = dict(zip(EPITOPE_SORTER,range(len(EPITOPE_SORTER))))
    
    return df.epitope.map(sorterIndex) #df['epitope_rank'] = 

# Peptide per GEM

In [5]:
def peptides_per_gem(credible_df, show=True, save_tuba=False, save_sund=False):
    """Make specificity plots"""
    
    import matplotlib as mpl
    mpl.rcParams['axes.grid.axis'] = 'y'
    
    sortby = 'num_clonotype'
    credible_df.sort_values(by=[sortby], inplace=True)
    
    # GEM to clonotype
    gem_to_clonotype = dict()
    for gem in credible_df.gem.unique():
        clonotypes = credible_df[credible_df.gem == gem].clonotype.values
        assert len(np.unique(clonotypes)) == 1, print(clonotypes, gem)
        gem_to_clonotype[gem] = clonotypes[0]

    # Clonotype to color
    all_clonotypes = credible_df.clonotype.unique()
    col_clonotypes = ['#9e0142','#d53e4f','#f46d43','#fdae61','#fee08b','#ffffbf','#e6f598','#abdda4','#66c2a5','#3288bd','#5e4fa2'] * len(all_clonotypes)
    clonotype_to_color = dict()
    for i, clonotype in enumerate(all_clonotypes):
        clonotype_to_color[clonotype] = col_clonotypes[i]
    
    # Plotting
    project = "peptide_per_gem/"
    
    for read_threshold in [1, 10]: #, 20, 50
        for tcr_threshold in [1, 10]: #, 20, 50
            for exclude_clonotype_singlets in [False, True]:
                for exclude_specificity_singlets in [False, True]:
                    unique_gems = set()

                    fig, ax = plt.subplots(figsize=(20, 10))
                    for mhc_barcode in credible_df.peptide_HLA.unique():
                        sub_df = credible_df[(credible_df.peptide_HLA == mhc_barcode) &
                                             (credible_df.read_counts_mhc >= read_threshold) &
                                             (credible_df.umis_tcr >= tcr_threshold)].copy()
                        if exclude_clonotype_singlets:
                            sub_df.drop(sub_df[~sub_df.duplicated(subset='num_clonotype', keep=False)].index, inplace=True)
                        if exclude_specificity_singlets:
                            sub_df.drop(sub_df[~sub_df.duplicated(subset=['num_clonotype', 'epitope'], keep=False)].index, inplace=True)

                        gems = sub_df.gem.values
                        mhcs = [mhc_barcode] * len(gems)
                        reads = sub_df.read_counts_mhc.values

                        clonotypes = np.array([gem_to_clonotype[gem] for gem in gems])
                        colors = [clonotype_to_color[ct] for ct in clonotypes]
                        unique_gems.update(sub_df.gem.to_list())

                        scatter = ax.scatter(gems, mhcs, s=reads, c=colors, edgecolors='face') #, alpha=0.3

                    plt.xlabel("%i GEMs" %len([item.get_text() for item in ax.get_xticklabels()]), fontsize=16)
                    plt.tick_params(labelbottom=False, labelright=True, labelsize=8) #8

                    # Criteria
                    textstr = '\n'.join((
                        "Criteria",
                        "Min. barcode read count \t %i".expandtabs() %read_threshold,
                        "Min. TCR read count \t\t    %i".expandtabs() %tcr_threshold,
                        "Exclude clonotype singlets \t %s".expandtabs() %str(exclude_clonotype_singlets),
                        "Exclude specificity singlets \t   %s".expandtabs() %str(exclude_specificity_singlets)))
                    props = dict(boxstyle='square', fc='white', ec='grey', alpha=0.5)
                    #textstr = '\n'.join(("test \t test".expandtabs(),"test \t %i".expandtabs() %(read_threshold),"tester \t test".expandtabs()))
                    ax.text(0.05, 0.95, textstr,
                            fontsize=12,
                            horizontalalignment='left',
                            verticalalignment='top',
                            clip_on=False,
                            transform=ax.transAxes,
                            bbox=props)

                    from matplotlib.lines import Line2D

                    legend_elements = []
                    for clonotype_label in all_clonotypes:
                        legend_elements += [Line2D([0], [0], marker='o', color='w', label=clonotype_label, markerfacecolor=clonotype_to_color[clonotype_label], markersize=10)]

                    legend1 = ax.legend(handles=legend_elements, ncol=9, loc=2, bbox_to_anchor=(0.02, -0.03))
                    ax.add_artist(legend1)

                    for read_size in [10, 50, 100]:
                        plt.scatter([], [], c='k', alpha=0.3, s=read_size, label=str(read_size) + ' reads')
                    plt.legend(scatterpoints=1, frameon=False, labelspacing=1, title='Barcode reads', loc='lower right')


                    plt.ylabel("Peptide, HLA", fontsize=16)
                    plt.title("Peptide specificity per GEM", fontsize=20)
                    
                    if show:
                        plt.show()
                        print("OBS! Figures at not saved!")
                        return
                    if save_tuba:
                        plt.savefig(FIG_DIR + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    if save_sund:
                        plt.savefig(FIG_SUND + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    
                    plt.cla()   # Clear axis
                    plt.clf()   # Clear figure
                    plt.close(fig) # Close a figure window

# Peptide per clonotype (GEM counts)

In [6]:
def peptide_per_clonotype_by_gem_size(credible_df, show=True, save_tuba=False, save_sund=False):
    import matplotlib as mpl
    mpl.rcParams['axes.grid.axis'] = 'both'#y
    
    sortby = 'clonotype'
    credible_df.sort_values(by=['epitope_rank', sortby], inplace=True)
    
    project = "peptide_per_clonotype_by_gem_size/"
    
    for read_threshold in [1, 10]: # , 20, 50
        for tcr_threshold in [1, 10]: #, 20, 50
            for exclude_clonotype_singlets in [False, True]:
                for exclude_specificity_singlets in [False, True]:

                    unique_gems, unique_tcrs = set(), set()

                    fig, ax = plt.subplots(figsize=(20, 10))

                    for mhc_barcode in credible_df.peptide_HLA.unique():

                        sub_df = credible_df[(credible_df.read_counts_mhc >= read_threshold) &
                                             (credible_df.umis_tcr >= tcr_threshold)].copy()
                        if exclude_clonotype_singlets:
                            sub_df.drop(sub_df[~sub_df.duplicated(subset='num_clonotype', keep=False)].index, inplace=True)
                        if exclude_specificity_singlets:
                            sub_df.drop(sub_df[~sub_df.duplicated(subset=['num_clonotype', 'epitope'], keep=False)].index, inplace=True)

                        sub_df = calc_binding_concordance(sub_df)
                        sub_df = sub_df[sub_df.peptide_HLA == mhc_barcode]

                        tcrs = sub_df.clonotype.str.split('clonotype').str[1].unique()
                        mhcs = [mhc_barcode] * len(tcrs)
                        gems = sub_df.groupby(['clonotype']).gems_per_specificity.mean().values
                        colors = sub_df.groupby(['clonotype']).binding_concordance.mean().values

                        unique_gems.update(sub_df.gem.to_list())
                        unique_tcrs.update(sub_df.clonotype.to_list())

                        scatter = ax.scatter(tcrs, mhcs, c=colors, cmap='viridis_r', norm=plt.Normalize(vmin=0, vmax=1), s=gems, edgecolors='face')

                    plt.tick_params(labelsize=8) 
                    plt.xticks(rotation=90, size=2)

                    sm = plt.cm.ScalarMappable(cmap='viridis_r', norm=plt.Normalize(vmin=0, vmax=1))
                    plt.colorbar(sm)

                    # Criteria
                    textstr = '\n'.join((
                        "Criteria",
                        "Min. barcode read count \t %i".expandtabs() %read_threshold,
                        "Min. TCR read count \t\t    %i".expandtabs() %tcr_threshold,
                        "Exclude clonotype singlets \t %s".expandtabs() %str(exclude_clonotype_singlets),
                        "Exclude specificity singlets \t   %s".expandtabs() %str(exclude_specificity_singlets)))
                    props = dict(boxstyle='square', fc='white', ec='grey', alpha=0.5)
                    ax.text(0.05, 0.95, textstr,
                            fontsize=12,
                            horizontalalignment='left',
                            verticalalignment='top',
                            clip_on=False,
                            transform=ax.transAxes,
                            bbox=props)

                    for number_gems in [2, 10, 50]:
                        plt.scatter([], [], c='k', alpha=0.3, s=number_gems, label=str(number_gems) + ' GEMs')
                    plt.legend(scatterpoints=1, frameon=False, labelspacing=1, title='GEMs', loc='lower right')

                    plt.xlabel("%i Clonotypes (of %i GEMs)" %(len(unique_tcrs), len(unique_gems)), fontsize=16)
                    plt.ylabel("Peptide, HLA", fontsize=16)
                    plt.title("Specificity concordance per clonotype", fontsize=20)

                    if show:
                        plt.show()
                        print("OBS! Figures are not saved!")
                        return
                    if save_tuba:
                        plt.savefig(FIG_DIR + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    if save_sund:
                        plt.savefig(FIG_SUND + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    plt.cla()   # Clear axis
                    plt.clf()   # Clear figure
                    plt.close(fig) # Close a figure window

# Peptide per clonotype (read counts)

In [7]:
def peptide_per_clonotype_read_counts(credible_df, show=True, save_tuba=False, save_sund=False):
    import matplotlib as mpl
    mpl.rcParams['axes.grid.axis'] = 'both'#y
    
    
    
    #credible_df.epitope = credible_df.epitope.astype("category")
    #credible_df.epitope.cat.set_categories(EPITOPE_SORTER, inplace=True)
    sortby = 'clonotype'
    credible_df.sort_values(by=['epitope_rank', sortby], inplace=True)
    
    project = "peptide_per_clonotype_by_read_size/"
    
    exclude_low_concordance_clonotypes = False
    show_tcr_multiplicity = False
    
    for read_threshold in [1, 10, 20]: #, 20, 50
        for tcr_threshold in [1, 10, 20]: #, 20, 50
            for exclude_clonotype_singlets in [False, True]:
                for exclude_specificity_singlets in [False, True]:
                    #print(str(read_threshold), str(tcr_threshold), str(exclude_clonotype_singlets), str(exclude_specificity_singlets))

                    unique_gems, unique_tcrs = set(), set()

                    fig, ax = plt.subplots(figsize=(20, 10))

                    for mhc_barcode in credible_df.peptide_HLA.unique(): #credible_df.epitope.unique(): #specificity_df.columns:
                        sub_df = credible_df[(credible_df.read_counts_mhc >= read_threshold) &
                                             (credible_df.umis_tcr >= tcr_threshold)].copy()
                        if exclude_clonotype_singlets:
                            sub_df.drop(sub_df[~sub_df.duplicated(subset='num_clonotype', keep=False)].index, inplace=True)
                        if exclude_specificity_singlets:
                            sub_df.drop(sub_df[~sub_df.duplicated(subset=['num_clonotype', 'epitope'], keep=False)].index, inplace=True)

                        sub_df = calc_binding_concordance(sub_df)
                        sub_df = sub_df[sub_df.peptide_HLA == mhc_barcode]

                        if exclude_low_concordance_clonotypes:
                            sub_df.drop()

                        tcrs = sub_df.clonotype.str.split('clonotype').str[1].unique()
                        mhcs = [mhc_barcode] * len(tcrs)
                        gems = sub_df.groupby(['clonotype']).read_counts_mhc.mean().values #groupby(['clonotype']).gem.count().values
                        errs = sub_df.groupby(['clonotype']).read_counts_mhc.std().values
                        colors = sub_df.groupby(['clonotype']).binding_concordance.mean().values # should binding concordance be calculated after subsetting by the threshold?

                        unique_gems.update(sub_df.gem.to_list())
                        unique_tcrs.update(sub_df.clonotype.to_list())
                            
                        # How to show number of GEMs? Plot a different symbol if only one GEM?
                        ax.scatter(tcrs, mhcs, c='red', s=gems+errs, edgecolors='face',alpha=0.3) #, alpha=0.3
                        scatter = ax.scatter(tcrs, mhcs, c=colors, cmap='viridis_r', norm=plt.Normalize(vmin=0, vmax=1), s=gems, edgecolors='face') #, alpha=0.3

                        if show_tcr_multiplicity:
                            tcrs = sub_df[sub_df.single_tcell == False].clonotype.str.split('clonotype').str[1].unique()
                            mhcs = [mhc_barcode] * len(tcrs)
                            colors = np.where(sub_df[sub_df.single_tcell == False].groupby(['clonotype']).binding_concordance.mean().values > 0.5, 'white', 'k')
                            ax.scatter(tcrs, mhcs, c=colors, marker='+', edgecolors='k') #, alpha=0.3

                    plt.tick_params(labelsize=8) #labelbottom=False, labelright=True, 
                    plt.xticks(rotation=90, size=2)

                    #plt.colorbar(cmap='viridis_r')
                    sm = plt.cm.ScalarMappable(cmap='viridis_r', norm=plt.Normalize(vmin=0, vmax=1))
                    plt.colorbar(sm)

                    for number_gems in [10, 50, 100]:
                        plt.scatter([], [], c='k', alpha=0.3, s=number_gems, label=str(number_gems) + ' reads')
                    for marker, response in [('+', "TCR singlet")]: #('o', True), 
                        plt.scatter([], [], c='k', alpha=0.3, s=50, label=str(response), marker=marker)
                    plt.legend(scatterpoints=1, frameon=False, labelspacing=1, loc='lower right') #title='Reads', 

                    # Criteria
                    # Criteria
                    textstr = '\n'.join((
                        "Criteria",
                        "Min. barcode read count \t %i".expandtabs() %read_threshold,
                        "Min. TCR read count \t\t    %i".expandtabs() %tcr_threshold,
                        "Exclude clonotype singlets \t %s".expandtabs() %str(exclude_clonotype_singlets),
                        "Exclude specificity singlets \t   %s".expandtabs() %str(exclude_specificity_singlets)))
                    props = dict(boxstyle='square', fc='white', ec='grey', alpha=0.5)
                    ax.text(0.63, 0.146, textstr,
                            fontsize=12,
                            horizontalalignment='left',
                            verticalalignment='top',
                            clip_on=False,
                            transform=ax.transAxes,
                            bbox=props)

                    plt.xlabel("%i Clonotypes (of %i GEMs)" %(len(unique_tcrs), len(unique_gems)), fontsize=16)
                    plt.ylabel("Peptide, HLA", fontsize=16)
                    plt.title("Specificity concordance per clonotype", fontsize=20)
                    
                    if show:
                        plt.show()
                        print("OBS! Figures are not saved!")
                        return
                    if save_tuba:
                        plt.savefig(FIG_DIR + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    if save_sund:
                        plt.savefig(FIG_SUND + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    plt.cla()   # Clear axis
                    plt.clf()   # Clear figure
                    plt.close(fig) # Close a figure window

# MHC barcode read count per clonotype

In [8]:
def mhc_read_count_per_clonotype(credible_df, show=True, save_tuba=False, save_sund=False):
    import matplotlib as mpl
    mpl.rcParams['axes.grid.axis'] = 'y'#both

    # Epitope to color
    all_epitopes = credible_df.peptide_HLA.unique()
    col_epitopes = ['#9e0142','#d53e4f','#f46d43','#fdae61','#fee08b','#ffffbf','#e6f598','#abdda4','#66c2a5','#3288bd','#5e4fa2'] * len(all_epitopes)
    epitope_to_color = dict()
    for i, epitope in enumerate(all_epitopes):
        epitope_to_color[epitope] = col_epitopes[i]
        
    # Detected response
    sortby = 'read_counts_mhc' #'umis_tcr' #
    credible_df.sort_values(by=['num_clonotype', sortby], inplace=True)
    
    project = "read_count_per_clonotype/"
    
    for read_threshold in [1, 10, 20]: #, 20, 50
        for tcr_threshold in [1, 10]: #, 20, 50
            for exclude_clonotype_singlets in [False, True]:
                for exclude_specificity_singlets in [False, True]:
                    unique_gems, unique_tcrs = set(), set()

                    fig, ax = plt.subplots(figsize=(20, 10))

                    xmin, xmax = -0.5, 0

                    for i, clonotype in enumerate(credible_df.num_clonotype.unique()):
                        sub_df = credible_df[(credible_df.num_clonotype == clonotype) &
                                             (credible_df.read_counts_mhc >= read_threshold) &
                                             (credible_df.umis_tcr >= tcr_threshold)].copy() # & (credible_df.clonotype != 'None') & (credible_df.epitope != '0')
                        if exclude_clonotype_singlets:
                            sub_df.drop(sub_df[~sub_df.duplicated(subset='num_clonotype', keep=False)].index, inplace=True)
                        if exclude_specificity_singlets:
                            sub_df.drop(sub_df[~sub_df.duplicated(subset=['num_clonotype', 'epitope'], keep=False)].index, inplace=True)

                        gems = sub_df.gem.to_list()
                        mhc_read_counts = sub_df.read_counts_mhc.to_list()

                        xmax += len(np.unique(gems))
                        unique_gems.update(sub_df.gem.to_list())
                        unique_tcrs.update(sub_df.clonotype.to_list())

                        epitopes = sub_df.peptide_HLA.to_list()
                        colors = [epitope_to_color[ep] for ep in epitopes]

                        # How to show number of GEMs? Plot a different symbol if only one GEM?
                        #ax.scatter(gems, [-5]*len(gems))
                        scatter = ax.scatter(gems, mhc_read_counts, c=colors) #, edgecolors='face', cmap='viridis_r', norm=plt.Normalize(vmin=0, vmax=1), s=gems, edgecolors='face', alpha=0.3

                        if i % 2 == 0:
                            plt.axvspan(xmin, xmax-0.5, facecolor='0.7', alpha=0.1)

                        xmin = xmax-0.5

                    from matplotlib.lines import Line2D

                    legend_elements = []
                    for epitope_label in all_epitopes:
                        legend_elements += [Line2D([0], [0], marker='o', color='w', label=epitope_label, markerfacecolor=epitope_to_color[epitope_label], markersize=10)]

                    legend1 = ax.legend(handles=legend_elements, ncol=7, loc=2, bbox_to_anchor=(0.02, -0.03))
                    ax.add_artist(legend1)

                    plt.tick_params(labelbottom=False, labelright=True, labelsize=8) #labelbottom=False, 
                    plt.xticks(rotation=90, size=2)

                    # Criteria
                    textstr = '\n'.join((
                        "Criteria",
                        "Min. barcode read count \t %i".expandtabs() %read_threshold,
                        "Min. TCR read count \t\t    %i".expandtabs() %tcr_threshold,
                        "Exclude clonotype singlets \t %s".expandtabs() %str(exclude_clonotype_singlets),
                        "Exclude specificity singlets \t   %s".expandtabs() %str(exclude_specificity_singlets)))
                    props = dict(boxstyle='square', fc='white', ec='grey', alpha=0.5)
                    ax.text(0.12, 0.98, textstr,
                            fontsize=12,
                            horizontalalignment='left',
                            verticalalignment='top',
                            clip_on=False,
                            transform=ax.transAxes,
                            bbox=props)

                    plt.xlabel("%i GEMs (sectioned per clonotype (%i))" %(len(unique_gems), len(unique_tcrs)), fontsize=16)
                    plt.ylabel("pMHC barcode read counts", fontsize=16)
                    plt.title("MHC barcode read counts per GEM per clonotype", fontsize=20)
                    
                    if show:
                        plt.show()
                        print("OBS! Figures at not saved!")
                        return
                    if save_tuba:
                        plt.savefig(FIG_DIR + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    if save_sund:
                        plt.savefig(FIG_SUND + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    plt.cla()   # Clear axis
                    plt.clf()   # Clear figure
                    plt.close(fig) # Close a figure window

## Response

In [9]:
def mhc_read_count_per_clonotype_response(credible_df, show=True, save_tuba=False, save_sund=False):
    import matplotlib as mpl
    mpl.rcParams['axes.grid.axis'] = 'y'#both

    # Epitope to color
    all_epitopes = credible_df.peptide_HLA.unique()
    col_epitopes = ['#9e0142','#d53e4f','#f46d43','#fdae61','#fee08b','#ffffbf','#e6f598','#abdda4','#66c2a5','#3288bd','#5e4fa2'] * len(all_epitopes)
    epitope_to_color = dict()
    for i, epitope in enumerate(all_epitopes):
        epitope_to_color[epitope] = col_epitopes[i]
        
    # Detected response
    sortby = 'read_counts_mhc' #'umis_tcr' #
    credible_df.sort_values(by=['num_clonotype', sortby], inplace=True)
    
    project = "read_count_per_clonotype_response/"
    
    for read_threshold in [1, 10]: #, 20, 50
        for tcr_threshold in [1, 10]: #, 20, 50
            for exclude_clonotype_singlets in [False, True]:
                for exclude_specificity_singlets in [False, True]:
                    unique_gems, unique_tcrs = set(), set()

                    fig, ax = plt.subplots(figsize=(20, 10))

                    xmin, xmax = -0.5, 0

                    for i, clonotype in enumerate(credible_df.num_clonotype.unique()):
                        for marker, response in [('o', True), ('+', False)]:
                            sub_df = credible_df[(credible_df.num_clonotype == clonotype) &
                                                 (credible_df.read_counts_mhc >= read_threshold) &
                                                 (credible_df.detected_response == response) &
                                                 (credible_df.umis_tcr >= tcr_threshold)].copy() # & (credible_df.clonotype != 'None') & (credible_df.epitope != '0')
                            if exclude_clonotype_singlets:
                                sub_df.drop(sub_df[~sub_df.duplicated(subset='num_clonotype', keep=False)].index, inplace=True)
                            if exclude_specificity_singlets:
                                sub_df.drop(sub_df[~sub_df.duplicated(subset=['num_clonotype', 'epitope'], keep=False)].index, inplace=True)

                            gems = sub_df.gem.to_list()
                            mhc_read_counts = sub_df.read_counts_mhc.to_list()

                            xmax += len(np.unique(gems))
                            unique_gems.update(sub_df.gem.to_list())
                            unique_tcrs.update(sub_df.clonotype.to_list())

                            epitopes = sub_df.peptide_HLA.to_list()
                            colors = [epitope_to_color[ep] for ep in epitopes]

                            # How to show number of GEMs? Plot a different symbol if only one GEM?
                            #ax.scatter(gems, [-5]*len(gems))
                            scatter = ax.scatter(gems, mhc_read_counts, marker=marker, c=colors) #, edgecolors='face', cmap='viridis_r', norm=plt.Normalize(vmin=0, vmax=1), s=gems, edgecolors='face', alpha=0.3

                        if i % 2 == 0:
                            plt.axvspan(xmin, xmax-0.5, facecolor='0.7', alpha=0.1)

                        xmin = xmax-0.5

                    from matplotlib.lines import Line2D

                    legend_elements = []
                    for epitope_label in all_epitopes:
                        legend_elements += [Line2D([0], [0], marker='o', color='w', label=epitope_label, markerfacecolor=epitope_to_color[epitope_label], markersize=10)]

                    legend1 = ax.legend(handles=legend_elements, ncol=7, loc=2, bbox_to_anchor=(0.02, -0.03))
                    ax.add_artist(legend1)

                    plt.tick_params(labelbottom=False, labelright=True, labelsize=8) #labelbottom=False, 
                    plt.xticks(rotation=90, size=2)

                    for marker, response in [('o', True), ('+', False)]:
                        plt.scatter([], [], c='k', label=str(response), marker=marker)
                    plt.legend(scatterpoints=1, frameon=False, labelspacing=1, title='Detected response', loc='upper left')

                    # Criteria
                    textstr = '\n'.join((
                        "Criteria",
                        "Min. barcode read count \t %i".expandtabs() %read_threshold,
                        "Min. TCR read count \t\t    %i".expandtabs() %tcr_threshold,
                        "Exclude clonotype singlets \t %s".expandtabs() %str(exclude_clonotype_singlets),
                        "Exclude specificity singlets \t   %s".expandtabs() %str(exclude_specificity_singlets)))
                    props = dict(boxstyle='square', fc='white', ec='grey', alpha=0.5)
                    ax.text(0.12, 0.98, textstr,
                            fontsize=12,
                            horizontalalignment='left',
                            verticalalignment='top',
                            clip_on=False,
                            transform=ax.transAxes,
                            bbox=props)

                    plt.xlabel("%i GEMs (sectioned per clonotype (%i))" %(len(unique_gems), len(unique_tcrs)), fontsize=16)
                    plt.ylabel("pMHC barcode read counts", fontsize=16)
                    plt.title("MHC barcode read counts per GEM per clonotype", fontsize=20)
                    
                    if show:
                        plt.show()
                        print("OBS! Figures at not saved!")
                        return
                    if save_tuba:
                        plt.savefig(FIG_DIR + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    if save_sund:
                        plt.savefig(FIG_SUND + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    plt.cla()   # Clear axis
                    plt.clf()   # Clear figure
                    plt.close(fig) # Close a figure window

## Peptide assayed

In [10]:
def mhc_read_count_per_clonotype_peptide_assayed(credible_df, show=True, save_tuba=False, save_sund=False):
    import matplotlib as mpl
    mpl.rcParams['axes.grid.axis'] = 'y'#both
    
    # Epitope to color
    all_epitopes = credible_df.peptide_HLA.unique()
    col_epitopes = ['#9e0142','#d53e4f','#f46d43','#fdae61','#fee08b','#ffffbf','#e6f598','#abdda4','#66c2a5','#3288bd','#5e4fa2'] * len(all_epitopes)
    epitope_to_color = dict()
    for i, epitope in enumerate(all_epitopes):
        epitope_to_color[epitope] = col_epitopes[i]
        
    # Detected response
    sortby = 'read_counts_mhc' #'umis_tcr' #
    credible_df.sort_values(by=['num_clonotype', sortby], inplace=True)

    project = "read_count_per_clonotype_peptide_assayed/"
    
    for read_threshold in [1, 10]: # , 20, 50
        for tcr_threshold in [1, 10]: #, 20, 50
            for exclude_clonotype_singlets in [False, True]:
                for exclude_specificity_singlets in [False, True]:
                    #print(str(read_threshold), str(tcr_threshold), str(exclude_clonotype_singlets), str(exclude_specificity_singlets))

                    unique_gems, unique_tcrs = set(), set()

                    fig, ax = plt.subplots(figsize=(20, 10))

                    xmin, xmax = -0.5, 0

                    for i, clonotype in enumerate(credible_df.num_clonotype.unique()):
                        for marker, response in [('o', True), ('+', False)]:
                            sub_df = credible_df[(credible_df.num_clonotype == clonotype) &
                                                 (credible_df.read_counts_mhc >= read_threshold) &
                                                 (credible_df.peptide_assayed == response) &
                                                 (credible_df.umis_tcr >= tcr_threshold)].copy() # & (credible_df.clonotype != 'None') & (credible_df.epitope != '0')
                            if exclude_clonotype_singlets:
                                sub_df.drop(sub_df[~sub_df.duplicated(subset='num_clonotype', keep=False)].index, inplace=True)
                            if exclude_specificity_singlets:
                                sub_df.drop(sub_df[~sub_df.duplicated(subset=['num_clonotype', 'epitope'], keep=False)].index, inplace=True)


                            gems = sub_df.gem.to_list()
                            mhc_read_counts = sub_df.read_counts_mhc.to_list()

                            xmax += len(np.unique(gems))
                            unique_gems.update(sub_df.gem.to_list())
                            unique_tcrs.update(sub_df.clonotype.to_list())

                            epitopes = sub_df.peptide_HLA.to_list()
                            colors = [epitope_to_color[ep] for ep in epitopes]
                            #tcr_umis = credible_df.umis_tcr.to_list()

                            # How to show number of GEMs? Plot a different symbol if only one GEM?
                            #ax.scatter(gems, [-5]*len(gems))
                            scatter = ax.scatter(gems, mhc_read_counts, marker=marker, c=colors) #s=tcr_umis, , edgecolors='face', cmap='viridis_r', norm=plt.Normalize(vmin=0, vmax=1), s=gems, edgecolors='face', alpha=0.3

                        if i % 2 == 0:
                            plt.axvspan(xmin, xmax-0.5, facecolor='0.7', alpha=0.1)

                        xmin = xmax-0.5

                    from matplotlib.lines import Line2D

                    legend_elements = []
                    for epitope_label in all_epitopes:
                        legend_elements += [Line2D([0], [0], marker='o', color='w', label=epitope_label, markerfacecolor=epitope_to_color[epitope_label], markersize=10)]

                    legend1 = ax.legend(handles=legend_elements, ncol=7, loc=2, bbox_to_anchor=(0.02, -0.03))
                    ax.add_artist(legend1)

                    plt.tick_params(labelbottom=False, labelright=True, labelsize=8) #labelbottom=False, 
                    plt.xticks(rotation=90, size=2)

                    # OBS! DO NOT DELETE
                    #for size in [10, 50, 100]:
                    #    plt.scatter([], [], c='k', s=size, label=str(size) + " TCR UMIs", marker='o')
                    #legend2 = ax.legend(scatterpoints=1, frameon=False, labelspacing=1, title='TCR UMIs', loc='upper left')
                    #ax.add_artist(legend2)
                    #for marker, response in [('+', "no peptide assay")]: #('o', True), 
                    #    plt.scatter([], [], c='k', label=str(response), marker=marker)
                    #legend3 = ax.legend(scatterpoints=1, frameon=False, labelspacing=1, title='Label indicator', loc='upper left')
                    #ax.add_artist(legend3)

                    for marker, response in [('o', True), ('+', False)]:
                        plt.scatter([], [], c='k', label=str(response), marker=marker)
                    plt.legend(scatterpoints=1, frameon=False, labelspacing=1, title='Peptide assay', loc='upper left')

                    # Criteria
                    textstr = '\n'.join((
                        "Criteria",
                        "Min. barcode read count \t %i".expandtabs() %read_threshold,
                        "Min. TCR read count \t\t    %i".expandtabs() %tcr_threshold,
                        "Exclude clonotype singlets \t %s".expandtabs() %str(exclude_clonotype_singlets),
                        "Exclude specificity singlets \t   %s".expandtabs() %str(exclude_specificity_singlets)))
                    props = dict(boxstyle='square', fc='white', ec='grey', alpha=0.5)
                    ax.text(0.12, 0.98, textstr,
                            fontsize=12,
                            horizontalalignment='left',
                            verticalalignment='top',
                            clip_on=False,
                            transform=ax.transAxes,
                            bbox=props)

                    plt.xlabel("%i GEMs (sectioned per clonotype (%i))" %(len(unique_gems), len(unique_tcrs)), fontsize=16)
                    plt.ylabel("pMHC barcode read counts", fontsize=16)
                    plt.title("MHC barcode read counts per GEM per clonotype", fontsize=20)
                    
                    if show:
                        plt.show()
                        print("OBS! Figures at not saved!")
                        return
                    if save_tuba:
                        plt.savefig(FIG_DIR + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    if save_sund:
                        plt.savefig(FIG_SUND + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    plt.cla()   # Clear axis
                    plt.clf()   # Clear figure
                    plt.close(fig) # Close a figure window

# TCR read count per clonotype

In [11]:
def tcr_read_count_per_clonotype_detected_response(credible_df, show=True, save_tuba=False, save_sund=False):
    mpl.rcParams['axes.grid.axis'] = 'y'#both
    
    # Epitope to color
    all_epitopes = credible_df.peptide_HLA.unique()
    col_epitopes = ['#9e0142','#d53e4f','#f46d43','#fdae61','#fee08b','#ffffbf','#e6f598','#abdda4','#66c2a5','#3288bd','#5e4fa2'] * len(all_epitopes)
    epitope_to_color = dict()
    for i, epitope in enumerate(all_epitopes):
        epitope_to_color[epitope] = col_epitopes[i]
    
    sortby = 'umis_tcr'
    credible_df.sort_values(by=['num_clonotype', sortby], inplace=True)
    
    project = "tcr_read_count_per_clonotype_detected_response/"
    
    for read_threshold in [1, 10, 20]: #, 50
        for tcr_threshold in [1, 10]: #, 20, 50
            for exclude_clonotype_singlets in [False, True]:
                for exclude_specificity_singlets in [False, True]:
                    unique_gems, unique_tcrs = set(), set()
                    
                    fig, ax = plt.subplots(figsize=(20, 10))

                    xmin, xmax = -0.5, 0

                    for i, clonotype in enumerate(credible_df.num_clonotype.unique()):
                        for marker, response in [('o', True), ('+', False)]:
                            sub_df = credible_df[(credible_df.num_clonotype == clonotype) &
                                                 (credible_df.read_counts_mhc >= read_threshold) &
                                                 (credible_df.umis_tcr >= tcr_threshold) & 
                                                 (credible_df.detected_response == response)].copy()
                            if exclude_clonotype_singlets:
                                sub_df.drop(sub_df[~sub_df.duplicated(subset='num_clonotype', keep=False)].index, inplace=True)
                            if exclude_specificity_singlets:
                                sub_df.drop(sub_df[~sub_df.duplicated(subset=['num_clonotype', 'epitope'], keep=False)].index, inplace=True)

                            gems = sub_df.gem.to_list()
                            tcr_read_counts = sub_df.umis_tcr.values

                            xmax += len(np.unique(gems))
                            
                            unique_gems.update(sub_df.gem.to_list())
                            unique_tcrs.update(sub_df.clonotype.to_list())

                            epitopes = sub_df.peptide_HLA.to_list()
                            colors = [epitope_to_color[ep] for ep in epitopes]

                            # How to show number of GEMs? Plot a different symbol if only one GEM?
                            #ax.scatter(gems, [-5]*len(gems))
                            scatter = ax.scatter(gems, tcr_read_counts, marker=marker, c=colors) #, edgecolors='face', cmap='viridis_r', norm=plt.Normalize(vmin=0, vmax=1), s=gems, edgecolors='face', alpha=0.3

                        if i % 2 == 0:
                            plt.axvspan(xmin, xmax-0.5, facecolor='0.7', alpha=0.1)

                        xmin = xmax-0.5

                    from matplotlib.lines import Line2D

                    legend_elements = []
                    for epitope_label in all_epitopes:
                        legend_elements += [Line2D([0], [0], marker='o', color='w', label=epitope_label, markerfacecolor=epitope_to_color[epitope_label], markersize=10)]

                    legend1 = ax.legend(handles=legend_elements, ncol=8, loc=2, bbox_to_anchor=(0.02, -0.03))
                    ax.add_artist(legend1)

                    plt.tick_params(labelbottom=False, labelright=True, labelsize=8) #labelbottom=False, 
                    plt.xticks(rotation=90, size=2)

                    for marker, response in [('o', True), ('+', False)]:
                        plt.scatter([], [], c='k', label=str(response), marker=marker)
                    plt.legend(scatterpoints=1, frameon=False, labelspacing=1, title='Detected response', loc='upper right')

                    # Criteria
                    textstr = '\n'.join((
                        "Criteria",
                        "Min. barcode read count \t %i".expandtabs() %read_threshold,
                        "Min. TCR read count \t\t    %i".expandtabs() %tcr_threshold,
                        "Exclude clonotype singlets \t %s".expandtabs() %str(exclude_clonotype_singlets),
                        "Exclude specificity singlets \t   %s".expandtabs() %str(exclude_specificity_singlets)))
                    props = dict(boxstyle='square', fc='white', ec='grey', alpha=0.5)
                    ax.text(0.05, 0.95, textstr,
                            fontsize=12,
                            horizontalalignment='left',
                            verticalalignment='top',
                            clip_on=False,
                            transform=ax.transAxes,
                            bbox=props)

                    plt.xlabel("%i GEMs (sectioned per clonotype (%i))" %(len(unique_gems), len(unique_tcrs)), fontsize=16)
                    plt.ylabel("TCR barcode read counts", fontsize=16)
                    plt.title("TCR barcode read counts per GEM per clonotype", fontsize=20)
                    
                    if show:
                        plt.show()
                        print("OBS! Figures at not saved!")
                        return
                    if save_tuba:
                        plt.savefig(FIG_DIR + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    if save_sund:
                        plt.savefig(FIG_SUND + project + "sortby_%s.b%i.t%i.ecs_%s.ess_%s.pdf" %(sortby, read_threshold, tcr_threshold, exclude_clonotype_singlets, exclude_specificity_singlets), bbox_inches='tight')
                    plt.cla()   # Clear axis
                    plt.clf()   # Clear figure
                    plt.close(fig) # Close a figure window