In [91]:
import swan_vis as swan
import pandas as pd
import anndata
import scanpy as sc
import scipy
import numpy

In [32]:
# utils.py
def calc_total_counts(adata, obs_col='dataset', layer='counts'):
	"""
	Calculate cumulative expression per adata entry based on condition given
	by `obs_col`. Default column to use is `adata.obs` index column, `dataset`.

	Parameters:
		adata (anndata AnnData): Annotated data object from the SwanGraph
		obs_col (str): Column name from adata.obs table to group on.
			Default: 'dataset'
		layer (str): Layer of AnnData to pull from. Default = 'counts'

	Returns:
		df (pandas DataFrame): Pandas DataFrame where rows are the different
			conditions from `obs_col` and the columns are transcript ids in the
			SwanGraph, and values represent the cumulative counts per isoform
			per condition.

	"""
	adata.X = adata.layers[layer]
	df = pd.DataFrame(data=adata.X, index=adata.obs[obs_col].tolist(), \
		columns=adata.var.index.tolist())

	# add up values on condition (row)
	df = df.groupby(level=0).sum()

	# df = df.transpose()

	return df

def calc_tpm(adata, obs_col='dataset'):
    """
	Calculate the TPM per condition given by `obs_col`.
	Default column to use is `adata.obs` index column, `dataset`.

	Parameters:
		adata (anndata AnnData): Annotated data object from the SwanGraph
		obs_col (str or list of str): Column name from adata.obs table to group on.
			Default: 'dataset'

	Returns:
		df (pandas DataFrame): Pandas DataFrame where rows are the different
			conditions from `obs_col` and the columns are observed transcript ids in the
			SwanGraph, and values represent the TPM value per isoform per
			condition.
	"""
    
    # first normalize across individual datasets / cells
    
    
    
    

# def calc_tpm(adata, sg_df=None, obs_col='dataset'):
# 	"""
# 	Calculate the TPM per condition given by `obs_col`.
# 	Default column to use is `adata.obs` index column, `dataset`.

# 	Parameters:
# 		adata (anndata AnnData): Annotated data object from the SwanGraph
# 		sg_df (pandas DataFrame): Pandas DataFrame from SwanGraph that will
# 			be used to order the rows of resultant TPM DataFrame
# 		obs_col (str or list of str): Column name from adata.obs table to group on.
# 			Default: 'dataset'

# 	Returns:
# 		df (pandas DataFrame): Pandas DataFrame where rows are the different
# 			conditions from `obs_col` and the columns are observed transcript ids in the
# 			SwanGraph, and values represent the TPM value per isoform per
# 			condition.
# 	"""

# 	# calculate cumulative counts across obs_col
# 	id_col = adata.var.index.name
# 	conditions = adata.obs[obs_col].unique().tolist()
# 	df = swan.calc_total_counts(adata, obs_col=obs_col)
# 	df = df.transpose()
    
# 	print(df.shape)

# 	# we use ints to index edges and locs
# 	if id_col == 'vertex_id' or id_col == 'edge_id':
# 		df.index = df.index.astype('int')

# 	# calculate tpm per isoform per condition
# 	tpm_cols = []
# 	for c in conditions:
# 		cond_col = '{}_tpm'.format(c)
# 		total_col = '{}_total'.format(c)
# 		df[total_col] = df[c].sum()
# 		df[cond_col] = (df[c]*1000000)/df[total_col]
# 		tpm_cols.append(cond_col)

# 	# formatting
# 	df.index.name = id_col
# 	df = df[tpm_cols]
# 	for col in tpm_cols:
# 		new_col = col[:-4]
# 		df.rename({col: new_col}, axis=1, inplace=True)

# 	# reorder columns like adata.obs
# 	df = df[adata.obs[obs_col].unique().tolist()]
# 	df = df.transpose()

# 	# reorder in adata.var / t_df order
# 	if not isinstance(sg_df, type(None)):
# 		ids = [tid for tid in sg_df[id_col].tolist() if tid in df.columns.tolist()]
# 		df = df[sg_df[id_col].tolist()]

