In [1]:
import swan_vis as swan
import pandas as pd
import anndata
import scanpy as sc
from scipy import sparse
import numpy as np
import time
import copy
import matplotlib.pyplot as plt
import matplotlib as mpl

In [2]:
sg = swan.read('/Users/fairliereese/mortazavi_lab/data/mousewg/adrenal/lr_splitseq/swan/swan.p')

c_dict = {'Medulla_NE':'#97578a',
                    'Medulla_EPI':'#339470',
                    'Sox10+':'#753b74',
                    'Stromal':'#e2969b',
                    'Adipocytes':'#e25e2c',
                    'Hepatocyte':'#ad7797',
                    'Smooth_muscle':'#86b84d',
                    'Macrophage':'#da5774',
                    'Endothelial':'#a63b4c',
                    'Cortex/Endothelial':'#f99d26',
                    'Cortex_ZF':'#f0c130',
                    'Cortex_ZG':'#b2373a',
                    'Cortex_cycling':'#3f4075',
                    'X_zone':'#fade7c',
                    'Y_zone':'#e47381',
                    'Capsule':'#236d88',
                    'Other':'grey'}
order = order =  ['Medulla_NE','Medulla_EPI','Sox10+',
                  'Stromal','Adipocytes','Hepatocyte',
                  'Smooth_muscle','Macrophage','Endothelial',
                  'Cortex/Endothelial','Cortex_ZF','Cortex_ZG',
                  'Cortex_cycling','X_zone','Y_zone',
                  'Capsule','Other']

# temp fix for x zone and y zone names 
# till I can regenerate swan graph
m = {'X-Zone': 'X_zone', 
     'Y-Zone': 'Y_zone'}
zones = ['X-Zone', 'Y-Zone']
sg.adata.obs.loc[sg.adata.obs.lr_celltype.isin(zones), 'lr_celltype'] = sg.adata.obs.loc[sg.adata.obs.lr_celltype.isin(zones)].lr_celltype.map(m)
sg.set_metadata_colors('lr_celltype', c_dict)

Read in graph from /Users/fairliereese/mortazavi_lab/data/mousewg/adrenal/lr_splitseq/swan/swan.p


In [3]:
def subset_on_gene_sg(sg, gid=None, datasets=None):
    """
    Subset the swan Graph on a given gene and return the subset graph.

    Parameters:
        gid (str): Gene ID to subset on
        datasets (list of str): List of datasets to keep in the subset

    returns:
        subset_sg (swan Graph): Swan Graph subset on the input gene.
    """

    # didn't ask for either
    if not gid and not datasets:
        return sg

    # subset on gene
    if gid:
        # make sure this gid is even in the Graph
        sg.check_gene(gid)

        # get the strand
        strand = sg.get_strand_from_gid(gid)

        # subset t_df first, it's the easiest
        tids = sg.t_df.loc[sg.t_df.gid == gid].index.tolist()
        t_df = sg.t_df.loc[tids].copy(deep=True)
        t_df['path'] = sg.t_df.loc[tids].apply(
                lambda x: copy.deepcopy(x.path), axis=1)
        t_df['loc_path'] = sg.t_df.loc[tids].apply(
                lambda x: copy.deepcopy(x.loc_path), axis=1)
        
        # since we don't keep all transcripts in adata, make
        # sure to pare that down
        tids = list(set(tids)&set(sg.adata.var.index.tolist()))

        # subset loc_df based on all the locs that are in the paths from
        # the already-subset t_df
        paths = t_df['loc_path'].tolist()
        locs = [node for path in paths for node in path]
        locs = np.unique(locs)
        loc_df = sg.loc_df.loc[locs].copy(deep=True)

        # subset edge_df based on all the edges that are in the paths from
        # the alread-subset t_df
        paths = t_df['path'].tolist()
        edges = [node for path in paths for node in path]
        edges = np.unique(edges)
        edge_df = sg.edge_df.loc[edges].copy(deep=True)
    if not gid:
        t_df = sg.t_df.copy(deep=True)
        edge_df = sg.edge_df.copy(deep=True)
        loc_df = sg.loc_df.copy(deep=True)

    # also subset anndata
    # if obs_col and obs_cats:
    # 	# adatas = [sg.adata, sg.edge_adata,
    # 	# 		  sg.tss_adata, sg.tes_adata]
    # 	# for adata in adatas:
    # 	# print(obs_col)
    # 	# print(obs_cats)
    # 	obs_vars = sg.adata.obs.loc[sg.adata.obs[obs_col].isin(obs_cats)]
    # 	obs_vars = obs_vars.index.tolist()
    # 	# print(obs_vars)
    # 	# print(tids)
    # 	adata = sg.adata[obs_vars, tids]
    new_adatas = dict()
    # adatas = {'iso': sg.adata, 'edge': sg.edge_adata,
    # 		  'tss': sg.tss_adata, 'tes': sg.tes_adata}
    adatas = {'iso': sg.adata}
    for key, adata in adatas.items():
        if datasets and gid:
            new_adatas[key] = adata[datasets, tids]
        elif gid:
            new_adatas[key] = adata[:, tids]
        elif datasets:
            new_adatas[key] = adata[datasets, :]
        else:
            new_adatas[key] = adata

    # create a new graph that's been subset
    subset_sg = swan.SwanGraph()
    subset_sg.loc_df = loc_df
    subset_sg.edge_df = edge_df
    subset_sg.t_df = t_df
    subset_sg.adata = new_adatas['iso']
    # subset_sg.edge_adata = new_adatas['edge']
    # subset_sg.tss_adata = new_adatas['tss']
    # subset_sg.tes_adata = new_adatas['tes']
    subset_sg.datasets = subset_sg.adata.obs.index.tolist()
    subset_sg.abundance = sg.abundance
    subset_sg.sc = sg.sc
    subset_sg.pg = sg.pg
    subset_sg.annotation = sg.annotation

    # renumber locs if using a gene
    if gid:
        if strand == '-':
            id_map = subset_sg.get_ordered_id_map(rev_strand=True)
            subset_sg.update_ids(id_map)
        else:
            subset_sg.update_ids()

        subset_sg.get_loc_types()

    # finally create the graph
    subset_sg.create_graph_from_dfs()

    return subset_sg

