<a href="https://colab.research.google.com/github/dtabuena/Workshop/blob/main/upset_DEG_Selection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import itertools
import numpy as np
import pandas as pd
import os


In [2]:
"""LOCAL CONFIG"""
dgc_01_dir = 'C:/Users/dennis.tabuena/Dropbox (Gladstone)/0_Projects/_Hyper+Crisper/DEG_Selection/DGC_2'
os.chdir(dgc_01_dir)
deg_file_loc_lists=['./E4vsE3_DEGs_wilcox_5mo_02_Dentate_Gyrus_Granule_Cells.csv',
                    './E4vsE3_DEGs_wilcox_10mo_02_Dentate_Gyrus_Granule_Cells.csv',
                    './E4vsE3_DEGs_wilcox_15mo_02_Dentate_Gyrus_Granule_Cells.csv',
                    './E4vsE3_DEGs_wilcox_20mo_02_Dentate_Gyrus_Granule_Cells.csv',]

In [None]:
"""
Criteria:
E4 > E3 at 5mo
E4 > E3 at 10mo
High_E4 > High_E3
fE4_NKO !> fE4
E4 !> E3 at 20mo
E3_20mo > E3_5mo
"""


In [None]:
def membership_dict_to_df(group_dict: dict) -> pd.DataFrame:
    """
    get dict with keys:group and values:list of items.
    convert dict to a pd.DataFrame with index:items and
    columns for membership in label. A numerical embeding
    of membership overlaps is created for simplicity after.
    A look up dict its returned to interpred embeddings
    """

    inicies = [x for v in group_dict.values() for x in v]
    mebership_df = pd.DataFrame(index=inicies,columns=list(group_dict.keys()))
    mebership_df['emb_combo'] = np.nan
    possible_combos = list(itertools.product([True,False], repeat=len(group_dict.keys())))
    combo_sum = [np.sum(c) for c in possible_combos]
    possible_combos = [possible_combos[i] for i in np.argsort(combo_sum)]
    embed_combos = {c:i for i,c in enumerate(possible_combos)}
    for r in mebership_df.index:
        row_bool = [ r in vals for vals in group_dict.values()]
        mebership_df.loc[r,list(group_dict.keys())] = row_bool
        mebership_df.loc[r,'emb_combo'] = embed_combos[tuple(row_bool)]
    return mebership_df, embed_combos

In [None]:
def upset_plot(group_dict, figsize=(3,3), exclude_all_none=True):

    'group_dict->dict with keys=categories/groups and values=list(members)'

    mebership_df, embed_combos = membership_dict_to_df(go_group_dict)
    possible_combos = list(embed_combos.keys())
    if exclude_all_none: #exclude members that do not intersect any group
        possible_combos = [c for c in possible_combos if np.array(c).any()]
        false_tupple = tuple(np.full((len(possible_combos[0])),False))
        del embed_combos[false_tupple]

    fig,ax=plt.subplots(2,2,figsize=figsize,width_ratios=[5,2],height_ratios=[5,2],dpi=300)
    null_ax = ax[0,1]
    null_ax.axis('off')
    combo_ax = ax[1,0]
    overlap_ax = ax[0,0]
    set_size_ax = ax[1,1]

    group_names = [c for c in mebership_df.columns if str(c) not in 'emb_combo']


    """Draw Dots and Connect"""
    true_xy = np.where(possible_combos)
    false_xy = np.where(np.logical_not(possible_combos))
    dot_size=12
    combo_ax.scatter(true_xy[0],true_xy[1],color='k',s=dot_size)
    combo_ax.scatter(false_xy[0],false_xy[1],color='lightgrey',s=dot_size)
    combo_ax.set_yticks(range(4),group_names)
    combo_ax.set_xticks([])
    combo_ax.set_ylim([-.5,len(group_names)-.5])
    my_map = np.cumsum(np.ones_like(possible_combos),axis=1)*possible_combos
    for row,vals in enumerate(my_map):
        vals_nz = [v for v in vals if v > 0]
        if np.sum(vals>0)>1:
            combo_ax.plot([row,row],[np.min(vals_nz)-1,np.max(vals_nz)-1],color='k',linewidth=.5 )



    """Plot Overlap Bars"""
    intersection_counts, bin_edges = np.histogram(mebership_df['emb_combo'],
                                                  bins=len(possible_combos),
                                                  range=[np.min(mebership_df['emb_combo'])-.5,np.max(mebership_df['emb_combo'])+.5]) #
    overlap_ax.bar(bin_edges[:-1]-.5,intersection_counts,color='k')
    overlap_ax.set_xlim(combo_ax.get_xlim() )
    overlap_ax.set_xticks([])
    overlap_ax.set_ylabel('Intersection (#)')


    """Plot Groups Sizes"""
    set_size_ax.barh( list(group_dict.keys()), [len(v) for v in group_dict.values()],color='k'   )
    set_size_ax.set_yticks([])
    set_size_ax.set_xlabel('Group Size (#)')

    return (fig,ax), mebership_df, embed_combos


fig_ax, mebership_df,embed_combos = upset_plot(go_group_dict,figsize=(4,2.5))