In [None]:
import pandas as pd
import numpy as np
from typing import Dict, Tuple
from collections import defaultdict
import time

class PrimeKGLoader:
    """
    Prepares PrimeKG data for efficient loading into DGL heterogeneous graphs.
    Each node type gets sequential IDs starting from 0.
    """
    
    def __init__(self):
        self.node_type_mapping = {}  # string -> int
        self.relationship_type_mapping = {}  # string -> int
        self.reverse_node_type_mapping = {}  # int -> string
        self.reverse_relationship_type_mapping = {}  # int -> string
        self.global_to_local_mapping = {}  # For reference: global_id -> (node_type, local_id)
        
    def load_and_prepare_primekg(self, nodes_csv_path: str, edges_csv_path: str):
        """
        Load PrimeKG data and prepare it for bulk_load_heterogeneous_graph.
        Each node type gets sequential IDs starting from 0.
        
        Args:
            nodes_csv_path: Path to nodes CSV file
            edges_csv_path: Path to edges CSV file
            
        Returns:
            Tuple of (node_types_dict, edge_types_dict) ready for DGL loading
        """
        print("Loading PrimeKG data...")
        start_time = time.time()
        
        # Load raw data
        print("  Reading CSV files...")
        nodes_df = pd.read_csv(nodes_csv_path, low_memory=False)
        edges_df = pd.read_csv(edges_csv_path, low_memory=False)
        
        print(f"  Loaded {len(nodes_df):,} nodes and {len(edges_df):,} edges")
        
        # Create type mappings
        print("  Creating type mappings...")
        self._create_type_mappings(nodes_df, edges_df)
        
        # Prepare node data (sequential IDs starting from 0 for each type)
        print("  Preparing node data...")
        node_types_dict = self._prepare_node_data(nodes_df)
        
        # Prepare edge data (using local IDs)
        print("  Preparing edge data...")
        edge_types_dict = self._prepare_edge_data(edges_df, nodes_df)
        
        total_time = time.time() - start_time
        print(f"\nData preparation completed in {total_time:.2f}s")
        
        # Print summary
        self._print_summary(node_types_dict, edge_types_dict)
        
        return node_types_dict, edge_types_dict, self.global_to_local_mapping
    
    def _create_type_mappings(self, nodes_df: pd.DataFrame, edges_df: pd.DataFrame):
        """Create mappings between string types and integer representations."""
        
        # Node type mappings
        unique_node_types = sorted(nodes_df['node_type'].unique())
        self.node_type_mapping = {node_type: i for i, node_type in enumerate(unique_node_types)}
        self.reverse_node_type_mapping = {i: node_type for node_type, i in self.node_type_mapping.items()}
        
        # Relationship type mappings
        unique_rel_types = sorted(edges_df['relationship_type'].unique())
        self.relationship_type_mapping = {rel_type: i for i, rel_type in enumerate(unique_rel_types)}
        self.reverse_relationship_type_mapping = {i: rel_type for rel_type, i in self.relationship_type_mapping.items()}
        
        print(f"    Found {len(unique_node_types)} node types: {unique_node_types}")
        print(f"    Found {len(unique_rel_types)} relationship types: {unique_rel_types}")
    
    def _prepare_node_data(self, nodes_df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
        """
        Group nodes by type and prepare DataFrames with sequential IDs starting from 0.
        
        Returns:
            Dict mapping node_type_string -> DataFrame with columns ['node_id', 'name', 'metadata_source', 'node_type_id', 'original_global_id']
        """
        node_types_dict = {}
        
        # Add numeric type ID to nodes
        nodes_df_copy = nodes_df.copy()
        nodes_df_copy['node_type_id'] = nodes_df_copy['node_type'].map(self.node_type_mapping)
        
        # Group by node type and assign sequential IDs starting from 0
        for node_type_str, group_df in nodes_df_copy.groupby('node_type'):
            # Sort by original ID for consistency
            group_df = group_df.sort_values('id').reset_index(drop=True)
            
            # Create sequential IDs starting from 0
            num_nodes = len(group_df)
            
            # Build global to local mapping for this node type
            global_ids = group_df['id'].values
            local_ids = np.arange(num_nodes)  # 0, 1, 2, ..., num_nodes-1
            
            # Store the mapping for edge processing
            for local_id, global_id in zip(local_ids, global_ids):
                self.global_to_local_mapping[global_id] = (node_type_str, local_id)
            
            # Prepare DataFrame for DGL
            prepared_df = pd.DataFrame({
                'node_id': local_ids,  # Sequential IDs starting from 0
                'name': group_df['name'].values,
                'metadata_source': group_df['metadata_source'].values,
                'node_type_id': group_df['node_type_id'].values,
                'original_global_id': global_ids  # Keep original for reference
            })
            
            node_types_dict[node_type_str] = prepared_df
            print(f"    {node_type_str}: {num_nodes:,} nodes (IDs: 0 to {num_nodes-1})")
            
        return node_types_dict
    
    def _prepare_edge_data(self, edges_df: pd.DataFrame, nodes_df: pd.DataFrame) -> Dict[Tuple[str, str, str], pd.DataFrame]:
        """
        Prepare edge data grouped by (src_type, edge_type, dst_type) using local IDs.
        
        Returns:
            Dict mapping (src_type, edge_type, dst_type) -> DataFrame with columns ['src', 'dst', 'relationship_type_id']
        """
        # Create node ID to type mapping for fast lookup
        node_id_to_type = dict(zip(nodes_df['id'], nodes_df['node_type']))
        
        # Add relationship type IDs
        edges_df_copy = edges_df.copy()
        edges_df_copy['relationship_type_id'] = edges_df_copy['relationship_type'].map(self.relationship_type_mapping)
        
        # Add source and target node types
        edges_df_copy['src_type'] = edges_df_copy['source_id'].map(node_id_to_type)
        edges_df_copy['dst_type'] = edges_df_copy['target_id'].map(node_id_to_type)
        
        # Filter out edges with unknown nodes
        valid_mask = (edges_df_copy['src_type'].notna()) & (edges_df_copy['dst_type'].notna())
        valid_edges = edges_df_copy[valid_mask]
        
        if len(valid_edges) < len(edges_df_copy):
            print(f"    Warning: Filtered out {len(edges_df_copy) - len(valid_edges)} edges with unknown nodes")
        
        # Group by (src_type, relationship_type, dst_type)
        edge_types_dict = {}
        
        for (src_type, rel_type, dst_type), group_df in valid_edges.groupby(['src_type', 'relationship_type', 'dst_type']):
            print(f"    Processing {src_type} --[{rel_type}]--> {dst_type}: {len(group_df):,} edges")
            
            # VECTORIZED APPROACH - Much faster than loops
            group_df_reset = group_df.reset_index(drop=True)
            
            # Create mapping functions for this specific edge type
            src_type_mapping = {global_id: local_id for global_id, (nt, local_id) in self.global_to_local_mapping.items() if nt == src_type}
            dst_type_mapping = {global_id: local_id for global_id, (nt, local_id) in self.global_to_local_mapping.items() if nt == dst_type}
            
            # Vectorized mapping using pandas map
            group_df_reset['src_local'] = group_df_reset['source_id'].map(src_type_mapping)
            group_df_reset['dst_local'] = group_df_reset['target_id'].map(dst_type_mapping)
            
            # Filter valid edges (both src and dst must be mapped)
            valid_mask = (group_df_reset['src_local'].notna()) & (group_df_reset['dst_local'].notna())
            valid_edges_df = group_df_reset[valid_mask]
            
            if len(valid_edges_df) == 0:
                print(f"      Warning: No valid edges found for {src_type}-{rel_type}->{dst_type}")
                continue
            
            # Create edge DataFrame with local node IDs
            edge_df = pd.DataFrame({
                'src': valid_edges_df['src_local'].astype(int).values,  # Local IDs (0-based for each node type)
                'dst': valid_edges_df['dst_local'].astype(int).values,  # Local IDs (0-based for each node type)
                'relationship_type_id': valid_edges_df['relationship_type_id'].values,
                'original_src_id': valid_edges_df['source_id'].values,  # Keep original for reference
                'original_dst_id': valid_edges_df['target_id'].values   # Keep original for reference
            })
            
            edge_types_dict[(src_type, rel_type, dst_type)] = edge_df
            print(f"      Created {len(edge_df):,} valid edges")
            
        return edge_types_dict
    
    def _print_summary(self, node_types_dict: Dict[str, pd.DataFrame], 
                      edge_types_dict: Dict[Tuple[str, str, str], pd.DataFrame]):
        """Print summary of prepared data."""
        print("\n" + "="*60)
        print("PRIMEKG DATA PREPARATION SUMMARY")
        print("="*60)
        
        print("\nNode Type Mappings:")
        for str_type, int_type in self.node_type_mapping.items():
            count = len(node_types_dict.get(str_type, []))
            print(f"  {int_type}: {str_type} ({count:,} nodes, IDs: 0 to {count-1})")
        
        print("\nRelationship Type Mappings:")
        for str_type, int_type in self.relationship_type_mapping.items():
            print(f"  {int_type}: {str_type}")
        
        print("\nPrepared Node Types:")
        total_nodes = 0
        for node_type, df in node_types_dict.items():
            min_id = df['node_id'].min()
            max_id = df['node_id'].max()
            print(f"  {node_type}: {len(df):,} nodes (local IDs: {min_id} to {max_id})")
            total_nodes += len(df)
        print(f"  TOTAL: {total_nodes:,} nodes")
        
        print("\nPrepared Edge Types:")
        total_edges = 0
        for (src_type, edge_type, dst_type), df in edge_types_dict.items():
            print(f"  {src_type} --[{edge_type}]--> {dst_type}: {len(df):,} edges")
            total_edges += len(df)
        print(f"  TOTAL: {total_edges:,} edges")
        
        print("\nData Format Verification:")
        for node_type, df in node_types_dict.items():
            assert df['node_id'].min() == 0, f"Node IDs for {node_type} don't start at 0!"
            assert df['node_id'].max() == len(df) - 1, f"Node IDs for {node_type} are not sequential!"
            print(f"  ✅ {node_type}: Sequential IDs 0 to {len(df)-1}")
        
        print("="*60)
    
    def get_type_mappings(self):
        """Return the type mappings for reference."""
        return {
            'node_types': self.node_type_mapping,
            'relationship_types': self.relationship_type_mapping,
            'reverse_node_types': self.reverse_node_type_mapping,
            'reverse_relationship_types': self.reverse_relationship_type_mapping
        }
    
    def get_global_to_local_mapping(self):
        """Return the global to local ID mapping for reference."""
        return self.global_to_local_mapping.copy()
    
    def global_id_to_local(self, global_id: int) -> Tuple[str, int]:
        """Convert a global node ID to (node_type, local_id)."""
        if global_id in self.global_to_local_mapping:
            return self.global_to_local_mapping[global_id]
        else:
            raise ValueError(f"Global ID {global_id} not found in mapping")
    
    def local_id_to_global(self, node_type: str, local_id: int) -> int:
        """Convert (node_type, local_id) to global node ID."""
        for global_id, (nt, lid) in self.global_to_local_mapping.items():
            if nt == node_type and lid == local_id:
                return global_id
        raise ValueError(f"Local ID ({node_type}, {local_id}) not found in mapping")


In [None]:
import torch
from typing import List, Dict
from dgl.dataloading import DataLoader, NeighborSampler

def create_simple_dataloader(graph, 
                            batch_size: int = 512,
                            fanouts: List[int] = [15, 10],
                            shuffle: bool = True):
    """
    Create simple DGL DataLoader compatible with different versions.
    
    Args:
        batch_size: Number of seed nodes per batch
        fanouts: Number of neighbors to sample per layer
        shuffle: Whether to shuffle the seed nodes
        
    Returns:
        DGL DataLoader that covers all nodes
    """
    
    sampler = NeighborSampler(fanouts)
    print("Using modern DGL DataLoader")

    all_nodes = {}
    total_count = 0
    for node_type in graph.ntypes:
        num_nodes = graph.num_nodes(node_type)
        all_nodes[node_type] = torch.arange(num_nodes)
        total_count += num_nodes
    print(f"Created dataloader for {total_count} nodes across {len(all_nodes)} types")
    
    # Create DataLoader
    dataloader = DataLoader(
        graph,
        all_nodes,
        sampler,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=False,
        num_workers=0  # Start with 0 for compatibility
    )
    
    print(f"DataLoader ready: {len(dataloader)} batches")
    return dataloader

In [None]:
from DeepGraphDB import DeepGraphDB

db = DeepGraphDB()
# Initialize the loader
loader = PrimeKGLoader()

# Load and prepare data
nodes_csv = "nodes.csv"  # Replace with your actual path
edges_csv = "edges.csv"  # Replace with your actual path

node_types_dict, edge_types_dict, mapping = loader.load_and_prepare_primekg(nodes_csv, edges_csv)

    
# Get type mappings for reference
mappings = loader.get_type_mappings()
print("\nType mappings created:")
print("Node types:", mappings['node_types'])
print("Relationship types:", mappings['relationship_types'])

# Verify data format
print("\nData format verification:")
for node_type, df in node_types_dict.items():
    print(f"  {node_type}: node_id range {df['node_id'].min()}-{df['node_id'].max()}")

# Now you can use this data with your DGL graph analyzer
print("\nReady to load into DGL!")
print("Use: analyzer.bulk_load_heterogeneous_graph(node_types_dict, edge_types_dict)")
db.bulk_load_heterogeneous_graph(node_types_dict, edge_types_dict)
db.set_mappings(loader.node_type_mapping, loader.relationship_type_mapping)
db.set_global_to_local_mapping(mapping)

In [None]:
import torch

x = torch.rand(max(db.global_to_local_mapping.keys())+1, 128)
db.load_node_features_for_gnn(x)

In [None]:
from dataloader import create_dataloader_with_negative_sampling

dataloader = create_dataloader_with_negative_sampling(
    db.graph,
    batch_size=256,
    fanouts=[10, 5],
    negative_sampler='global_uniform',
    num_negative_edges=1
)

# dataloader = create_simple_dataloader(db.graph, batch_size=512, fanouts=[15, 10, 5], shuffle=True)

batch = next(iter(dataloader))

In [None]:
batch[3]

In [None]:
# print("\nGraph statistics:")
# stats = db.get_graph_statistics()
# for key, value in stats.items():
#     print(f"  {key}: {value}")

test_node = 26719

print(f"\nTesting k-hop neighbors around node {test_node}...")

start = time.time()
neighbors = db.get_k_hop_neighbors([test_node], k=3, edge_types=[("effectphenotype", "phenotype_phenotype" ,"effectphenotype")])
end = time.time()
print(f"3-hop neighbors of node {test_node}: {len(neighbors.get(1, set()))} at hop 1, {len(neighbors.get(2, set()))} at hop 2, {len(neighbors.get(3, set()))} at hop 3")
print(f"K-hop query took {end-start:.4f} seconds")
print(f"Neighbors: {neighbors}")

print(f"\nTesting subgraph extraction around node {test_node}...")
start = time.time()
subgraph, k_hop_res = db.extract_subgraph([test_node], k=3)
end = time.time()
total_subgraph_nodes = sum(subgraph.num_nodes(ntype) for ntype in subgraph.ntypes) if hasattr(subgraph, 'ntypes') else subgraph.num_nodes()
total_subgraph_edges = sum(subgraph.num_edges(etype) for etype in subgraph.canonical_etypes) if hasattr(subgraph, 'canonical_etypes') else subgraph.num_edges()
print(f"3-hop subgraph: {total_subgraph_nodes} nodes, {total_subgraph_edges} edges")
print(f"Subgraph extraction took {end-start:.4f} seconds")

# print(f"\nTesting meta-path queries from node {test_node}...")
# start = time.time()
# meta_paths = db.find_meta_paths([test_node], ['authored', 'about'], max_paths_per_node=5)
# end = time.time()
# print(f"Meta-paths from node {test_node}: {len(meta_paths.get(test_node, []))} paths found")
# if meta_paths.get(test_node):
#     print(f"First path example: {meta_paths[test_node][0] if meta_paths[test_node] else 'None'}")
# print(f"Meta-path query took {end-start:.4f} seconds")

# print("\nTesting node queries...")
# start = time.time()
# nodes_df = db.query_nodes_by_feature('person', 'age', 'gt', 50, return_features=['h_index'])
# end = time.time()
# print(f"Filtered person nodes (age > 50):\n{nodes_df['node_ids'].shape}")
# print(f"Node query took {end-start:.4f} seconds")

# print("\nTesting node queries...")
# start = time.time()
# nodes_df = db.query_nodes_by_feature('person', 'age', 'gt', 50, return_features=['h_index'])
# nodes_df =  db.get_top_nodes_by_feature('person', 'h_index', top_k=5, ascending=False, return_features=['name'])
# end = time.time()
# print(f"Top k node x h-index:\n{nodes_df['name']}")
# print(f"Node query took {end-start:.4f} seconds")

# # print("\nTesting edge queries...")
# # start = time.time()
# # edges_df = db.query_edges(edge_type=('person', 'authored', 'paper'), limit=5)
# # end = time.time()
# # print(f"Authored edges (first 5):\n{edges_df}")
# # print(f"Edge query took {end-start:.4f} seconds")

# # Performance summary
# print(f"\n{'='*50}")
# print("PERFORMANCE SUMMARY")
# print(f"{'='*50}")
# print(f"Graph size: {stats.get('total_nodes', 0):,} nodes, {stats.get('total_edges', 0):,} edges")
# print("All operations completed successfully on large-scale graph!")
# print("Framework is ready for production use with millions of nodes/edges.")

In [None]:
nodes = [36278, 9, 120987]
k = 2

neighbors = db.get_k_hop_neighbors(nodes, k=k)
print(f"{k}-hop neighbors of nodes: {len(neighbors.get(1, set()))} at hop 1, {len(neighbors.get(2, set()))} at hop 2, {len(neighbors.get(3, set()))} at hop 3")

In [None]:
subgraph, k_hop_res = db.extract_subgraph(nodes, k=2)
total_subgraph_nodes = sum(subgraph.num_nodes(ntype) for ntype in subgraph.ntypes) if hasattr(subgraph, 'ntypes') else subgraph.num_nodes()
total_subgraph_edges = sum(subgraph.num_edges(etype) for etype in subgraph.canonical_etypes) if hasattr(subgraph, 'canonical_etypes') else subgraph.num_edges()
print(f"{k}-hop subgraph: {total_subgraph_nodes} nodes, {total_subgraph_edges} edges")

In [None]:
group_of_nodes = [[36278], [120987]]
merged_graphs, stats = db.merge_k_hop_subgraphs(group_of_nodes, k=1)

In [None]:
db.find_meta_paths([6372], ['protein_protein', 'anatomy_protein_present'], 10, 3)

In [None]:
edges = pd.read_csv("edges.csv", low_memory=False)
nodes = pd.read_csv("nodes.csv", low_memory=False)

In [None]:
start_nodes = set([36278, 120987])  # Replace with your actual node IDs

# Step 1: Find all 1-hop neighbors
one_hop_edges = edges[edges['source_id'].isin(start_nodes) | edges['target_id'].isin(start_nodes)]
one_hop_nodes = set(one_hop_edges['source_id']).union(set(one_hop_edges['target_id']))

# Step 2: Find all 2-hop neighbors
two_hop_edges = edges[edges['source_id'].isin(one_hop_nodes) | edges['target_id'].isin(one_hop_nodes)]
two_hop_nodes = set(two_hop_edges['source_id']).union(set(two_hop_edges['target_id']))

# Subgraph nodes (within 2 hops from start_nodes)
subgraph_nodes = two_hop_nodes

# Step 3: Extract the full 2-hop subgraph (edges between any of those nodes)
subgraph = edges[edges['source_id'].isin(subgraph_nodes) & edges['target_id'].isin(subgraph_nodes)]

# Display or export the result
print(subgraph)

In [None]:
edges.groupby(['source_id']).size().reset_index(name='count').sort_values(by='count', ascending=False).head(10)

In [None]:
# print(nodes.loc[100])

# local_type, local_id = db.global_to_local_mapping[100]
# print(db.node_data[local_type]['name'][local_id])

# db.bulk_modify_nodes({100: {'name': 'SGAHII'}})

# local_type, local_id = db.global_to_local_mapping[100]
# print(db.node_data[local_type]['name'][local_id])

In [None]:
# print(db.graph.num_nodes())
# print(db.graph.num_edges())
# print('----------------------')
# print(len(db.global_to_local_mapping.keys()))

# count = 0
# for k, v in db.node_data.items():
#     count += v['name'].shape[0]

# print(count)

# db.bulk_delete_nodes([63423, 63376, 64544])

# print(db.graph.num_nodes())
# print(db.graph.num_edges())
# print('----------------------')
# print(len(db.global_to_local_mapping.keys()))

# count = 0
# for k, v in db.node_data.items():
#     count += v['name'].shape[0]

# print(count)

In [None]:
import random
import time

# --- Core remapping function for a single type batch ---
def remap_type(typ, global_to_local, global_to_vector):
    """
    After deletions, for one type:
      1. Rescale local_ids
      2. Remap global_ids to consecutive starting at min
      3. Resort both dicts ascending
    """
    # Gather remaining entries for this type
    entries = [(gid, loc, global_to_vector[gid])
               for gid, (loc, t) in global_to_local.items() if t == typ]
    entries.sort(key=lambda x: x[1])

    # Compute new global IDs
    old_ids = [gid for gid, _, _ in entries]
    if not old_ids:
        return
    min_id = min(old_ids)
    new_globals = [min_id + i for i in range(len(entries))]

    # Clear old entries for this type
    for gid, _, _ in entries:
        del global_to_local[gid]
        del global_to_vector[gid]

    # Reassign mappings
    for new_local, ((old_gid, _, vec_id), new_gid) in enumerate(zip(entries, new_globals)):
        global_to_local[new_gid] = (new_local, typ)
        global_to_vector[new_gid] = vec_id

# --- Batch delete and remap without elems list ---
def delete_many(global_ids, global_to_local, global_to_vector):
    """
    Delete multiple global_ids, then remap affected types in one pass.
    Operates only on the two mappings.
    """
    types_to_fix = set()

    # 1. Remove specified IDs
    for gid in global_ids:
        old_local, typ = global_to_local.pop(gid)
        global_to_vector.pop(gid)
        types_to_fix.add(typ)

    # 2. Remap each affected type
    for typ in types_to_fix:
        remap_type(typ, global_to_local, global_to_vector)

    # 3. Resort both dicts by global_id
    sorted_keys = sorted(global_to_local)
    global_to_local.update({k: global_to_local.pop(k) for k in sorted_keys})
    global_to_vector.update({k: global_to_vector.pop(k) for k in sorted_keys})

# --- Test harness for 1M IDs, 10 types ---
NUM_TYPES = 100
TOTAL_IDS = 1_000_000
IDS_PER_TYPE = TOTAL_IDS // NUM_TYPES

global_to_local = {}
global_to_vector = {}

gid_start = 0
for t in range(NUM_TYPES):
    typ = chr(ord('A') + t)
    for local in range(IDS_PER_TYPE):
        global_to_local[gid_start] = (local, typ)
        global_to_vector[gid_start] = random.randint(1_000_000, 9_999_999)
        gid_start += 1

# Delete a batch of IDs
to_delete = random.sample(list(global_to_local.keys()), 1000)
print(f"Deleting {len(to_delete)} IDs...")

# Time the batch operation
t0 = time.perf_counter()
delete_many(to_delete, global_to_local, global_to_vector)
t1 = time.perf_counter()
print(f"Batch deletion and remap took {t1 - t0:.4f} seconds")

# Sanity checks
remaining = len(global_to_local)
print(f"Remaining entries: {remaining} (should be {TOTAL_IDS - len(to_delete)})")
# Check one random type
typ = random.choice([tp for _, (_, tp) in global_to_local.items()])
ids_of_type = [gid for gid, (_, tt) in global_to_local.items() if tt == typ]
print(f"Type {typ} has {len(ids_of_type)} entries spanning {min(ids_of_type)} to {max(ids_of_type)}")