In [None]:
# making gen report work 
# check if groupby column is present
def gen_report(self,
               gid,
               prefix,
               datasets=None,
               groupby=None,
               metadata_cols=None,
               novelty=False,
               layer='tpm', # choose from tpm, pi
               cmap='Spectral_r',
               include_qvals=False,
               q=0.05,
               log2fc=1,
               qval_obs_col=None,
               qval_obs_conditions=None,
               include_unexpressed=False,
               indicate_novel=False,
               display_numbers=False,
               transcript_name=False,
               browser=False,
               order='expression'):
    multi_groupby = False
    indicate_dataset = False
    if groupby:
        # grouping by more than one column
        if type(groupby) == list and len(groupby) > 1:
            for g in groupby:
                if g not in self.adata.obs.columns.tolist():
                    raise Exception('Groupby column {} not found'.format(g))
            groupby = self.add_multi_groupby(groupby)
            multi_groupby = True
        elif groupby not in self.adata.obs.columns.tolist():
            raise Exception('Groupby column {} not found'.format(groupby))


    # check if metadata columns are present
    if metadata_cols:
        for c in metadata_cols:
            if c not in self.adata.obs.columns.tolist():
                raise Exception('Metadata column {} not found'.format(c))

            # if we're grouping by a certain variable, make sure
            # the other metadata cols we plan on plotting have unique
            # mappings to the other columns. if just grouping by dataset,
            # since each dataset is unique, that's ok
            if groupby and groupby != 'dataset':
                if groupby == c:
                    continue

                temp = self.adata.obs[[groupby, c, 'dataset']].copy(deep=True)
                temp = temp.groupby([groupby, c]).count().reset_index()
                temp = temp.loc[~temp.dataset.isnull()]

                # if there are duplicates from the metadata column, throw exception
                if temp[groupby].duplicated().any():
                        raise Exception('Metadata column {} '.format(c)+\
                            'not compatible with groupby column {}. '.format(groupby)+\
                            'Groupby column has more than 1 unique possible '+\
                            'value from metadata column.')

    # check to see if input gene is in the graph
    if gid not in self.t_df.gid.tolist():
        gid = self.get_gid_from_gname(gid)
    self.check_gene(gid)

    # check to see if these plotting settings will play together
    self.check_plotting_args(indicate_dataset,
        indicate_novel, browser)

    # get the list of columns to include from the input datasets dict
    if datasets:
        # get a df that is subset of metadata
        # also sort the datasets based on the order they appear in "datasets"
        i = 0
        sorters = []
        for meta_col, meta_cats in datasets.items():
            if meta_col not in self.adata.obs.columns.tolist():
                raise Exception('Metadata column {} not found'.format(meta_col))
            if type(meta_cats) == str:
                meta_cats = [meta_cats]
            if i == 0:
                temp = self.adata.obs.loc[self.adata.obs[meta_col].isin(meta_cats)]
            else:
                temp = temp.loc[temp[meta_col].isin(meta_cats)]
            sort_ind = dict(zip(meta_cats, range(len(meta_cats))))
            sort_col = '{}_sort'.format(meta_col)
            temp[sort_col] = temp[meta_col].map(sort_ind).astype(int)
            sorters.append(sort_col)
            i += 1

        # sort the df based on the order that different categories appear in "datasets"
        temp.sort_values(by=sorters, inplace=True, ascending=True)
        temp.drop(sorters, axis=1, inplace=True)
        columns = temp.dataset.tolist()
        del temp
    else:
        columns = None

    # if we've asked for novelty first check to make sure it's there
    if novelty:
        if not self.has_novelty():
            raise Exception('No novelty information present in the graph. '
                'Add it or do not use the "novelty" report option.')
            
#     print('finished checking all plotting args')

    # abundance info to calculate TPM on - subset on datasets that will
    # be included
    if columns or datasets:
        subset_adata = subset_on_gene_sg(sg, datasets=columns).adata
    else:
        subset_adata = self.adata

    # small SwanGraph with only this gene's data
    sg = subset_on_gene_sg(sg, gid=gid, datasets=columns)

    # if we're grouping data, calculate those new numbers
    # additionally order transcripts
    if groupby:
        if layer == 'tpm':
            # use whole adata to calc tpm
            t_df = tpm_df = calc_tpm(subset_adata, obs_col=groupby).transpose()
#             print('finished calculating tpm')
        elif layer == 'pi':
            # calc tpm just so we can order based on exp
            tpm_df = calc_tpm(subset_adata, obs_col=groupby).transpose()
#             print('finished calculating pi')
            t_df, _ = calc_pi(self.adata, self.t_df, obs_col=groupby)
#             print('finished calculating tpm and pi')
            t_df = t_df.transpose()
            
    else:
        if layer == 'tpm':
            # use whole adata to calc tpm
            t_df = tpm_df = self.get_tpm().transpose()
            t_df = t_df[subset_adata.obs.dataset.tolist()]
#             print('finished calculating tpm')
        elif layer == 'pi':
            # calc tpm just so we can order based on exp
            t_df = tpm_df = self.get_tpm().transpose()
            t_df = t_df[subset_adata.obs.dataset.tolist()]
            t_df, _ = calc_pi(self.adata, self.t_df)
#             print('finished calculating and pi')
            t_df = t_df.transpose()

    # order transcripts by user's preferences
    if order == 'expression' and self.abundance == False:
        order = 'tid'
    elif order == 'expression':
        order = 'log2tpm'
    tids = self.t_df.loc[self.t_df.gid == gid].index.tolist()
    tids = list(set(tids)&set(tpm_df.index.tolist()))
    tpm_df = tpm_df.loc[tids]
    _, tids = self.order_transcripts_subset(tpm_df, order=order)
#     print('finished ordering transcripts')
    del tpm_df
    t_df = t_df.loc[tids]

    # remove unexpressed transcripts if desired
    if not include_unexpressed:
        t_df = t_df.loc[t_df.any(axis=1)]

    # make sure de has been run if needed
    if include_qvals:
        uns_key = make_uns_key(kind='det',
                               obs_col=qval_obs_col,
                               obs_conditions=qval_obs_conditions)
        qval_df = self.adata.uns[uns_key].copy(deep=True)
        qval_df['significant'] = (qval_df.qval <= q)&(qval_df.log2fc >= log2fc)
    else:
        qval_df = None

    # get tids in this report
    report_tids = t_df.index.tolist()

    # plot each transcript with these settings
    print()
    print('Plotting transcripts for {}'.format(gid))
    self.plot_each_transcript(report_tids, prefix,
                              indicate_dataset,
                              indicate_novel,
                              browser=browser)
#     print('finished plotting each transcript')

    # get a different prefix for saving colorbars and scales
    gid_prefix = prefix+'_{}'.format(gid)

    # if we're plotting tracks, we need a scale as well
    # also set what type of report this will be, 'swan' or 'browser'
    if browser:
        self.pg.plot_browser_scale()
        save_fig(gid_prefix+'_browser_scale.png')
        report_type = 'browser'
    else:
        report_type = 'swan'

    # plot colorbar for either tpm or pi
    if layer == 'tpm':

        # take log2(tpm) (add pseudocounts)
        t_df = np.log2(t_df+1)

        # min and max tpm vals
        g_max = t_df.max().max()
        g_min = t_df.min().min()

        # create a colorbar
        plt.rcParams.update({'font.size': 30})
        fig, ax = plt.subplots(figsize=(14, 1.5))
        fig.subplots_adjust(bottom=0.5)
        fig.patch.set_visible(False)
        ax.patch.set_visible(False)

        try:
            cmap = plt.get_cmap(cmap)
        except:
            raise ValueError('Colormap {} not found'.format(cmap))

        norm = mpl.colors.Normalize(vmin=g_min, vmax=g_max)

        cb = mpl.colorbar.ColorbarBase(ax,
                            cmap=cmap,
                            norm=norm,
                            orientation='horizontal')
        cb.set_label('log2(TPM)')
        plt.savefig(gid_prefix+'_colorbar_scale.png', format='png',
            bbox_inches='tight', dpi=200)
        plt.clf()
        plt.close()

    elif layer == 'pi':

        # min and max pi vals
        g_max = 100
        g_min = 0

        # create a colorbar between 0 and 1
        plt.rcParams.update({'font.size': 30})
        fig, ax = plt.subplots(figsize=(14, 1.5))
        fig.subplots_adjust(bottom=0.5)
        fig.patch.set_visible(False)
        ax.patch.set_visible(False)

        try:
            cmap = plt.get_cmap(cmap)
        except:
            raise ValueError('Colormap {} not found'.format(cmap))

        norm = mpl.colors.Normalize(vmin=0, vmax=100)

        cb = mpl.colorbar.ColorbarBase(ax,
                            cmap=cmap,
                            norm=norm,
                            orientation='horizontal')
        cb.set_label('Percent of isoform use (' +'$\pi$'+')')
        plt.savefig(gid_prefix+'_colorbar_scale.png', format='png',
            bbox_inches='tight', dpi=200)
        plt.clf()
        plt.close()

    # merge with self.t_df to get additional columns
    print('tdf')
    print(t_df.columns)
    datasets = t_df.columns
    cols = ['novelty', 'tname']
    t_df = t_df.merge(self.t_df[cols], how='left', left_index=True, right_index=True)

    # create report
    print('Generating report for {}'.format(gid))
    pdf_name = create_fname(prefix,
                 indicate_dataset,
                 indicate_novel,
                 browser,
                 ftype='report',
                 gid=gid)
    if transcript_name:
        t_disp = 'Transcript Name'
    else:
        t_disp = 'Transcript ID'
    report = Report(gid_prefix,
                    report_type,
                    self.adata.obs,
                    self.adata.uns,
                    datasets=datasets,
                    groupby=groupby,
                    metadata_cols=metadata_cols,
                    novelty=novelty,
                    layer=layer,
                    cmap=cmap,
                    g_min=g_min,
                    g_max=g_max,
                    include_qvals=include_qvals,
                    qval_df=qval_df,
                    display_numbers=display_numbers,
                    t_disp=t_disp)
    report.add_page()

    # loop through each transcript and add it to the report
    for ind, entry in t_df.iterrows():
        tid = ind

        # display name for transcript
        if transcript_name:
            t_disp = entry['tname']
        else:
            t_disp = tid
        fname = create_fname(prefix,
                             indicate_dataset,
                             indicate_novel,
                             browser,
                             ftype='path',
                             tid=tid)
        report.add_transcript(entry, fname, t_disp)
    report.write_pdf(pdf_name)

    # remove multi groupby column if necessary
    if multi_groupby:
        self.rm_multi_groupby(groupby)

