In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import pandas as pd
import scipy.stats as st
import glob


def load_gene_ontology_data(): 
    
    """Load the GO annotation dataset of E. coli K-12. """
    
    gene_ontology_data = pd.read_csv('~/jupyter-notebooks/rpgroup/data/GO_annotations_ecoli.csv')
    
    return gene_ontology_data

In [2]:
GO = load_gene_ontology_data()

In [3]:
GO.head()

Unnamed: 0,GO_ID,GO_term,Category,gene_name
0,GO:0003723,RNA binding,Function,pcnB
1,GO:0003723,RNA binding,Function,cspE
2,GO:0003723,RNA binding,Function,cspB
3,GO:0003723,RNA binding,Function,thrS
4,GO:0003723,RNA binding,Function,cspC


In [6]:
def lower_strings(string_list):
    
    
    return [str(x).lower() for x in string_list]

In [12]:
def get_gene_data(data, gene_name_column, test_gene_list):
    
    """Extract data from specific genes given a larger dataframe.
    
    Inputs
    
    * data: large dataframe from where to filter
    * gene_name_column: column to filter from
    * test_gene_list : a list of genes you want to get
    
    Output
    * dataframe with the genes you want
    """
    
    gene_profiles = pd.DataFrame()

    for gene in data[gene_name_column].values:

        if gene in test_gene_list: 

            df_ = data[(data[gene_name_column] == gene)]

            gene_profiles = pd.concat([gene_profiles, df_])
    
    gene_profiles.drop_duplicates(inplace = True)
    
    return gene_profiles

In [13]:
def get_GO_gene_set(gene_ontology_data, test_gene_list):
    
    """
    Given a list of genes of interest and the Gene Ontology annotation dataset,
    filter the Gene Ontology dataset for E. coli to make an enrichment analysis.
    _____________________________________________________________________________
    
    inputs~
    
    gene_ontology_data: GO annotation dataset.
    test_gene_list: List of genes of interest.  
    
    outputs~
    
    GO_gene_set:Filtered GO annotation dataset corresponding to the test gene set. 
    
    """
    gene_ontology_data = load_gene_ontology_data()
    
    #Call the sortSeq library to lower the gene names
    gene_ontology_data.gene_name = lower_strings(gene_ontology_data.gene_name.values)
    
    #Call the sortSeq library filter only the GO data from the test gene list 
    GO_gene_set = get_gene_data(gene_ontology_data, 'gene_name', test_gene_list)
    
    return GO_gene_set

In [15]:
GO.head()

Unnamed: 0,GO_ID,GO_term,Category,gene_name
0,GO:0003723,RNA binding,Function,pcnB
1,GO:0003723,RNA binding,Function,cspE
2,GO:0003723,RNA binding,Function,cspB
3,GO:0003723,RNA binding,Function,thrS
4,GO:0003723,RNA binding,Function,cspC


In [41]:
go_gene_set = get_gene_data(GO, 'gene_name', ['lacI', 'lacZ', 'rhaR', 
                                             'thrS', 'fliZ', 'fliL', 
                                              'waaA', 'ori'])

In [42]:
go_gene_set.GO_ID.value_counts()

GO:0005887    2
GO:0042802    2
GO:0006435    1
GO:0016989    1
GO:0008144    1
GO:0043565    1
GO:0071978    1
GO:0045892    1
GO:0016020    1
GO:0008270    1
GO:0009341    1
GO:1902021    1
GO:0048027    1
GO:0000986    1
GO:0004565    1
GO:0004829    1
GO:0005990    1
GO:0009245    1
GO:0001047    1
GO:0004812    1
GO:0043039    1
GO:0045947    1
GO:0003723    1
GO:0005524    1
GO:0001217    1
GO:0005886    1
GO:0009425    1
GO:0006418    1
GO:0016740    1
GO:0031420    1
GO:0005737    1
GO:0000287    1
GO:0005515    1
GO:0006417    1
GO:0006351    1
GO:0000900    1
GO:0002161    1
Name: GO_ID, dtype: int64

In [43]:
def get_hi_GOs(GO_gene_set):
    
    """
    Get the GO IDs whose counts are above the 5% of the total entries of the GO_gene_set.
    
    This allows to reduce our search space and only calculate enrichment p-values for highly 
    represented GOs.
    
    * GO: gene ontology 
    
    -------------------------------------------------------
    input~ GO_gene_set :Filtered GO annotation dataset corresponding to the test gene set. 
    
    output ~ GO IDs that represent > 10% of the dataset. 
    """
    #Treshold = get only the GOs whose counts > 10% of the total counts of GOs in the gene set 
    thr = int(GO_gene_set.shape[0] * 0.10)
    
    #Check that GO_gene_set is not empty.
    if GO_gene_set.shape[0] > 1:
    
        #Get the indices of the GOs that are above the threshold 
        hi_indices = GO_gene_set.GO_ID.value_counts().values > thr


        #Filter and get the GO IDs that are above threshold
        hi_GO_ids = GO_gene_set.GO_ID.value_counts().loc[hi_indices].index.values
        
        #Check that there are GO_IDs above the threshold
        if len(hi_GO_ids) > 0:

            return hi_GO_ids

        else: 
            print('No enriched functions found.')
                
    else: 
        
        print('No enriched functions found.')

