In [1]:
# loading libraries
import pandas as pd
import networkx as nx
import os
from tqdm import tqdm
import statistics
import seaborn as sns
import matplotlib.pyplot as plt
import ray

In [2]:
def createGeneConverter():
    test = []
    geneId_geneName = {}
    with open('Homo_sapiens.GRCh37.74.gtf', 'r') as file:
        for line in file:
            line = line.strip()
            data = line.split('\t')[-1]
            test.append(data)
            if 'gene_name' in data:
                attributes = data.split(';')
                geneId = attributes[0].split(' ')[1].strip('"')
                for attr in attributes:
                    if 'gene_name' in attr:
                        geneName = attr.split(' ')[2].strip('"')
                        if geneId not in geneId_geneName:
                            geneId_geneName[geneId] = geneName
    geneName_geneId = {v: k for k, v in geneId_geneName.items()}
    return geneName_geneId, geneId_geneName

def createRegulatory(regulatory_filepath):
    geneName_geneId, geneId_geneName = createGeneConverter()
    # loading, cleaning, and permutating regulatory dataset
    print(f'Loading: {regulatory_filepath}')
    regulatory = pd.read_csv(regulatory_filepath, index_col=0)
    
    exceptions = []
    for name in regulatory.index:
        try:
            regulatory = regulatory.rename(index={name: geneName_geneId[name]})
        except:
            exceptions.append(name)
    
    print(f'Row exception count: {len(exceptions)}')
    
    for exc in exceptions:
        regulatory = regulatory.drop(exc)

    # if columns are ENSG IDs, uncomment this part
    exceptions = []
    for ID in regulatory.columns.tolist():
        if ID not in geneId_geneName:
            exceptions.append(ID)

    # if columns are gene names, uncomment this part
    # exceptions = []
    # for name in regulatory.columns.tolist():
    #     try:
    #         regulatory = regulatory.rename(columns={name: geneName_geneId[name]})
    #     except:
    #         exceptions.append(name)

    
    print(f'Column exception count: {len(exceptions)}')

    for exc in exceptions:
        regulatory = regulatory.drop(exc)

    
    def inverse(x):
        return 1/x
    
    def absolute(x):
        return abs(x)
    
    regulatory = regulatory.map(inverse)
    regulatory = regulatory.map(absolute)

    return regulatory

def createEndpoints(filepath):
    geneName_geneId, geneId_geneName = createGeneConverter()
    # loading and cleaning dataset
    endpoints = pd.read_csv(filepath, sep='\t', index_col=0)
    
    exceptions = []
    for name in endpoints.index:
        try:
            endpoints = endpoints.rename(index={name: geneName_geneId[name.upper()]})
        except:
            exceptions.append(name)
    
    print(f'Endpoints exception count: {len(exceptions)}')
    
    for exc in exceptions:
        endpoints = endpoints.drop(exc)

    return list(endpoints.index)

def createRandomEndpoints(regulatory, num, seed):
    rand = regulatory.sample(n=num, random_state=seed)
    return list(rand.index)

def createGraph(regulatory, endpoints):
    # creating network
    regMatrix = regulatory.to_numpy().tolist()
    
    G = nx.Graph()
    nodes = list(set(list(regulatory.index) + regulatory.columns.tolist() + endpoints))
    G.add_nodes_from(nodes)
    edgeCount = 0
    for rowName, row in zip(regulatory.index, regMatrix):
        for columnName, cell in zip(regulatory.columns.tolist(), row):
            # if cell < 10:
            G.add_edge(rowName, columnName, weight=cell)
            edgeCount += 1
    
    print(f'EdgeCount: {edgeCount}')

    return G

@ray.remote
def processData(G, params):
    origin, endpoints = params
    count = {}
    for i, endpoint in enumerate(endpoints):
        print(i)
        try:
            path = nx.shortest_path(G, origin, endpoint, weight="weight")
            path.pop(0)
            path.pop()
            for node in path:
                if node not in count:
                    count[node] = 1
                else:
                    count[node] += 1
        except:
            pass
        
    return count