In [4]:
# # making gen report work 
# # check if groupby column is present
# def gen_report(sg,
#                gid,
#                prefix,
#                datasets=None,
#                groupby=None,
#                metadata_cols=None,
#                novelty=False,
#                layer='tpm', # choose from tpm, pi
#                cmap='Spectral_r',
#                include_qvals=False,
#                q=0.05,
#                log2fc=1,
#                qval_obs_col=None,
#                qval_obs_conditions=None,
#                include_unexpressed=False,
#                indicate_novel=False,
#                display_numbers=False,
#                transcript_name=False,
#                browser=False,
#                order='expression'):
#     multi_groupby = False
#     indicate_dataset = False
#     if groupby:
#         # grouping by more than one column
#         if type(groupby) == list and len(groupby) > 1:
#             for g in groupby:
#                 if g not in self.adata.obs.columns.tolist():
#                     raise Exception('Groupby column {} not found'.format(g))
#             groupby = self.add_multi_groupby(groupby)
#             multi_groupby = True
#         elif groupby not in self.adata.obs.columns.tolist():
#             raise Exception('Groupby column {} not found'.format(groupby))


#     # check if metadata columns are present
#     if metadata_cols:
#         for c in metadata_cols:
#             if c not in self.adata.obs.columns.tolist():
#                 raise Exception('Metadata column {} not found'.format(c))

#             # if we're grouping by a certain variable, make sure
#             # the other metadata cols we plan on plotting have unique
#             # mappings to the other columns. if just grouping by dataset,
#             # since each dataset is unique, that's ok
#             if groupby and groupby != 'dataset':
#                 if groupby == c:
#                     continue

#                 temp = self.adata.obs[[groupby, c, 'dataset']].copy(deep=True)
#                 temp = temp.groupby([groupby, c]).count().reset_index()
#                 temp = temp.loc[~temp.dataset.isnull()]

#                 # if there are duplicates from the metadata column, throw exception
#                 if temp[groupby].duplicated().any():
#                         raise Exception('Metadata column {} '.format(c)+\
#                             'not compatible with groupby column {}. '.format(groupby)+\
#                             'Groupby column has more than 1 unique possible '+\
#                             'value from metadata column.')

#     # check to see if input gene is in the graph
#     if gid not in self.t_df.gid.tolist():
#         gid = self.get_gid_from_gname(gid)
#     self.check_gene(gid)

#     # check to see if these plotting settings will play together
#     self.check_plotting_args(indicate_dataset,
#         indicate_novel, browser)

#     # get the list of columns to include from the input datasets dict
#     if datasets:
#         # get a df that is subset of metadata
#         # also sort the datasets based on the order they appear in "datasets"
#         i = 0
#         sorters = []
#         for meta_col, meta_cats in datasets.items():
#             if meta_col not in self.adata.obs.columns.tolist():
#                 raise Exception('Metadata column {} not found'.format(meta_col))
#             if type(meta_cats) == str:
#                 meta_cats = [meta_cats]
#             if i == 0:
#                 temp = self.adata.obs.loc[self.adata.obs[meta_col].isin(meta_cats)]
#             else:
#                 temp = temp.loc[temp[meta_col].isin(meta_cats)]
#             sort_ind = dict(zip(meta_cats, range(len(meta_cats))))
#             sort_col = '{}_sort'.format(meta_col)
#             temp[sort_col] = temp[meta_col].map(sort_ind).astype(int)
#             sorters.append(sort_col)
#             i += 1

#         # sort the df based on the order that different categories appear in "datasets"
#         temp.sort_values(by=sorters, inplace=True, ascending=True)
#         temp.drop(sorters, axis=1, inplace=True)
#         columns = temp.dataset.tolist()
#         del temp
#     else:
#         columns = None

#     # if we've asked for novelty first check to make sure it's there
#     if novelty:
#         if not self.has_novelty():
#             raise Exception('No novelty information present in the graph. '
#                 'Add it or do not use the "novelty" report option.')
            
# #     print('finished checking all plotting args')

#     # abundance info to calculate TPM on - subset on datasets that will
#     # be included
#     if columns or datasets:
#         subset_adata = subset_on_gene_sg(sg, datasets=columns).adata
#     else:
#         subset_adata = self.adata

#     # small SwanGraph with only this gene's data
#     sg = subset_on_gene_sg(sg, gid=gid, datasets=columns)

#     # if we're grouping data, calculate those new numbers
#     # additionally order transcripts
#     if groupby:
#         if layer == 'tpm':
#             # use whole adata to calc tpm
#             t_df = tpm_df = calc_tpm(subset_adata, obs_col=groupby).transpose()
# #             print('finished calculating tpm')
#         elif layer == 'pi':
#             # calc tpm just so we can order based on exp
#             tpm_df = calc_tpm(subset_adata, obs_col=groupby).transpose()
# #             print('finished calculating pi')
#             t_df, _ = calc_pi(self.adata, self.t_df, obs_col=groupby)
# #             print('finished calculating tpm and pi')
#             t_df = t_df.transpose()
            
#     else:
#         if layer == 'tpm':
#             # use whole adata to calc tpm
#             t_df = tpm_df = self.get_tpm().transpose()
#             t_df = t_df[subset_adata.obs.dataset.tolist()]
# #             print('finished calculating tpm')
#         elif layer == 'pi':
#             # calc tpm just so we can order based on exp
#             t_df = tpm_df = self.get_tpm().transpose()
#             t_df = t_df[subset_adata.obs.dataset.tolist()]
#             t_df, _ = calc_pi(self.adata, self.t_df)
# #             print('finished calculating and pi')
#             t_df = t_df.transpose()

#     # order transcripts by user's preferences
#     if order == 'expression' and self.abundance == False:
#         order = 'tid'
#     elif order == 'expression':
#         order = 'log2tpm'
#     tids = self.t_df.loc[self.t_df.gid == gid].index.tolist()
#     tids = list(set(tids)&set(tpm_df.index.tolist()))
#     tpm_df = tpm_df.loc[tids]
#     _, tids = self.order_transcripts_subset(tpm_df, order=order)
# #     print('finished ordering transcripts')
#     del tpm_df
#     t_df = t_df.loc[tids]

#     # remove unexpressed transcripts if desired
#     if not include_unexpressed:
#         t_df = t_df.loc[t_df.any(axis=1)]

