In [1]:
import os
import gzip
from tqdm import tqdm
from graph_tool import Graph, GraphView, load_graph
from graph_tool.util import find_vertex
import graph_tool as gt

In [2]:
class GraphProcessor:
    def __init__(
        self, 
        data_dir, 
        triples_file='wikidata5m_transductive.tar.gz', 
        graph_file='wikidata5m_graph.gt.gz', 
        max_subgraph_edges=1000000):
        
        triples_path = os.path.join(data_dir, triples_file)
        graph_path = os.path.join(data_dir, graph_file)
        
        self.graph = self.build_full_graph(data_dir, triples_path, graph_path, max_subgraph_edges)

    def build_graph(self, triples):
        g = Graph(directed=True)
        vprop = g.new_vertex_property("string")
        eprop = g.new_edge_property("string")

        vertices = {}
        for s, p, o in triples:
            if s not in vertices:
                v1 = g.add_vertex()
                vprop[v1] = s
                vertices[s] = v1
            else:
                v1 = vertices[s]

            if o not in vertices:
                v2 = g.add_vertex()
                vprop[v2] = o
                vertices[o] = v2
            else:
                v2 = vertices[o]

            e = g.add_edge(v1, v2)
            eprop[e] = p

        g.vertex_properties["name"] = vprop
        g.edge_properties["name"] = eprop

        return g
    
    def merge_subgraph(main_graph, subgraph):
        # Dictionary to store the mapping of old vertices in subgraph to new vertices in main_graph
        vertex_map = {}

        for v in subgraph.vertices():
            v_name = subgraph.vp.name[v]
            
            # Check if vertex already exists in main_graph
            existing_vertices = find_vertex(main_graph, main_graph.vp.name, v_name)
            if existing_vertices:
                v_new = existing_vertices[0]
            else:
                v_new = main_graph.add_vertex()
                main_graph.vp.name[v_new] = v_name

            vertex_map[v] = v_new

        for e in subgraph.edges():
            source = vertex_map[e.source()]
            target = vertex_map[e.target()]
            e_new = main_graph.add_edge(source, target)
            main_graph.ep.name[e_new] = subgraph.ep.name[e]

    
    def build_full_graph(self, data_dir, triples_path, graph_path, max_subgraph_edges=1000000):       
        if os.path.exists(graph_path):
            with gzip.open(graph_path, 'rb') as f:
                return load_graph(f)
            
        else:
            # Extract triples using the gzip approach
            with gzip.open(triples_path, 'rt') as file:
                lines = file.readlines()[1:-1]
            
            num_chunks = -(-len(lines) // max_subgraph_edges)  # Ceiling division
            subgraphs = []  # to store paths of saved subgraphs
                    
            for i in tqdm(range(num_chunks), desc="Building Subgraphs"):
                start = i * max_subgraph_edges
                end = min((i + 1) * max_subgraph_edges, len(lines))
                triples_chunk = [tuple(line.strip().split('\t')) for line in lines[start:end]]
                
                subgraph = self.build_graph(triples_chunk)
                
                subgraph_path = os.path.join(data_dir, f"wikidata5m_subgraph_{i}.gt")
                subgraph.save(subgraph_path)
                subgraphs.append(subgraph_path)

            # Start with the first subgraph
            main_graph = load_graph(subgraphs[0])
            
            # Add the other subgraphs to it one by one
            for subgraph_path in tqdm(subgraphs[1:], desc="Merging Subgraphs"):
                subgraph = load_graph(subgraph_path)
                self.merge_subgraph(main_graph, subgraph)
                
            # Saving the final merged graph in gzipped format
            with gzip.open(graph_path, 'wb') as f:
                main_graph.save(f)
            
            # Deleting the subgraph files
            for subgraph_path in subgraphs:
                os.remove(subgraph_path)

            return main_graph

    # def get_k_hop_neighbors(self, node, k=2):
    #     """Retrieve k-hop neighbors for a given node."""
    #     v = find_vertex(self.graph, self.vprop, node)
    #     if v:
    #         v = v[0]
    #         neighbors = set()
    #         for e in self.graph.edges():
    #             if e.source() == v or e.target() == v:
    #                 neighbors.add(int(self.vprop[e.source()]))
    #                 neighbors.add(int(self.vprop[e.target()]))
    #         return list(neighbors)
    #     else:
    #         return []

    # def prune_by_degree(self, vertices, k=2, max_nodes=256):
    #     """Prune the subgraph based on node degree."""
    #     while len(vertices) > max_nodes:
    #         degree_dict = {v: self.graph.vertex(v).out_degree() for v in vertices}
    #         sorted_vertices = sorted(degree_dict.keys(), key=lambda x: degree_dict[x])
    #         vertices = sorted_vertices[:max_nodes]
    #     return vertices

    # def get_subgraph_for_entities(self, entities, k=2, max_nodes=256):
    #     """Retrieve a subgraph for a given list of entities."""
    #     all_neighbors = set()
    #     for entity in entities:
    #         all_neighbors.update(self.get_k_hop_neighbors(entity, k))
    #     subgraph_vertices = list(all_neighbors)
    #     pruned_subgraph_vertices = self.prune_by_degree(subgraph_vertices, k=k, max_nodes=max_nodes)

    #     subgraph = GraphView(self.graph, vfilt=lambda v: self.vprop[v] in pruned_subgraph_vertices)

    #     return subgraph

In [3]:
graph_dir = os.path.join(os.getcwd(), '../data/wikidata5m')
graph_path = os.path.join(os.getcwd(), graph_dir, 'wikidata5m_graph.gt.gz')
triples_path = os.path.join(os.getcwd(), graph_dir, 'wikidata5m_transductive.tar.gz')

In [4]:
gp = GraphProcessor(data_dir=graph_dir, max_subgraph_edges=21000000)

In [5]:
from collections import Counter
import numpy as np

In [10]:
def graph_statistics(graph, n):
    # Calculate the average degree
    degrees = graph.get_total_degrees(graph.get_vertices())
    avg_degree = np.mean(degrees)

    # Calculate the mode of the degree
    degree_counts = Counter(degrees)
    mode_degree = max(degree_counts, key=degree_counts.get)

    # Count the number of nodes with degree less than n
    count_below_n = sum(1 for d in degrees if d > n)

    return avg_degree, mode_degree, count_below_n

In [11]:
avg_degree, mode_degree, count_below_n = graph_statistics(gp.graph, 10)

In [12]:
print(f"Average Degree: {avg_degree}")
print(f"Mode Degree: {mode_degree}")
print(f"Number of nodes with degree more than {10}: {count_below_n}")

Average Degree: 8.977967938089266
Mode Degree: 3
Number of nodes with degree more than 10: 533984


In [6]:
def num_edge_types(graph, edge_property_name="name"):
    # Extract the edge property values into a set to get unique values
    unique_edge_types = set(graph.ep[edge_property_name])
    return len(unique_edge_types)

# Usage example:
n_edge_types = num_edge_types(gp.graph, "name")
print(f"Total number of unique edge types: {n_edge_types}")

Total number of unique edge types: 822


In [7]:
with gzip.open(triples_path, 'rt') as file:
    lines = file.readlines()[1:-1]

In [8]:
for line in lines[:15]:
    triple = line.strip().split('\t')
    print(triple)

['Q6719921', 'P31', 'Q11446']
['Q4925109', 'P175', 'Q5165801']
['Q11010724', 'P734', 'Q59853']
['Q1236794', 'P31', 'Q1134686']
['Q1053299', 'P159', 'Q1761']
['Q1306857', 'P141', 'Q211005']
['Q247581', 'P361', 'Q849059']
['Q15489648', 'P31', 'Q5']
['Q4816942', 'P31', 'Q3957']
['Q374813', 'P19', 'Q139046']
['Q14579', 'P400', 'Q390389']
['Q7425556', 'P131', 'Q1989']
['Q3810769', 'P54', 'Q2049524']
['Q8021705', 'P54', 'Q192597']
['Q90457', 'P102', 'Q328195']


In [9]:
def find_edge(graph, source_name, target_name):
    """
    Given a graph, a source entity name, and a target entity name,
    this function checks if there's an edge between the two entities.
    If the edge exists, it returns the relation type (edge property).
    Otherwise, it returns None.
    """
    vprop = graph.vp.name
    eprop = graph.ep.name
    
    source_vertex = find_vertex(graph, vprop, source_name)
    target_vertex = find_vertex(graph, vprop, target_name)

    if not source_vertex or not target_vertex:
        return None

    # Since find_vertex returns a list, we get the first item as our vertex.
    source_vertex = source_vertex[0]
    target_vertex = target_vertex[0]

    edge = graph.edge(source_vertex, target_vertex)

    if edge:
        return eprop[edge]
    else:
        return None

for line in lines[:15]:
    triple = line.strip().split('\t')
    edge = find_edge(gp.graph, triple[0], triple[2])
    print(edge)

P31
P175
P734
P31
P159
P141
P361
P31
P31
P19
P400
P131
P54
P54
P102


In [None]:
gp.graph.get_all_neighbours(0)

In [10]:
vertex = find_vertex(gp.graph, gp.graph.vp.name, 'Q52234')[0]
print(vertex)
neighbors = gp.graph.get_all_neighbors(vertex)
neighbors = [gp.graph.vp.name[v] for v in neighbors]
print(neighbors)

1506658
['Q38', 'Q52227', 'Q6655', 'Q270091', 'Q747074', 'Q13372', 'Q160628', 'Q270220', 'Q6723', 'Q16205', 'Q52233', 'Q160628', 'Q16205', 'Q487473', 'Q52227', 'Q52233', 'Q270220', 'Q3655365', 'Q13372', 'Q270091', 'Q2314005', 'Q3966766', 'Q27881215']


In [14]:
def get_k_hop_neighbors(graph: Graph, node, k=1):
    """Retrieve k-hop neighbors for a given node."""
    
    visited = set()
    queue = [(node, 0)]  # The tuple contains the node and its distance from the source node
    
    while queue:
        current_node, distance = queue.pop(0)
        if current_node not in visited:
            visited.add(current_node)
            
            if distance < k:
                neighbors = graph.vertex(current_node).all_neighbors()
                edges = graph.vertex(current_node).all_edges()
                out_degree = graph.vertex(current_node).out_degree()
                in_degree = graph.vertex(current_node).in_degree()
                dummy_var = None
                for neighbor in graph.vertex(current_node).all_neighbors():
                    queue.append((neighbor, distance + 1))

    # Remove the source node
    visited.remove(node)
    return list(visited)

def extract_subgraph_path_degree(graph: Graph, source_entities, k, max_nodes):
    # Step 1: Extract k-hop subgraph
    all_neighbors = set(source_entities)
    for _ in range(k):
        current_neighbors = set()
        for entity in all_neighbors:
            vertex = find_vertex(gp.graph, gp.graph.vp.name, entity)[0]
            current_neighbors.update(get_k_hop_neighbors(graph, vertex, k))
        all_neighbors.update(current_neighbors)
    
    # Convert set to list to work with indices
    all_neighbors = list(all_neighbors)
    
    # # Step 2: Calculate local degree
    # degrees = {node: 0 for node in all_neighbors}
    # for i, node1 in enumerate(all_neighbors):
    #     for j, node2 in enumerate(all_neighbors):
    #         if i != j and graph.edge(node1, node2):
    #             degrees[node1] += 1
    #             degrees[node2] += 1
    
    # # Step 3: Prune nodes if needed
    # if len(all_neighbors) > max_nodes:
    #     sorted_nodes = sorted(degrees.keys(), key=lambda x: degrees[x], reverse=True)
    #     all_neighbors = sorted_nodes[:max_nodes]
    
    # Extract the subgraph
    subgraph = GraphView(graph, vfilt=lambda v: graph.vp.name[v] in all_neighbors)
    return subgraph

In [12]:
def average_node_degree(graph):
    """
    Compute the average degree of nodes in the graph.

    Parameters:
    - graph: A graph_tool.Graph object.

    Returns:
    - float: The average degree of nodes in the graph.
    """
    
    # Total degrees (sum of in and out degrees for each node)
    total_degrees = sum(v.out_degree() + v.in_degree() for v in graph.vertices())
    
    # Average degree
    avg_degree = total_degrees / graph.num_vertices()
    
    return avg_degree

In [13]:
avg_node_degree = average_node_degree(gp.graph)
print(f"Average node degree: {avg_node_degree}")

Average node degree: 8.977967938089266


In [28]:
entities = [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)]
entities = [entity[0] for entity in entities]
internal_entity_ids = []
for entity in entities:
    internal_id = find_vertex(gp.graph, gp.graph.vp.name, entity)[0]
    print(internal_id)
    print(type(internal_id))
    neighbors = gp.graph.get_all_neighbors(internal_id)
    print(neighbors)
    break