# 	return df

In [61]:
def add_abundance(sg, counts_file):
    """
    Adds abundance from a counts matrix to the SwanGraph. Transcripts in the
    SwanGraph but not in the counts matrix will be assigned 0 counts.
    Transcripts in the abundance matrix but not in the SwanGraph will not
    have expression added.

    Parameters:
        counts_file (str): Path to TSV expression file where first column is
            the transcript ID and following columns name the added datasets and
            their counts in each dataset, OR to a TALON abundance matrix.
    """

    # read in abundance file
    swan.check_file_loc(counts_file, 'abundance matrix')
    try:
        df = pd.read_csv(counts_file, sep='\t')
    except:
        raise ValueError('Problem reading expression matrix {}'.format(counts_file))

    # check if abundance matrix is a talon abundance matrix
    cols = ['gene_ID', 'transcript_ID', 'annot_gene_id', 'annot_transcript_id',
        'annot_gene_name', 'annot_transcript_name', 'n_exons', 'length',
        'gene_novelty', 'transcript_novelty', 'ISM_subtype']
    if df.columns.tolist()[:11] == cols:
        df = swan.reformat_talon_abundance(counts_file)

    # rename transcript ID column
    col = df.columns[0]
    df.rename({col: 'tid'}, axis=1, inplace=True)

    # limit to just the transcripts already in the graph
    sg_tids = sg.t_df.tid.tolist()
    ab_tids = df.tid.tolist()
    tids = list(set(sg_tids)&set(ab_tids))
    df = df.loc[df.tid.isin(tids)]
    
    # transpose to get adata format
    df.set_index('tid', inplace=True)
    df = df.T
    
    # get adata components - obs, var, and X
    var = df.columns.to_frame()
    var.columns = ['tid']
    obs = df.index.to_frame()
    obs.columns = ['dataset']
    X = df.to_numpy()
    
    # create transcript-level adata object
    adata = anndata.AnnData(var=var, obs=obs, X=X)

    # add each dataset to list of "datasets", check if any are already there!
    datasets = adata.obs.dataset.tolist()
    for d in datasets:
        if d in sg.datasets:
            raise ValueError('Dataset {} already present in the SwanGraph.'.format(d))
    sg.datasets.extend(datasets)

    print()
    if len(datasets) <= 5:
        print('Adding abundance for datasets {} to SwanGraph.'.format(', '.join(datasets)))
    else:
        mini_datasets = datasets[:5]
        n = len(datasets) - len(mini_datasets)
        print('Adding abundance for datasets {}... (and {} more) to SwanGraph'.format(', '.join(mini_datasets), n))

    # if there is preexisting abundance data in the SwanGraph, concatenate
    # otherwise, adata is the new transcript level adata
    if not sg.has_abundance():

        # create transcript-level adata object
        sg.adata = adata

        # add counts as layers
        sg.adata.layers['counts'] = sg.adata.X
#         print('calculating tpm')
#         sg.adata.layers['tpm'] = calc_tpm(sg.adata, sg.t_df).to_numpy()

#         # could probably parallelize calc_pi
#         if not sg.sc:
#             print('calculating tpm') 
#             sg.adata.layers['pi'] = swan.calc_pi(sg.adata, sg.t_df)[0].to_numpy()
    else:
        
        # first set current layer to be counts
        sg.adata.X = sg.adata.layers['counts']
        
        # concatenate existing adata with new one
        # outer join to add all new transcripts (that are from added
        # annotation or transcriptome) to the abundance
        uns = sg.adata.uns
        sg.adata = sg.adata.concatenate(adata, join='outer', index_unique=None)
        sg.adata.uns = uns
        
#         # recalculate pi and tpm
#         print('calculating tpm')
#         sg.adata.layers['tpm'] = calc_tpm(sg.adata, sg.t_df).to_numpy()

#         if not sg.sc:
#             print('calculating pi')
#             sg.adata.layers['pi'] = swan.calc_pi(sg.adata, sg.t_df)[0].to_numpy()