In [44]:
hi_go_ids = get_hi_GOs(go_gene_set)

No enriched functions found.


In [45]:
get_hyper_test_p_value(GO, go_gene_set, hi_go_ids)

Enrichment test did not run.


In [38]:
GO.shape

(15126, 4)

In [37]:
GO['GO_ID'].value_counts()[hi_go_ids[0]]

383

In [31]:
GO[GO['GO_ID']==hi_go_ids[0] ].shape

(383, 4)

In [39]:
def get_hyper_test_p_value(gene_ontology_data, GO_gene_set, hi_GO_ids):
    
    """
    Given a list of GO IDs, calculate its p-value according to the hypergeometric distribution. 
    -------------------------------------------------------
    inputs~
    
    gene_ontology_data: GO annotation dataset.
    GO_gene_set: Filtered GO annotation dataset corresponding to the test gene set. 
    hi_GO_ids: Overrepresented GO IDs. 
    
    outputs~
    
    summary_df: Summary dataframe with the statistically overrepresented GO IDs w/ their reported p-value
                and associated cofit genes. 
    
    """
    
    if hi_GO_ids is not None and len(hi_GO_ids) > 0: 
        n = GO_gene_set.shape[0] # sample size

        M = gene_ontology_data.shape[0] # total number of balls ~ total number of annotations

        p_vals = np.empty(len(hi_GO_ids))

        for i, hi_GO in enumerate(hi_GO_ids):

            # White balls drawn : counts of the hiGO in the GO_gene_set dataset
            w = pd.value_counts(GO_gene_set['GO_ID'].values, sort=False)[hi_GO]

            # Black balls drawn : counts of all of the GO IDs not corresponding to the specific hi_GO
            b = GO_gene_set.shape[0] - w

            # Total number of white balls in the bag : counts of the hiGO in the whole genome
            w_genome = pd.value_counts(gene_ontology_data['GO_ID'].values, sort=False)[hi_GO]

            # Total number of black balls in the bag : counts of non-hiGO IDs in the whole genome
            b_genome = gene_ontology_data.shape[0] - w_genome

            #Initialize an empty array to store the PMFs values
            hypergeom_pmfs = np.empty(n - w + 1)

            #Get all of the PMFs that are >= w (overrepresentation test)

            pmfs = st.hypergeom.pmf(k = np.arange(w, n+1), N = n, n = w_genome, M = M)

            #P-value = PMFs >= w 
            p_val = hypergeom_pmfs.sum() 
            
            # Apply bonferroni correction 
            p_val = p_val * n

            #Store p_value in the list 
            p_vals[i] = p_val

        #Filter the p_values < 0.05 
        significant_indices = p_vals < 0.05
        significant_pvals = p_vals[significant_indices]
        #Get significant GO_IDs 
        significant_GOs = hi_GO_ids[significant_indices]

        GO_summary_df = pd.DataFrame({ 'GO_ID': significant_GOs, 'p_val': significant_pvals })

        #Make a left inner join
        summary_df = pd.merge(GO_summary_df, GO_gene_set, on = 'GO_ID', how = 'inner')
        
        print('Enrichment test ran succesfully!')
        
        return summary_df
    
    else: 
        
        print('Enrichment test did not run.')
        

In [None]:

        
def make_GSEA(gene_ontology_data, test_gene_list): 
    
    """ Wrapper function to get the enriched functions in a particular test_gene_list """
    
    GO_gene_set = get_GO_gene_set(gene_ontology_data, test_gene_list)
    
    hi_GO_ids = get_hi_GOs(GO_gene_set)
    
    GSEA_df = get_hyper_test_p_value(gene_ontology_data, GO_gene_set, hi_GO_ids)
    
    ecocyc_annotation = load_ecocyc_annotation()
        
    ecocyc_gene_set = get_gene_data(ecocyc_annotation, 'gene_name', test_gene_list)
    
    #print(GSEA_df)
    
    if GSEA_df is not None and ecocyc_gene_set is not None and ecocyc_gene_set.shape[0] > 1 :

        summary_df = pd.merge(GSEA_df, ecocyc_gene_set, on = 'gene_name', how = 'left')
        
        return summary_df   
    
    else :
    
        return GSEA_df   