print(internal_entity_ids)


350504
<class 'graph_tool.libgraph_tool_core.Vertex'>
[ 509439  732044  439350     119  840198  507772 1222367  194793 2194781
 2240924 2281144  223915  772394  132450    5558    8478 3221287   81266
    6654  109367     757  419388 2037149  987831 2295135  432470  327782
 3164783    4627   84539  358545  496908   29120    2394 1939895  376375
  396127   19427 1549253     740 2575705 3962655 3126404      43  665725
      15  560465  886008 1083555  133696  260568 2297224   64568    1397
 3735010      49 1496342     717  174432 3994725     367  226337  280855
  677905  109613  277933   49191  207888  223915  247536  272619  277933
  280855  300044  327782  338331  345010  350503      43    7288    7289
   19427   56087   84539   91698  109683  132844  136681  142145  167353
  183525  196769  204579  382166  396127  398554  402324  496908  510474
  520348  521234  539655  560465  560466  575715  580939  593500  607154
  610148  610149  645636  652752  674163  675534  681997  745910  7996

In [66]:
batch_of_entities = [
    [('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 335, 341),('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351),('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q203', 449, 452), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q203', 449, 452), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q30', 200, 212), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147),('Q61061', 183, 191), ('Q30', 200, 212), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117),('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q22686', 464, 475), ],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q22686', 464, 475), ('Q61061', 489, 497), ],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q61061', 61, 69),('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q203', 449, 452), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q782', 25, 30), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q61061', 301, 309), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q22686', 432, 443), ('Q22686', 464, 475), ('Q30', 506, 518)],
    [('Q76', 0, 11), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q30', 318, 330), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q16279311', 417, 423), ('Q203', 449, 452), ('Q61061', 489, 497), ('Q30', 506, 518)],
    [('Q782', 25, 30), ('Q61061', 61, 69), ('Q30', 78, 90), ('Q8445', 111, 117), ('Q200', 145, 147), ('Q444353', 171, 175), ('Q61061', 183, 191), ('Q30', 200, 212), ('Q3220821', 230, 236), ('Q30', 245, 257), ('Q30', 262, 268), ('Q61', 273, 287), ('Q11651', 293, 299), ('Q61061', 301, 309), ('Q30', 335, 341), ('Q2057908', 346, 351), ('Q22686', 360, 371), ('Q22686', 390, 401), ('Q8445', 406, 412), ('Q22686', 432, 443), ('Q22686', 464, 475), ('Q61061', 489, 497), ('Q30', 506, 518)],

]