#     # make sure de has been run if needed
#     if include_qvals:
#         uns_key = make_uns_key(kind='det',
#                                obs_col=qval_obs_col,
#                                obs_conditions=qval_obs_conditions)
#         qval_df = self.adata.uns[uns_key].copy(deep=True)
#         qval_df['significant'] = (qval_df.qval <= q)&(qval_df.log2fc >= log2fc)
#     else:
#         qval_df = None

#     # get tids in this report
#     report_tids = t_df.index.tolist()

#     # plot each transcript with these settings
#     print()
#     print('Plotting transcripts for {}'.format(gid))
#     self.plot_each_transcript(report_tids, prefix,
#                               indicate_dataset,
#                               indicate_novel,
#                               browser=browser)
# #     print('finished plotting each transcript')

#     # get a different prefix for saving colorbars and scales
#     gid_prefix = prefix+'_{}'.format(gid)

#     # if we're plotting tracks, we need a scale as well
#     # also set what type of report this will be, 'swan' or 'browser'
#     if browser:
#         self.pg.plot_browser_scale()
#         save_fig(gid_prefix+'_browser_scale.png')
#         report_type = 'browser'
#     else:
#         report_type = 'swan'

#     # plot colorbar for either tpm or pi
#     if layer == 'tpm':

#         # take log2(tpm) (add pseudocounts)
#         t_df = np.log2(t_df+1)

#         # min and max tpm vals
#         g_max = t_df.max().max()
#         g_min = t_df.min().min()

#         # create a colorbar
#         plt.rcParams.update({'font.size': 30})
#         fig, ax = plt.subplots(figsize=(14, 1.5))
#         fig.subplots_adjust(bottom=0.5)
#         fig.patch.set_visible(False)
#         ax.patch.set_visible(False)

#         try:
#             cmap = plt.get_cmap(cmap)
#         except:
#             raise ValueError('Colormap {} not found'.format(cmap))

#         norm = mpl.colors.Normalize(vmin=g_min, vmax=g_max)

#         cb = mpl.colorbar.ColorbarBase(ax,
#                             cmap=cmap,
#                             norm=norm,
#                             orientation='horizontal')
#         cb.set_label('log2(TPM)')
#         plt.savefig(gid_prefix+'_colorbar_scale.png', format='png',
#             bbox_inches='tight', dpi=200)
#         plt.clf()
#         plt.close()

#     elif layer == 'pi':

#         # min and max pi vals
#         g_max = 100
#         g_min = 0

#         # create a colorbar between 0 and 1
#         plt.rcParams.update({'font.size': 30})
#         fig, ax = plt.subplots(figsize=(14, 1.5))
#         fig.subplots_adjust(bottom=0.5)
#         fig.patch.set_visible(False)
#         ax.patch.set_visible(False)

#         try:
#             cmap = plt.get_cmap(cmap)
#         except:
#             raise ValueError('Colormap {} not found'.format(cmap))

#         norm = mpl.colors.Normalize(vmin=0, vmax=100)

#         cb = mpl.colorbar.ColorbarBase(ax,
#                             cmap=cmap,
#                             norm=norm,
#                             orientation='horizontal')
#         cb.set_label('Percent of isoform use (' +'$\pi$'+')')
#         plt.savefig(gid_prefix+'_colorbar_scale.png', format='png',
#             bbox_inches='tight', dpi=200)
#         plt.clf()
#         plt.close()

#     # merge with self.t_df to get additional columns
#     print('tdf')
#     print(t_df.columns)
#     datasets = t_df.columns
#     cols = ['novelty', 'tname']
#     t_df = t_df.merge(self.t_df[cols], how='left', left_index=True, right_index=True)

#     # create report
#     print('Generating report for {}'.format(gid))
#     pdf_name = create_fname(prefix,
#                  indicate_dataset,
#                  indicate_novel,
#                  browser,
#                  ftype='report',
#                  gid=gid)
#     if transcript_name:
#         t_disp = 'Transcript Name'
#     else:
#         t_disp = 'Transcript ID'
#     report = Report(gid_prefix,
#                     report_type,
#                     self.adata.obs,
#                     self.adata.uns,
#                     datasets=datasets,
#                     groupby=groupby,
#                     metadata_cols=metadata_cols,
#                     novelty=novelty,
#                     layer=layer,
#                     cmap=cmap,
#                     g_min=g_min,
#                     g_max=g_max,
#                     include_qvals=include_qvals,
#                     qval_df=qval_df,
#                     display_numbers=display_numbers,
#                     t_disp=t_disp)
#     report.add_page()

#     # loop through each transcript and add it to the report
#     for ind, entry in t_df.iterrows():
#         tid = ind

#         # display name for transcript
#         if transcript_name:
#             t_disp = entry['tname']
#         else:
#             t_disp = tid
#         fname = create_fname(prefix,
#                              indicate_dataset,
#                              indicate_novel,
#                              browser,
#                              ftype='path',
#                              tid=tid)
#         report.add_transcript(entry, fname, t_disp)
#     report.write_pdf(pdf_name)

#     # remove multi groupby column if necessary
#     if multi_groupby:
#         self.rm_multi_groupby(groupby)

In [5]:
# make some swan reports
gen_report(sg, 'Srrm2',
              prefix='figures/srrm2',
              layer='tpm',
              cmap='viridis',
              novelty=True,
              groupby='lr_celltype',
              transcript_name=True,
              metadata_cols=['lr_celltype'],
              datasets={'lr_celltype': order})

gen_report(sg, 'Srrm2',
              prefix='figures/srrm2',
              layer='pi',
              cmap='magma',
              novelty=True,
              groupby='lr_celltype',
              transcript_name=True,
              metadata_cols=['lr_celltype'],
              datasets={'lr_celltype': order},
              browser=True)

subset adata
['ATTGGCTCACAAGCTATCTATTAC-ont_b', 'AAGAGATCCGCATACACTGTCCCG-ont_b', 'CATACCAACGAACTTATTACCTCG-ont_b', 'AAGAGATCTGAAGAGACGCGACTA-pb_2ka', 'GTACGCAACCTCTATCTAAATATC-ont_b']
sg
['ATTGGCTCACAAGCTATCTATTAC-ont_b', 'AAGAGATCCGCATACACTGTCCCG-ont_b', 'CATACCAACGAACTTATTACCTCG-ont_b', 'AAGAGATCTGAAGAGACGCGACTA-pb_2ka', 'GTACGCAACCTCTATCTAAATATC-ont_b']
in calc_tpm
             ENCODEMT000142671  ENCODEMT000144512  ENCODEMT000145637  \
lr_celltype                                                            
Medulla_NE                 0.0                0.0                0.0   
Medulla_NE                 0.0                0.0                0.0   
Medulla_NE                 0.0                0.0                0.0   
Medulla_NE                 0.0                0.0                0.0   
Medulla_NE                 0.0                0.0                0.0   

             ENCODEMT000147256  ENCODEMT000152127  ENCODEMT000153402  \
lr_celltype                                        

  view_to_actual(adata)


                ENCODEMT000142671  ENCODEMT000144512  ENCODEMT000145637  \
lr_celltype                                                               
Adipocytes                    0.0           28.02298          20.751194   
Capsule                       0.0            0.00000           0.000000   
Cortex_ZF                     0.0            0.00000           0.000000   
Cortex_ZG                     0.0            0.00000           0.000000   
Cortex_cycling                0.0            0.00000           0.000000   

                ENCODEMT000147256  ENCODEMT000152127  ENCODEMT000153402  \
lr_celltype                                                               
Adipocytes               0.000000           0.000000           0.000000   
Capsule                  0.000000           0.000000           0.000000   
Cortex_ZF               26.217817           0.000000          32.375031   
Cortex_ZG                0.000000           0.000000           5.496197   
Cortex_cycling          

  view_to_actual(adata)


                ENCODEMT000142671  ENCODEMT000144512  ENCODEMT000145637  \
