In [None]:
import numpy as np
import networkx as nx
import joblib
import pandas as pd
    
from collections import Counter

def process_thresholds(lst, N):
    if N < 2:
        raise ValueError("N must be at least 2 to include min and max thresholds.")
    
    # Count occurrences of each value
    count = Counter(lst)
    #print(count)
    
    # Find the minimum and maximum values
    min_val, max_val = min(lst), max(lst)
    
    # Remove min and max from the counting
    count.pop(min_val, None)
    count.pop(max_val, None)
    
    # Select the N-1 values with the highest counts
    top_values = sorted(count.items(), key=lambda x: x[1], reverse=True)[:N-1]
    
    # Prepare the thresholds: a_0=min, top N-1 values, a_N=max
    thresholds = [min_val] + [value for value, _ in top_values] + [max_val]
    
    return sorted(thresholds)

In [None]:
from GraphRicciCurvature.OllivierRicci import OllivierRicci

def get_thresh(num_graph, edgedata, graph_ind):
    
    thresh = []
    graph_list = []
    
    for graph_id in range(1, 1 + num_graph):
        graph = nx.Graph()
        id_location = [index + 1 for index, element in enumerate(graph_ind) if
                       element == graph_id]  # list the index of the graph_id locations
        graph_edges = np.array(edgedata[edgedata['from'].isin(id_location)])
        graph.add_nodes_from(id_location)
        graph.add_edges_from(graph_edges)

        # Compute Ricci curvature
        orc = OllivierRicci(graph, alpha=0.5, verbose="INFO")
        orc.compute_ricci_curvature()
        
        # Add Ricci curvature as weights to the graph
        for line in np.array(graph_edges):
            curvature = round(orc.G[line[0]][line[1]]["ricciCurvature"], 6)
            thresh.append(curvature)
            graph[line[0]][line[1]]['weight'] = curvature  # Add as edge weight
        
        graph_list.append(graph)
        
    thresh = process_thresholds(thresh, 20)
        
    return graph_list, thresh
        







In [None]:
import pandas as pd
import numpy as np
import scipy
import pyflagser


import networkx as nx


def sub_PH(num_graph, G, threshold_array, m):
    threshold_array = sorted(threshold_array)
    N = len(threshold_array)-m
    Bet0 = []
    Bet1 = []
    cell0 = []
    cell1 = []
    

    for graph_id in range(num_graph):
        B0 = []
        B1 = []
        c0 = []
        c1 = []
        graph = G[graph_id]
        
        for val in range(N):
            Rindex = [(u, v) for u, v, d in graph.edges(data=True) if threshold_array[val]<= d['weight'] <= threshold_array[val+m]]
            # Create a subgraph from filtered edges
            sub = graph.edge_subgraph(Rindex).copy()
            
            c0.append(sub.number_of_nodes())
            c1.append(sub.number_of_edges())
            
            try:
                adjacency_matrix = nx.adjacency_matrix(sub)
                adjacency_matrix = adjacency_matrix.todense() 

                diagr = pyflagser.flagser_unweighted(adjacency_matrix, min_dimension=0, max_dimension=2, directed=False, coeff=2, approximation=None)

                B0.append(diagr['betti'][0])
                B1.append(diagr['betti'][1])
            except:
                B0.append(0)
                B1.append(0)
            
            
        Bet0.append(B0)
        Bet1.append(B1)
        cell0.append(c0)
        cell1.append(c1)


    return Bet0, Bet1, cell0, cell1

In [None]:
import pandas as pd
import numpy as np
import scipy
import networkx as nx
import time


for name in ['BZR', 'COX2', 'MUTAG', 'PROTEINS', 'IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-BINARY', 'REDDIT-MULTI-5K']:
    
    print(name)
    start_time = time.time()
    edgedata = pd.read_csv(r"/home/astrit/Downloads/" + name + "/" + name + "_A.txt", header=None)
    edgedata.columns = ['from', 'to']
    graph_indicators = pd.read_csv(r"/home/astrit/Downloads/"+name+"/"+name+"_graph_indicator.txt", header=None)
    graph_indicators.columns = ["ID"]
    graph_ind = (graph_indicators["ID"].values.astype(int))
    
    num_graph = len(np.unique(np.array(graph_indicators))) # total number of graphs
    
    graph_info, thresh = get_thresh(num_graph, edgedata, graph_ind)
    end_time = time.time()
    
    for m in [2]:
        start_time_m = time.time()
        B0_sub, B1_sub, c0_sub, c1_sub = sub_PH(num_graph, graph_info, thresh, m)
        
        B0_sub = pd.DataFrame(B0_sub)
        B1_sub = pd.DataFrame(B1_sub)
        c0_sub = pd.DataFrame(c0_sub)
        c1_sub = pd.DataFrame(c1_sub)
        
        B0_sub.to_csv(name + "B0_ricci_sub"+str(m)+".csv")
        B1_sub.to_csv(name + "B1_ricci_sub"+str(m)+".csv")
        c0_sub.to_csv(name + "c0_ricci_sub"+str(m)+".csv")
        c1_sub.to_csv(name + "c1_ricci_sub"+str(m)+".csv")
        end_time_m = time.time()

        print(m, end_time-start_time, end_time_m-start_time_m)