In [17]:
subgraph = extract_subgraph_path_degree(gp.graph, entities, k=1, max_nodes=64)

KeyboardInterrupt: 

In [51]:
def retrieve_pruned_khop_subgraph(graph, entities, k=2, min_connections=2):
    """
    Retrieves a subgraph consisting of the k-hop neighbors of the given entities.
    Prunes nodes at each hop that have fewer than 'min_connections' links to the 
    set of entities from the previous hop.
    
    :param graph: The graph_tool graph.
    :param entities: List of entity names.
    :param k: Depth of subgraph retrieval.
    :param min_connections: Minimum number of connections a node must have to be retained.
    
    :return: A pruned subgraph.
    """

    # Getting the vertex property map for entity names
    vprop = graph.vp.name

    # Convert entity names to vertices for initial set
    current_vertices = [v for v in graph.vertices() if vprop[v] in entities]
    all_retrieved_vertices = set(current_vertices)

    for _ in range(k):
        next_vertices = set()

        # For each vertex in the current set, find its 1-hop neighbors
        for v in current_vertices:
            next_vertices.update(v.out_neighbors())
            # next_vertices.update(v.in_neighbors())

        # Prune neighbors that have fewer than 'min_connections' links to the current set of vertices
        to_remove = set()
        for neighbor in next_vertices:
            links_to_current = sum(1 for out_neigh in neighbor.out_neighbors() if out_neigh in current_vertices)
            # links_to_current += sum(1 for in_neigh in neighbor.in_neighbors() if in_neigh in current_vertices)
            
            if links_to_current < min_connections:
                to_remove.add(neighbor)

        next_vertices.difference_update(to_remove)
        all_retrieved_vertices.update(next_vertices)

        # Set next vertices as current vertices for the next iteration
        current_vertices = list(next_vertices)

    # Create a subgraph from the all retrieved vertices
    subgraph = GraphView(graph, vfilt=lambda v: v in all_retrieved_vertices)

    return subgraph