lr_celltype                                                               
Adipocytes                    0.0           28.02298          20.751194   
Capsule                       0.0            0.00000           0.000000   
Cortex_ZF                     0.0            0.00000           0.000000   
Cortex_ZG                     0.0            0.00000           0.000000   
Cortex_cycling                0.0            0.00000           0.000000   

                ENCODEMT000147256  ENCODEMT000152127  ENCODEMT000153402  \
lr_celltype                                                               
Adipocytes               0.000000           0.000000           0.000000   
Capsule                  0.000000           0.000000           0.000000   
Cortex_ZF               26.217817           0.000000          32.375031   
Cortex_ZG                0.000000           0.000000           5.496197   
Cortex_cycling          

In [None]:
# # fixing groupbys in calc_pi and calc_tpm (and maybe create_edge_adata and create_end_adata)
# # sg = swan.read('/Users/fairliereese/mortazavi_lab/data/mousewg/adrenal/lr_splitseq/swan/swan.p')
# keys   = ['dataset1', 'dataset2', 'dataset3', 'dataset4', 'dataset5', 'dataset6']
# vals = [1,2,3,4,5,6]

# def sparse_groupby(keys, vals):
#     print(keys)
#     unique_keys, row = np.unique(keys, return_inverse=True)
#     col = np.arange(len(keys))
#     mat = sparse.coo_matrix((vals, (row, col)))
#     print(mat.toarray())
#     return dict(zip(unique_keys, mat.sum(1).flat))

# sparse_groupby(keys, vals)

# # # vals = [[1,2,3,4,5,6], [1,2,3,4,5,7]]
# df = pd.DataFrame(index=keys, data=vals)
# df.head()
# # df.groupby(keys).sum()

# # def pandas_groupby(keys, vals):
# #     return pd.Series(vals).groupby(keys).sum().to_dict()
# # pandas_groupby(keys, vals)

In [None]:
# # debugging / speeding up get_edge_ab
# sg = swan.SwanGraph()
# sg.add_annotation('../testing/files/test_full_annotation.gtf')
# sg.add_transcriptome('../testing/files/test_full.gtf')
# sg.add_abundance('../testing/files/test_ab_talon_1.tsv')

In [24]:
t_df = sg.t_df.copy(deep=True)
t_df = t_df.loc[sg.adata.var.index.tolist()]
# print('subset t_df')
# print(t_df)
edge_exp_df = swan.pivot_path_list(t_df, 'path')

print('finished running pivot_path_list')

# get a mergeable transcript expression df
tid = sg.adata.var.index.tolist()
obs = sg.adata.obs.index.tolist()
data = sg.adata.layers['counts'].transpose()
t_exp_df = pd.DataFrame.sparse.from_spmatrix(columns=obs,
                                             data=data,
                                             index=tid)

print('finished making t_exp_df')

# merge counts per transcript with edges
edge_exp_df = edge_exp_df.merge(t_exp_df, how='left',
    left_index=True, right_index=True)

print('finished merging')

# sum the counts per transcript / edge / dataset

# THIS IS THE STEP that takes forever
# ensure that this is sparse befre the operation
# idea - subset by those that have 1 or multiple edge ids, only sum up
# those that have multiple, then concat
print('hwllo')
# print(edge_exp_df)
# print(edge_exp_df.dtypes)
edge_exp_df.reset_index(inplace=True, drop=True)
edge_exp_df.set_index('edge_id', inplace=True)

# sparse
print('sparse')
start_time = time.time()
edge_exp_df = edge_exp_df.groupby(by='edge_id', as_index=True).sum()
print('finished summing')
print("--- %s seconds ---" % (time.time() - start_time))

# dense
print('dense')
start_time = time.time()
edge_exp_df[sg.datasets] = edge_exp_df[sg.datasets].sparse.to_dense()
edge_exp_df = edge_exp_df.groupby(by='edge_id', as_index=True).sum()
print('finished summing')
print("--- %s seconds ---" % (time.time() - start_time))

finished running pivot_path_list
finished making t_exp_df
finished merging
hwllo
sparse
finished summing
--- 0.014503955841064453 seconds ---
dense
finished summing
--- 0.003910064697265625 seconds ---


In [23]:
print(edge_exp_df.dtypes)
edge_exp_df.head()

dataset1    Sparse[float64, 0]
dataset2    Sparse[float64, 0]
dtype: object


Unnamed: 0_level_0,dataset1,dataset2
edge_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,5.0,5.0
1,5.0,5.0
2,5.0,5.0
3,5.0,5.0
4,5.0,5.0


In [17]:
kind = 'tss'

# limit to only expresed transcripts
t_df = sg.t_df.copy(deep=True)
t_df = t_df.loc[sg.adata.var.index.tolist()]
print(len(sg.t_df.index))
print(len(t_df.index))

df = swan.get_ends(sg.t_df, kind)
print('finished getting ends')

# get a mergeable transcript expression df
tid = sg.adata.var.index.tolist()
obs = sg.adata.obs.index.tolist()
data = sg.adata.layers['counts'].transpose()
t_exp_df = pd.DataFrame.sparse.from_spmatrix(columns=obs, data=data, index=tid)
t_exp_df = t_exp_df.merge(t_df, how='left',
    left_index=True, right_index=True)

print('finished merging exp')

# merge counts per transcript with end expression
df = df.merge(t_exp_df, how='left',
    left_index=True, right_index=True)
print('finished merging')

# sort based on vertex id
df.sort_index(inplace=True, ascending=True)
print('finished sorting')

# set index to gene ID, gene name, and vertex id
df.reset_index(drop=True, inplace=True)
df.set_index(['gid', 'gname', 'vertex_id'], inplace=True)
df = df[sg.datasets]

# groupby on gene and assign each unique TSS / gene combo an ID
id_col = '{}_id'.format(kind)
name_col = '{}_name'.format(kind)
df.reset_index(inplace=True)

# sparse
print('sparse')
start_time = time.time()
df = df.groupby(['gid', 'gname', 'vertex_id']).sum().reset_index()
print('finished summing')
print("--- %s seconds ---" % (time.time() - start_time))

# dense
print('dense')
start_time = time.time()
df[sg.datasets] = df[sg.datasets].sparse.to_dense()
df = df.groupby(['gid', 'gname', 'vertex_id']).sum().reset_index()
print('finished summing')
print("--- %s seconds ---" % (time.time() - start_time))

6
5
finished getting ends
finished merging exp
finished merging
finished sorting
sparse
finished summing
--- 0.010179758071899414 seconds ---
dense
finished summing
--- 0.006715059280395508 seconds ---


In [9]:
df.dtypes

gid                      object
gname                    object
vertex_id                 int64
dataset1     Sparse[float64, 0]
dataset2     Sparse[float64, 0]
dtype: object

In [6]:
edge_exp_df.reset_index(inplace=True, drop=True)
edge_exp_df.set_index('edge_id', inplace=True)
print(edge_exp_df.dtypes)
edge_exp_df.head()

dataset1    Sparse[float64, 0]
dataset2    Sparse[float64, 0]
dtype: object


Unnamed: 0_level_0,dataset1,dataset2
edge_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,5.0,5.0
1,5.0,5.0
2,5.0,5.0
3,5.0,5.0
4,5.0,5.0


In [11]:
edge_exp_df

Unnamed: 0_level_0,dataset1,dataset2
edge_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,5.0,5.0
1,5.0,5.0
2,5.0,5.0
3,5.0,5.0
4,5.0,5.0
5,10.0,0.0
6,10.0,0.0
7,10.0,0.0
8,10.0,0.0
9,10.0,0.0


