In [44]:
import pandas as pd
import itertools
import obonet
import networkx as nx
import gzip

def parse_gaf(file_path, namespace_filter="biological_process"):
    gene2go = {}
    go2genes = {}

    aspect_map = {
        "P": "biological_process",
        "F": "molecular_function",
        "C": "cellular_component"
    }

    open_func = gzip.open if file_path.endswith(".gz") else open

    with open_func(file_path, "rt") as f:
        for line in f:
            if line.startswith("!"):
                continue  # Skip comments

            fields = line.strip().split("\t")
            if len(fields) < 15:
                continue

            #print(fields)
            gene = fields[2]
            #print(gene)
            go_term = fields[4]
            aspect_code = fields[8]
            namespace = aspect_map.get(aspect_code, None)

            if namespace_filter and namespace != namespace_filter:
                continue

            # Update gene2go
            gene2go.setdefault(gene, set()).add(go_term)

            # Update go2genes
            go2genes.setdefault(go_term, set()).add(gene)

    return gene2go, go2genes

class GeneSetAnalyzer:
    def __init__(self, gene2go, go2genes, graph, bp_terms, go_df, md):
        self.gene2go = gene2go
        self.go2genes = go2genes
        self.graph = graph
        self.bp_terms = bp_terms
        self.go_df = go_df
        self.md = md

    def get_all_go_ancestors(self, term):
        if term not in self.graph:
            return set()
        all_ancestors = nx.descendants(self.graph,term)
        return all_ancestors

    def get_biological_process_ancestors(self, term):
        all_ancestors = self.get_all_go_ancestors(term)
        return all_ancestors & self.bp_terms

    def gene_lookup(self, gene):
        golist = list(self.gene2go.get(gene, []))
        for go in golist.copy():
            golist.extend(self.get_biological_process_ancestors(go))
        golist = list(set(golist))
        df1 = pd.DataFrame({'GO_bp_term': golist})
        df2 = df1.merge(self.go_df, on='GO_bp_term')
        return df2[df2['GO_bp_term'] != 'GO:0008150']

    def gene_lookup_shallow(self, gene):
        golist = list(self.gene2go.get(gene, []))
        df1 = pd.DataFrame({'GO_bp_term': golist})
        df2 = df1.merge(self.go_df, on='GO_bp_term')
        return df2[df2['GO_bp_term'] != 'GO:0008150']

    def gene_lookup_and_merge(self, gene1, gene2, shallow=False, filter_threshold=100):
        try:
            if shallow:
                df1 = self.gene_lookup_shallow(gene1).drop(columns=['number_of_genes'])
                df2 = self.gene_lookup_shallow(gene2)
            else:
                df1 = self.gene_lookup(gene1).drop(columns=['number_of_genes'])
                df2 = self.gene_lookup(gene2)
            
            gdf = df1.merge(df2, on='GO_bp_term')
            if filter_threshold is not None:
                gdf = gdf[gdf['number_of_genes'] < filter_threshold]
            return 1 if not gdf.empty else 0
        except KeyError as e:
            print(f"KeyError during lookup for {gene1}, {gene2}: {e}")
            return None

    def analyze_gene_set(self, gene_set, shallow=False, filter_threshold=100):
        pairwise = list(itertools.combinations(gene_set, 2))
        g1, g2, match = [], [], []
        for gene1, gene2 in pairwise:
            val = self.gene_lookup_and_merge(gene1, gene2, shallow=shallow, filter_threshold=filter_threshold)
            if val is not None:
                g1.append(gene1)
                g2.append(gene2)
                match.append(val)
        return pd.DataFrame({'Gene_A': g1, 'Gene_B': g2, 'Match': match})

    def get_filtered_md(self, gene_set):
        return self.md[self.md.apply(
            lambda row: row['Gene_A'] in gene_set and row['Gene_B'] in gene_set,
            axis=1
        )]

    def get_background(self, gene_set, sample_n=100, shallow=False, filter_threshold=100, random_seed=42):
        non_gene_df = self.md[
            ~self.md['Gene_A'].isin(gene_set) & ~self.md['Gene_B'].isin(gene_set)
        ].sample(n=sample_n, random_state=random_seed)

        g1, g2, match = [], [], []
        for _, row in non_gene_df.iterrows():
            val = self.gene_lookup_and_merge(row['Gene_A'], row['Gene_B'], shallow=shallow, filter_threshold=filter_threshold)
            if val is not None:
                g1.append(row['Gene_A'])
                g2.append(row['Gene_B'])
                match.append(val)

        match_df = pd.DataFrame({'Gene_A': g1, 'Gene_B': g2, 'Match': match})
        return non_gene_df.merge(match_df, on=['Gene_A', 'Gene_B'], how='inner')

    def run_analysis(self, gene_set, shallow=False, filter_threshold=100, sample_n=100, random_seed=42):
        gene_df = self.analyze_gene_set(gene_set, shallow=shallow, filter_threshold=filter_threshold)

        merged_df = gene_df.merge(self.get_filtered_md(gene_set), on=['Gene_A', 'Gene_B'], how='inner')

        # background = self.get_background(gene_set, sample_n=sample_n, shallow=shallow,
        #                                  filter_threshold=filter_threshold, random_seed=random_seed)
        return merged_df #, background


In [None]:
analyzer = GeneSetAnalyzer(
    gene2go=gene2go,
    go2genes=go2genes,
    graph=graph,
    bp_terms=bp_terms,
    go_df=df,
    md=md
)

IFT_gene_set = {g for g in gene2go if g.startswith('IFT')}
# ift_df, background_df
analyzer.run_analysis(IFT_gene_set, sample_n=400)

In [25]:
# """
# CONFIGURATION
# all inputs to GeneSetAnalyzer class
# """

# # gene2go and go2genes dictionaries
# gaf_path = "goa_human.gaf"  # or .gz
# gene2go, go2genes = parse_gaf(gaf_path, namespace_filter="biological_process")

# # graph
# graph = obonet.read_obo("go-basic.obo")

# # bp terms
# bp_root = "GO:0008150"
# bp_terms = {
#     term for term in graph.nodes
#     if nx.has_path(graph, term, bp_root)
# }

# # go_df
# df=pd.DataFrame()
# golist = []; genelist = []
# for go in list(go2genes.keys()):
#     golist.append(go)
#     genelist.append(len(list(go2genes[go])))
# df['GO_bp_term']=golist
# df['number_of_genes']=genelist

# # md - cosine similarity gene pairs
# md = pd.read_pickle("/work2/05515/bflynn/frontera/gene_similarity_gobp_value.pkl").reset_index().drop(columns={'index'})