#     # add abundance for edges, TSS per gene, and TES per gene
#     # sg.create_edge_adata()
#     sg.create_end_adata(kind='tss')
#     sg.create_end_adata(kind='tes')

    # set abundance flag to true
    sg.abundance = True
    
    return sg

In [62]:
sg = swan.read('test_mousewg.p')
ab = '/Users/fairliereese/mortazavi_lab/data/mousewg/lr_bulk/talon/mouse_talon_abundance_filtered.tsv'

Read in graph from test_mousewg.p


In [63]:
# test adding de novo
sg = add_abundance(sg, ab)


Adding abundance for datasets gastroc_14d_f_2, gastroc_14d_f_1, heart_18-20mo_m_1, heart_18-20mo_m_2, heart_18-20mo_f_1... (and 86 more) to SwanGraph


In [221]:
def calc_tpm(adata, obs_col='dataset'):
    
    # calculate tpm using scanpy
    d = sc.pp.normalize_total(adata,
                              layer='counts',
                              target_sum=1e6,
                              key_added='total_counts',
                              inplace=False)
    adata.obs['total_counts'] = d['norm_factor']
    
    # turn into a sparse dataframe
    cols = adata.var.index.tolist()
    inds = adata.obs[obs_col].tolist()
    data = d['X']
    data = scipy.sparse.csr_matrix(data)
    df = pd.DataFrame.sparse.from_spmatrix(data, index=inds, columns=cols)
    df.index.name = obs_col    

    # average across tpm
    if obs_col != 'dataset':
        df.reset_index(inplace=True)
        df = df.groupby(obs_col).mean()
        
    return df

In [219]:
# test_calc_tpm_1
sg = swan.SwanGraph()
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')
sg.adata.obs['cluster'] = ['c1', 'c1']

df = calc_tpm(sg.adata)
df


Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.


Unnamed: 0_level_0,test1,test2,test3,test4,test5
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
dataset1,166666.671875,333333.34375,0.0,333333.34375,166666.671875
dataset2,166666.671875,0.0,333333.34375,333333.34375,166666.671875


In [220]:
# test_calc_tpm_2
sg = swan.SwanGraph()
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')
sg.adata.obs['cluster'] = ['c1', 'c1']

df = calc_tpm(sg.adata, obs_col='cluster')
df


Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.


Unnamed: 0_level_0,test1,test2,test3,test4,test5
cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
c1,166666.671875,166666.671875,166666.671875,333333.34375,166666.671875


In [225]:
def calc_total_counts(adata, obs_col='dataset', layer='counts'):
    
    # turn into a sparse dataframe
    cols = adata.var.index.tolist()
    inds = adata.obs[obs_col].tolist()
    data = adata.layers[layer]
    data = scipy.sparse.csr_matrix(data)
    df = pd.DataFrame.sparse.from_spmatrix(data, index=inds, columns=cols)
    df.index.name = obs_col 

    # add up values on condition (row)
    df = df.groupby(level=0).sum()
    
    return df

In [226]:
# test_calc_total_counts_1
sg = swan.SwanGraph()
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')
sg.adata.obs['cluster'] = ['c1', 'c1']

df = calc_total_counts(sg.adata)
df


Adding transcriptome to the SwanGraph

Adding abundance for datasets dataset1, dataset2 to SwanGraph.


Unnamed: 0_level_0,test1,test2,test3,test4,test5
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
dataset1,5.0,10.0,0.0,10.0,5.0
dataset2,5.0,0.0,10.0,10.0,5.0


In [None]:
# test_calc_total_counts_2
sg = swan.SwanGraph()
sg.add_transcriptome('../testing/files/test_full.gtf')
sg = add_abundance(sg, '../testing/files/test_ab_1.tsv')
sg.adata.obs['cluster'] = ['c1', 'c1']

df = calc_total_counts(sg.adata)
df

In [None]:
# test merging when incoming adata has new transcripts that old did not

In [None]:
# test merging when incoming adata has duplicate dataset names