In [10]:
edge_exp_df.groupby(by='edge_id', as_index=True).sum()

Unnamed: 0_level_0,dataset1,dataset2
edge_id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,5.0,5.0
1,5.0,5.0
2,5.0,5.0
3,5.0,5.0
4,5.0,5.0
5,15.0,15.0
6,10.0,10.0
7,10.0,0.0
8,10.0,0.0
9,10.0,10.0


In [422]:
# added 
# utils.py
def calc_total_counts(adata, obs_col='dataset', layer='counts'):
    
    # turn into a sparse dataframe
    cols = adata.var.index.tolist()
    inds = adata.obs[obs_col].tolist()
    data = adata.layers[layer]
    data = scipy.sparse.csr_matrix(data)
    df = pd.DataFrame.sparse.from_spmatrix(data, index=inds, columns=cols)
    df.index.name = obs_col 

    # add up values on condition (row)
    df = df.groupby(level=0).sum()
    
    return df

In [423]:
# added
# utils.py
def calc_tpm(adata, obs_col='dataset'):
    
    # calculate tpm using scanpy
    d = sc.pp.normalize_total(adata,
                              layer='counts',
                              target_sum=1e6,
                              key_added='total_counts',
                              inplace=False)
    adata.obs['total_counts'] = d['norm_factor']
    
    # turn into a sparse dataframe
    cols = adata.var.index.tolist()
    inds = adata.obs[obs_col].tolist()
    data = d['X']
    data = scipy.sparse.csr_matrix(data)
    df = pd.DataFrame.sparse.from_spmatrix(data, index=inds, columns=cols)
    df.index.name = obs_col    

    # average across tpm
    if obs_col != 'dataset':
        df.reset_index(inplace=True)
        df = df.groupby(obs_col).mean()
        
    return df

In [526]:
# added
# swangraph.py
def add_abundance(sg, counts_file):
    """
    Adds abundance from a counts matrix to the SwanGraph. Transcripts in the
    SwanGraph but not in the counts matrix will be assigned 0 counts.
    Transcripts in the abundance matrix but not in the SwanGraph will not
    have expression added.

    Parameters:
        counts_file (str): Path to TSV expression file where first column is
            the transcript ID and following columns name the added datasets and
            their counts in each dataset, OR to a TALON abundance matrix.
    """

    # read in abundance file
    swan.check_file_loc(counts_file, 'abundance matrix')
    try:
        df = pd.read_csv(counts_file, sep='\t')
    except:
        raise ValueError('Problem reading expression matrix {}'.format(counts_file))

    # check if abundance matrix is a talon abundance matrix
    cols = ['gene_ID', 'transcript_ID', 'annot_gene_id', 'annot_transcript_id',
        'annot_gene_name', 'annot_transcript_name', 'n_exons', 'length',
        'gene_novelty', 'transcript_novelty', 'ISM_subtype']
    if df.columns.tolist()[:11] == cols:
        df = swan.reformat_talon_abundance(counts_file)

    # rename transcript ID column
    col = df.columns[0]
    df.rename({col: 'tid'}, axis=1, inplace=True)

    # limit to just the transcripts already in the graph
    sg_tids = sg.t_df.tid.tolist()
    ab_tids = df.tid.tolist()
    tids = list(set(sg_tids)&set(ab_tids))
    df = df.loc[df.tid.isin(tids)]
    
    # transpose to get adata format
    df.set_index('tid', inplace=True)
    df = df.T
    
    # get adata components - obs, var, and X
    var = df.columns.to_frame()
    var.columns = ['tid']
    obs = df.index.to_frame()
    obs.columns = ['dataset']
    X = sparse.csr_matrix(df.to_numpy())
    
    # create transcript-level adata object and filter out unexpressed transcripts
    adata = anndata.AnnData(var=var, obs=obs, X=X)
    genes, _  = sc.pp.filter_genes(adata, min_counts=1, inplace=False)
    adata = adata[:, genes]
    adata.layers['counts'] = adata.X

    # add each dataset to list of "datasets", check if any are already there!
    datasets = adata.obs.dataset.tolist()
    for d in datasets:
        if d in sg.datasets:
            raise ValueError('Dataset {} already present in the SwanGraph.'.format(d))
    sg.datasets.extend(datasets)

    print()
    if len(datasets) <= 5:
        print('Adding abundance for datasets {} to SwanGraph.'.format(', '.join(datasets)))
    else:
        mini_datasets = datasets[:5]
        n = len(datasets) - len(mini_datasets)
        print('Adding abundance for datasets {}... (and {} more) to SwanGraph'.format(', '.join(mini_datasets), n))

    # if there is preexisting abundance data in the SwanGraph, concatenate
    # otherwise, adata is the new transcript level adata
    if not sg.has_abundance():

        # create transcript-level adata object
        sg.adata = adata

        # add counts as layers
        sg.adata.layers['counts'] = sg.adata.X
        print('Calculating transcript TPM...')
        sg.adata.layers['tpm'] = sparse.csr_matrix(calc_tpm(sg.adata).to_numpy())

        if not sg.sc:
            print('Calculating PI...') 
            sg.adata.layers['pi'] = sparse.csr_matrix(calc_pi(sg.adata, sg.t_df)[0].to_numpy())
    else:
        
        # first set current layer to be counts
        sg.adata.X = sg.adata.layers['counts']
        
        # concatenate existing adata with new one
        # outer join to add all new transcripts (that are from added
        # annotation or transcriptome) to the abundance
        uns = sg.adata.uns
        sg.adata = sg.adata.concatenate(adata, join='outer', index_unique=None)
        sg.adata.uns = uns
        
        # recalculate pi and tpm
        print('Calculating transcript TPM...')
        sg.adata.layers['tpm'] = sparse.csr_matrix(calc_tpm(sg.adata).to_numpy())

        if not sg.sc:
            print('Calculating PI...')
            sg.adata.layers['pi'] = sparse.csr_matrix(calc_pi(sg.adata, sg.t_df)[0].to_numpy())

    # add abundance for edges, TSS per gene, and TES per gene
    sg = create_edge_adata(sg)
    print('Calculating TSS usage...')
    sg = create_end_adata(sg, kind='tss')
    print('Calculating TES usage...')
    sg = create_end_adata(sg, kind='tes')

    # set abundance flag to true
    sg.abundance = True
    
    return sg