def connectionEnrichment(origins, endpoints):
    datasets = os.listdir('data')
    datasets = [dataset for dataset in datasets if dataset != '.DS_Store']
    count = {}

    for dataset in datasets:
        regulatory = createRegulatory(f'data/{dataset}')
        G = createGraph(regulatory, endpoints)
        dataId = ray.put(G)
        resultIds = [processData.remote(dataId, (origin, endpoints)) for origin in origins]
        output = ray.get(resultIds)
        
        for dic in output:
            for key, val in dic.items():
                if key not in count:
                    count[key] = 1
                else:
                    count[key] += 1
    return count
    
def sortDic(dic):
    return dict(reversed(sorted(dic.items(), key=lambda item: item[1])))

def printResultsWithStats(dic):
    print(f'Size: {len(dic)}')
    print(f'Average: {statistics.mean(dic.values())}')
    print(f'Median: {statistics.median(dic.values())}')
    print(dic)

In [3]:
ray.init(num_cpus=8, ignore_reinit_error=True)

2025-03-27 14:24:48,231	INFO worker.py:1843 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


0,1
Python version:,3.12.2
Ray version:,2.44.0
Dashboard:,http://127.0.0.1:8265


[36m(processData pid=49690)[0m 0
[36m(processData pid=49692)[0m 0
[36m(processData pid=49692)[0m 1
[36m(processData pid=49694)[0m 0
[36m(processData pid=49696)[0m 0
[36m(processData pid=49695)[0m 0
[36m(processData pid=49695)[0m 1
[36m(processData pid=49695)[0m 2
[36m(processData pid=49695)[0m 3
[36m(processData pid=49695)[0m 4
[36m(processData pid=49695)[0m 5
[36m(processData pid=49695)[0m 6
[36m(processData pid=49695)[0m 7
[36m(processData pid=49695)[0m 8
[36m(processData pid=49695)[0m 9
[36m(processData pid=49695)[0m 10
[36m(processData pid=49695)[0m 11
[36m(processData pid=49695)[0m 12
[36m(processData pid=49695)[0m 13
[36m(processData pid=49695)[0m 14
[36m(processData pid=49695)[0m 15
[36m(processData pid=49695)[0m 16
[36m(processData pid=49695)[0m 17
[36m(processData pid=49695)[0m 18
[36m(processData pid=49695)[0m 19
[36m(processData pid=49695)[0m 20
[36m(processData pid=49695)[0m 21
[36m(processData pid=49695)[0m 22
[36m(p

In [None]:
# multi-dataset aging gene enrichment
globalAgingGenes = createEndpoints('global_aging_genes.tsv')
agingCount = connectionEnrichment(globalAgingGenes, globalAgingGenes)

Endpoints exception count: 15
Loading: data/Testis.csv
Row exception count: 0
Column exception count: 0
EdgeCount: 19476492


In [None]:
## printing results
printResultsWithStats(sortDic(agingCount))

In [None]:
# control experiment 1
datasets = os.listdir('data')
datasets = [dataset for dataset in datasets if dataset != '.DS_Store']
regulatory = createRegulatory(f'data/{datasets[0]}')
globalAgingGenes = createEndpoints('global_aging_genes.tsv')
randomGenes = createRandomEndpoints(regulatory, len(globalAgingGenes), 42)
controlCount1 = connectionEnrichment(randomGenes, randomGenes)

In [None]:
# printing results
printResultsWithStats(sortDic(controlCount1))

In [None]:
# control experiment 2
datasets = os.listdir('data')
datasets = [dataset for dataset in datasets if dataset != '.DS_Store']
regulatory = createRegulatory(f'data/{datasets[0]}')
globalAgingGenes = createEndpoints('global_aging_genes.tsv')
randomGenes = createRandomEndpoints(regulatory, len(globalAgingGenes), 43)
controlCount2 = connectionEnrichment(randomGenes, randomGenes)

In [None]:
# printing results
printResultsWithStats(sortDic(controlCount2))

In [None]:
# control experiment 3
datasets = os.listdir('data')
datasets = [dataset for dataset in datasets if dataset != '.DS_Store']
regulatory = createRegulatory(f'data/{datasets[0]}')
globalAgingGenes = createEndpoints('global_aging_genes.tsv')
randomGenes = createRandomEndpoints(regulatory, len(globalAgingGenes), 44)
controlCount3 = connectionEnrichment(randomGenes, randomGenes)

In [None]:
# printing results
printResultsWithStats(sortDic(controlCount3))

In [None]:
# creating graph
agingVal = list(agingCount.values())
controlVal1 = list(controlCount1.values())
controlVal2 = list(controlCount2.values())
controlVal3 = list(controlCount3.values())
agingName = ['aging' for i in range(len(agingVal))]
controlName1 = ['control1' for i in range(len(controlVal1))]
controlName2 = ['control2' for i in range(len(controlVal2))]
controlName3 = ['control3' for i in range(len(controlVal3))]

counts = agingVal + controlVal1 + controlVal2 + controlVal3
names = agingName + controlName1 + controlName2 + controlName3

df = pd.DataFrame({'counts': counts, 'names': names})

sns.set_theme(style="ticks")

f, ax = plt.subplots(figsize=(7, 5))
sns.despine(f)

sns.histplot(
    df,
    x="counts", hue="names",
    multiple="dodge",
    palette="light:m_r",
    edgecolor=".3",
    linewidth=.5,
    bins = 16,
)

plt.yscale('log')
plt.xlabel("Number of overlaps")

In [None]:
for name1, dataset1 in zip(['aging', 'control1', 'control2', 'control3'], [agingCount, controlCount1, controlCount2, controlCount3]):
    print(f'Origin: {name1}')
    for name2, dataset2 in zip(['aging', 'control1', 'control2', 'control3'], [agingCount, controlCount1, controlCount2, controlCount3]):
        print(f'Endpoint: {name2}')
        intersection = set(dataset1).intersection(set(dataset2))
        sideTotal2 = sum([dataset2[i] for i in intersection])
        print(f'{len(intersection)} // {sideTotal2} // {int(sideTotal2 / len(intersection))}')

In [None]:
genageDF = pd.read_csv("genage_human.csv")
genageList = list(genageDF['symbol'])
genageENSGList = [geneName_geneId[name] for name in genageList]

In [None]:
# experiment
filteredAgingCount = {k:v for k,v in agingCount.items() if k in genageENSGList}

In [None]:
printResultsWithStats(sortDic(filteredAgingCount))

In [None]:
# control 1
filteredControlCount1 = {k:v for k,v in controlCount1.items() if k in genageENSGList}

In [None]:
printResultsWithStats(sortDic(filteredControlCount1))

In [None]:
# control 1
filteredControlCount2 = {k:v for k,v in controlCount2.items() if k in genageENSGList}

In [None]:
printResultsWithStats(sortDic(filteredControlCount2))

In [None]:
# control 1
filteredControlCount3 = {k:v for k,v in controlCount3.items() if k in genageENSGList}

In [None]:
printResultsWithStats(sortDic(filteredControlCount3))

In [None]:
# creating graph
agingVal = list(filteredAgingCount.values())
controlVal1 = list(filteredControlCount1.values())
controlVal2 = list(filteredControlCount2.values())
controlVal3 = list(filteredControlCount3.values())
agingName = ['aging' for i in range(len(agingVal))]
controlName1 = ['control1' for i in range(len(controlVal1))]
controlName2 = ['control2' for i in range(len(controlVal2))]
controlName3 = ['control3' for i in range(len(controlVal3))]

counts = agingVal + controlVal1 + controlVal2 + controlVal3
names = agingName + controlName1 + controlName2 + controlName3

df = pd.DataFrame({'counts': counts, 'names': names})

sns.set_theme(style="ticks")

f, ax = plt.subplots(figsize=(7, 5))
sns.despine(f)

sns.histplot(
    df,
    x="counts", hue="names",
    multiple="dodge",
    palette="light:m_r",
    edgecolor=".3",
    linewidth=.5,
    bins = 16,
)

plt.yscale('log')
plt.xlabel("Number of overlaps between central genes and GenAge")