In [60]:
pruned_subgraph = retrieve_pruned_khop_subgraph(gp.graph, entities, k=2, min_connections=2)

In [61]:
pruned_subgraph.num_vertices()

489

In [73]:
import multiprocessing

def _khop_worker(i, entity_list, graph, k, min_connections, results):
    # Here, we'll recreate the method logic using passed parameters
    subgraph = retrieve_pruned_khop_subgraph(entity_list, graph, k, min_connections)
    results[i] = subgraph

# def retrieve_pruned_khop_subgraph_standalone(entities, graph, k=1, min_connections=1):
#     # The content of your original retrieve_pruned_khop_subgraph method
#     # Make sure to replace all references to self.graph with graph
#     ...

def batch_retrieve_pruned_khop_subgraph(graph, batch_entities, k=1, min_connections=1, num_processes=None):
    if not num_processes:
        num_processes = multiprocessing.cpu_count()

    manager = multiprocessing.Manager()
    results = manager.list([None] * len(batch_entities))

    with multiprocessing.Pool(processes=num_processes) as pool:
        pool.starmap(_khop_worker, [(i, entity_list, graph, k, min_connections, results) for i, entity_list in enumerate(batch_entities)])

    return list(results)



In [74]:
batch_of_subgraphs = batch_retrieve_pruned_khop_subgraph(gp.graph, batch_of_entities, k=2, min_connections=2)

KeyboardInterrupt: 