In [527]:
# added
# swangraph.py
def create_end_adata(sg, kind):
    """
    Create a tss / tes-level adata object. Enables calculating tss / tes
    usage across samples.

    Parameters:
        kind (str): Choose from 'tss' or 'tes'
    """

    df = swan.get_ends(sg.t_df, kind)

    # get a mergeable transcript expression df
    tid = sg.adata.var.index.tolist()
    obs = sg.adata.obs.index.tolist()
    data = sg.adata.layers['counts'].transpose()
    t_exp_df = pd.DataFrame.sparse.from_spmatrix(columns=obs, data=data, index=tid)
    t_exp_df = t_exp_df.merge(sg.t_df, how='left',
        left_index=True, right_index=True)

    # merge counts per transcript with end expression
    df = df.merge(t_exp_df, how='left',
        left_index=True, right_index=True)

    # sort based on vertex id
    df.sort_index(inplace=True, ascending=True)

    # set index to gene ID, gene name, and vertex id 
    df.reset_index(drop=True, inplace=True)
    df.set_index(['gid', 'gname', 'vertex_id'], inplace=True)
    df = df[sg.datasets]

    # groupby on gene and assign each unique TSS / gene combo an ID
    id_col = '{}_id'.format(kind)
    name_col = '{}_name'.format(kind)
    df.reset_index(inplace=True)
    df = df.groupby(['gid', 'gname', 'vertex_id']).sum().reset_index()
    df['end_gene_num'] = df.sort_values(['gid', 'vertex_id'],
                    ascending=[True, True])\
                    .groupby(['gid']) \
                    .cumcount() + 1
    df[id_col] = df['gid']+'_'+df['end_gene_num'].astype(str)
    df[name_col] = df['gname']+'_'+df['end_gene_num'].astype(str)
    df.drop('end_gene_num', axis=1, inplace=True)

    # obs, var, and X tables for new data
    var_cols = ['gid', 'gname', 'vertex_id', id_col, name_col]
    var = df[var_cols]
    var.set_index('{}_id'.format(kind), inplace=True)
    df.drop(var_cols, axis=1, inplace=True)
    df = df[sg.adata.obs.index.tolist()]
    X = sparse.csr_matrix(df.transpose().values)
    obs = sg.adata.obs
    
    # create anndata
    adata = anndata.AnnData(var=var, obs=obs, X=X)
    
    # add counts and tpm as layers
    adata.layers['counts'] = adata.X
    adata.layers['tpm'] = sparse.csr_matrix(calc_tpm(adata).to_numpy())
    if not sg.sc:
        adata.layers['pi'] = sparse.csr_matrix(calc_pi(adata,
                adata.var)[0].to_numpy())

    # assign adata and clean up unstructured data if needed
    if kind == 'tss':
        if sg.has_abundance():
            adata.uns = sg.tss_adata.uns
        sg.tss_adata = adata
        
    elif kind == 'tes':
        if sg.has_abundance():
            adata.uns = sg.tss_adata.uns
        sg.tes_adata = adata
    
    return sg


In [557]:
# added
# swangraph.py
def create_edge_adata(sg):
    """
    Create an edge-level adata object. Enables calculating edge usage across
    samples.
    """

    # get table what edges are in each transcript
    edge_exp_df = swan.pivot_path_list(sg.t_df, 'path')

    # get a mergeable transcript expression df
    tid = sg.adata.var.index.tolist()
    obs = sg.adata.obs.index.tolist()
    data = sg.adata.layers['counts'].transpose()
    t_exp_df = pd.DataFrame.sparse.from_spmatrix(columns=obs,
                                                 data=data,
                                                 index=tid)

    # merge counts per transcript with edges
    edge_exp_df = edge_exp_df.merge(t_exp_df, how='left',
        left_index=True, right_index=True)

    # sum the counts per transcript / edge / dataset
    edge_exp_df = edge_exp_df.groupby('edge_id').sum()

    # order based on order of edges in sg.edge_df
    edge_exp_df = edge_exp_df.merge(sg.edge_df[['v1', 'v2']],
        how='left', left_index=True, right_index=True)
    edge_exp_df.sort_values(by=['v1', 'v2'], inplace=True)
    edge_exp_df.drop(['v1', 'v2'], axis=1, inplace=True)
    
    # drop edges that are unexpressed
    edge_exp_df = edge_exp_df.loc[edge_exp_df.sum(1) > 0]

    # obs, var, and X tables for new data
    var = edge_exp_df.index.to_frame()
    X = sparse.csr_matrix(edge_exp_df.transpose().values)
    obs = sg.adata.obs

    # create edge-level adata object
    adata = anndata.AnnData(var=var, obs=obs, X=X)

    # add counts and tpm as layers
    adata.layers['counts'] = adata.X
    adata.layers['tpm'] = sparse.csr_matrix(calc_tpm(adata).to_numpy())
    # can't make pi for edges unless I make a new edge for 
    # each gene that the edge is in
    # could just have sg.edge_adata var separate from sg.edge_df for now tho
#     sg.edge_adata.layers['pi'] = sparse.csr_matrix(calc_pi(sg.adata, sg.edge_df)[0].to_numpy())

    # assign adata and clean up unstructured data if needed
    if sg.has_abundance():
        adata.uns = sg.edge_adata.uns
    sg.edge_adata = adata
    
    return sg

In [555]:
# added
# utils.py
def calc_pi(adata, t_df, obs_col='dataset'):

    # calculate cumulative counts across obs_col
    id_col = adata.var.index.name
    conditions = adata.obs[obs_col].unique().tolist()
    df = calc_total_counts(adata, obs_col=obs_col)
    df = df.transpose()
    # we use ints to index edges and locs
    if id_col == 'vertex_id' or id_col == 'edge_id':
        df.index = df.index.astype('int')

    sums = df.copy(deep=True)
    sums = sums[conditions]
    sums = sums.transpose()

    # add gid
    df = df.merge(t_df['gid'], how='left', left_index=True, right_index=True)
    t_counts = df.melt(id_vars=['gid'],
                       value_vars=conditions,
                       var_name=obs_col,
                       value_name='t_counts',
                       ignore_index=False)
    t_counts.index.name = id_col
    t_counts.reset_index(inplace=True)

    # calculate total number of reads per gene per condition
    temp = df.copy(deep=True)
    temp.reset_index(drop=True, inplace=True)
    totals = temp.groupby('gid').sum().reset_index()

    # merge back in
    df.reset_index(inplace=True)
    df.rename({'index':id_col}, axis=1, inplace=True)
    df = df.merge(totals, on='gid', suffixes=('_t_counts', None))
    del totals

    df = df.melt(id_vars=['gid'], 
                 value_vars=conditions, 
                 var_name=obs_col,
                 value_name='gene_counts')
    df = df.drop_duplicates()
    df = t_counts.merge(df, how='left', on=['gid', obs_col])


    df['pi'] = (df.t_counts/df.gene_counts)*100
    df = df.pivot(columns=obs_col, index=id_col, values='pi')

    # order based on order in adata
    ids = adata.var.index.tolist()
    df = df.loc[ids]
    cols = adata.obs[obs_col].unique().tolist()
    df = df[cols]

    # convert to sparse
    df = df.transpose()
    df = pd.DataFrame.sparse.from_spmatrix(data=sparse.csr_matrix(df.values),
                                           index=df.index.tolist(),
                                           columns=df.columns)
    return df, sums


In [558]:
sg = swan.SwanGraph()
sg.add_annotation('../testing/files/test_full_annotation.gtf')
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_talon_1.tsv')

print(sg.tss_adata.var.head())
print(sg.t_df.loc_path)
print(sg.adata.var.head())
print(sg.adata.layers['counts'].toarray())
print(sg.tss_adata.var.head())
print(sg.tss_adata.layers['counts'].toarray())
print(sg.tss_adata.layers['tpm'].toarray())
print(sg.tss_adata.layers['pi'].toarray())

print(sg.tes_adata.var.head())
print(sg.t_df[['gid', 'tid', 'loc_path']])
print(sg.adata.var.head())
print(sg.adata.layers['counts'].toarray())
print(sg.tes_adata.var)
print(sg.tes_adata.layers['counts'].toarray())
print(sg.tes_adata.layers['tpm'].toarray())
print(sg.tes_adata.layers['pi'].toarray())

print(type(sg.tes_adata.layers['counts']))
print(type(sg.tes_adata.layers['tpm']))
print(type(sg.tes_adata.layers['pi']))


Adding annotation to the SwanGraph

Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.
Calculating transcript TPM...
Calculating PI...
df
       dataset1  dataset2
test1       5.0       5.0
test2      10.0       0.0
test3       0.0      10.0
test4      10.0      10.0
test5       5.0       5.0
t_df
             tname        gid        gname               path    tid  \
tid                                                                    
test1  test1_tname  test1_gid  test1_gname    [0, 1, 2, 3, 4]  test1   
test2  test2_tname  test2_gid  test2_gname    [5, 6, 7, 8, 9]  test2   
test3  test3_tname  test2_gid  test2_gname  [5, 6, 14, 15, 9]  test3   
test4  test4_tname  test4_gid  test4_gname               [10]  test4   
test5  test5_tname  test2_gid  test2_gname        [5, 11, 12]  test5   

                    loc_path  annotation    novelty  
tid                                                  
test1     [0, 1, 2, 3, 4, 5]        



In [1]:
import swan_vis as swan
import pandas as pd
import anndata
import scanpy as sc
from scipy import sparse
import numpy as np
import time
import copy
import matplotlib.pyplot as plt
import matplotlib as mpl

In [2]:
# test_add_abundance_3
sg = swan.SwanGraph()
sg.add_annotation('../testing/files/test_full_annotation.gtf')
sg.add_transcriptome('../testing/files/test_full.gtf')
sg.add_abundance('../testing/files/test_ab_talon_1.tsv')
print(sg.t_df.index.tolist())
print(sg.adata.var.index.tolist())
print(sg.adata.layers['counts'].toarray())
print(sg.adata.layers['tpm'].toarray())
print(sg.adata.layers['pi'].toarray())


Adding annotation to the SwanGraph

Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.
Calculating transcript TPM...
Calculating PI...
   dataset1  dataset2        gid
0       5.0       5.0  test1_gid
1      10.0       0.0  test2_gid
2       0.0      10.0  test2_gid
3      10.0      10.0  test4_gid
4       5.0       5.0  test2_gid
dataset1    float32
dataset2    float32
gid          object
dtype: object
Calculating edge usage...
Calculating TSS usage...
   dataset1  dataset2        gid
0       5.0       5.0  test1_gid
1      15.0      15.0  test2_gid
2      10.0      10.0  test4_gid
dataset1    float32
dataset2    float32
gid          object
dtype: object
Calculating TES usage...
   dataset1  dataset2        gid
0       5.0       5.0  test1_gid
1      10.0      10.0  test2_gid
2       5.0       5.0  test2_gid
3      10.0      10.0  test4_gid
dataset1    float32
dataset2    float32
gid          object
dtype: object
['test1', 'test2', '



In [379]:
# test_add_abundance_2
sg = swan.SwanGraph()
sg.add_annotation('../testing/files/test_full_annotation.gtf')
sg.add_transcriptome('../testing/files/test_full.gtf')

sg = add_abundance(sg, '../testing/files/test_ab_dataset1.tsv')
sg = add_abundance(sg, '../testing/files/test_ab_dataset2.tsv')

print(sg.t_df.index.tolist())
print(sg.adata.var.index.tolist())
print(sg.adata.layers['counts'].toarray())
print(sg.adata.layers['tpm'].toarray())
print(sg.adata.layers['pi'].toarray())


Adding annotation to the SwanGraph

Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1 to SwanGraph.
Calculating transcript TPM...
Calculating PI...

Adding abundance for datasets dataset2 to SwanGraph.
Calculating transcript TPM...
Calculating PI...
['test1', 'test2', 'test3', 'test4', 'test5', 'test6']
['test1', 'test2', 'test3', 'test4', 'test5']
[[ 5. 10.  0. 10.  5.]
 [ 5.  0. 10. 10.  5.]]
[[166666.69 333333.38      0.   333333.38 166666.69]
 [166666.69      0.   333333.38 333333.38 166666.69]]
[[100.        66.66667    0.       100.        33.333336]
 [100.         0.        66.66667  100.        33.333336]]


In [380]:
# test_add_abundance_1
sg = swan.SwanGraph()
sg.add_annotation('../testing/files/test_full_annotation.gtf')
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')

print(sg.t_df.index.tolist())
print(sg.adata.var.index.tolist())
print(sg.adata.layers['counts'].toarray())
print(sg.adata.layers['tpm'].toarray())
print(sg.adata.layers['pi'].toarray())


# looks good but tests still needa be updated


Adding annotation to the SwanGraph

Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.
Calculating transcript TPM...
Calculating PI...
['test1', 'test2', 'test3', 'test4', 'test5', 'test6']
['test1', 'test2', 'test3', 'test4', 'test5']
[[ 5. 10.  0. 10.  5.]
 [ 5.  0. 10. 10.  5.]]
[[166666.69 333333.38      0.   333333.38 166666.69]
 [166666.69      0.   333333.38 333333.38 166666.69]]
[[100.        66.66667    0.       100.        33.333336]
 [100.         0.        66.66667  100.        33.333336]]


In [446]:
# test_calc_pi_2
sg = swan.SwanGraph()
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')
sg.adata.obs['cluster'] = 'c1'
test_df, test_sums = calc_pi(sg.adata, sg.t_df, obs_col='cluster')

test_df


Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.
Calculating transcript TPM...
Calculating PI...
Calculating TSS usage...
Calculating TES usage...


tid,test1,test2,test3,test4,test5
c1,100.0,33.333336,33.333336,100.0,33.333336


In [428]:
# test_calc_pi_1
sg = swan.SwanGraph()
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')
test_df, test_sums = calc_pi(sg.adata, sg.t_df, obs_col='dataset')

test_df


Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.
Calculating transcript TPM...
Calculating PI...


tid,test1,test2,test3,test4,test5
dataset1,100.0,66.666672,0.0,100.0,33.333336
dataset2,100.0,0.0,66.666672,100.0,33.333336


In [233]:
sg = swan.read('test_mousewg.p')
ab = '/Users/fairliereese/mortazavi_lab/data/mousewg/lr_bulk/talon/mouse_talon_abundance_filtered.tsv'

Read in graph from test_mousewg.p


In [234]:
# test adding de novo
sg = add_abundance(sg, ab)


Adding abundance for datasets gastroc_14d_f_2, gastroc_14d_f_1, heart_18-20mo_m_1, heart_18-20mo_m_2, heart_18-20mo_f_1... (and 86 more) to SwanGraph
Calculating transcript TPM...


In [238]:
# test_calc_tpm_1
sg = swan.SwanGraph()
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')
sg.adata.obs['cluster'] = ['c1', 'c1']

df = calc_tpm(sg.adata)
df


Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.
Calculating transcript TPM...


Unnamed: 0_level_0,test1,test2,test3,test4,test5
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
dataset1,166666.671875,333333.34375,0.0,333333.34375,166666.671875
dataset2,166666.671875,0.0,333333.34375,333333.34375,166666.671875


In [240]:
sg.adata.layers['tpm'].toarray()

array([[166666.67, 333333.34,      0.  , 333333.34, 166666.67],
       [166666.67,      0.  , 333333.34, 333333.34, 166666.67]],
      dtype=float32)

In [220]:
# test_calc_tpm_2
sg = swan.SwanGraph()
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')
sg.adata.obs['cluster'] = ['c1', 'c1']

df = calc_tpm(sg.adata, obs_col='cluster')
df


Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.


Unnamed: 0_level_0,test1,test2,test3,test4,test5
cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
c1,166666.671875,166666.671875,166666.671875,333333.34375,166666.671875


In [226]:
# test_calc_total_counts_1
sg = swan.SwanGraph()
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')
sg.adata.obs['cluster'] = ['c1', 'c1']

df = calc_total_counts(sg.adata)
df


Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.


Unnamed: 0_level_0,test1,test2,test3,test4,test5
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
dataset1,5.0,10.0,0.0,10.0,5.0
dataset2,5.0,0.0,10.0,10.0,5.0


In [227]:
# test_calc_total_counts_2
sg = swan.SwanGraph()
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')
sg.adata.obs['cluster'] = ['c1', 'c1']

df = calc_total_counts(sg.adata, obs_col='cluster')
df


Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.


Unnamed: 0_level_0,test1,test2,test3,test4,test5
cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
c1,10.0,10.0,10.0,20.0,10.0


In [None]:
# test merging when incoming adata has duplicate dataset names

In [None]:
# test merging when adding new dataset adds new transcript id to the adata - already tested with test_add_abundance_2