In [81]:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
import networkx as nx
import xml.etree.ElementTree as ET
from sqlalchemy import Column, String, Float, Integer
from sqlalchemy.dialects.postgresql import insert


In [85]:
# Database connection parameters
DATABASE_URL = 'Config.DATABASE_URL'

# Create the SQLAlchemy engine
engine = create_engine(DATABASE_URL)

# Create a configured "Session" class
Session = sessionmaker(bind=engine)

# Create a session
session = Session()

# Define the base class for models
Base = declarative_base()

  Base = declarative_base()


In [86]:
class Node(Base):
    __tablename__ = 'nodes'
    
    id = Column(Integer, primary_key=True)
    x = Column(Float)
    y = Column(Float)
    station_type = Column(String)
    node_label = Column(String)
    year = Column(Integer)
    east_west = Column(String)
    neighbourhood = Column(String)
    district = Column(String)

class Edge(Base):
    __tablename__ = 'edges'
    
    id = Column(String, primary_key=True)
    source = Column(String)
    target = Column(String)
    label = Column(String)
    year = Column(Integer)
    edge_type = Column(String)
    frequency = Column(Float)
    east_west = Column(String)
    capacity = Column(Integer)
    distance = Column(Float)

# Ensure tables are created in PostgreSQL
Base.metadata.create_all(engine)

In [89]:
def migrate_graphml_to_postgres(graphml_file):
    # Parse the GraphML file
    tree = ET.parse(graphml_file)
    root = tree.getroot()

    # Define the namespace
    ns = {'ns0': 'http://graphml.graphdrawing.org/xmlns'}

    # Batch size to avoid memory overload
    batch_size = 1000

    # Data containers for bulk inserts
    node_data_list = []
    edge_data_list = []

    # Migrate nodes
    for node in root.findall(".//ns0:node", ns):
        node_data = {
            'id': node.get('id'),
            'x': float(node.find(".//ns0:data[@key='d0']", ns).text),
            'y': float(node.find(".//ns0:data[@key='d1']", ns).text),
            'station_type': node.find(".//ns0:data[@key='d3']", ns).text,
            'node_label': node.find(".//ns0:data[@key='d7']", ns).text,
            'year': node.find(".//ns0:data[@key='d2']", ns).text,
            'east_west': node.find(".//ns0:data[@key='d4']", ns).text,
            # Use `.find()` and `.text` conditionally to handle missing elements
            'neighbourhood': (node.find(".//ns0:data[@key='d5']", ns).text if node.find(".//ns0:data[@key='d5']", ns) is not None else None),
            'district': (node.find(".//ns0:data[@key='d6']", ns).text if node.find(".//ns0:data[@key='d6']", ns) is not None else None)
        }
        node_data_list.append(node_data)

        # Insert in batches
        if len(node_data_list) >= batch_size:
            bulk_insert_nodes(node_data_list)
            node_data_list = []

    # Insert remaining nodes
    if node_data_list:
        bulk_insert_nodes(node_data_list)

    # Migrate edges
    for edge in root.findall(".//ns0:edge", ns):
        edge_data = {
            'id': edge.get('id'),
            'source': edge.get('source'),
            'target': edge.get('target'),
            'label': edge.find(".//ns0:data[@key='d10']", ns).text,
            'year': int(edge.find(".//ns0:data[@key='d12']", ns).text),
            'frequency': float(edge.find(".//ns0:data[@key='d13']", ns).text),
            'east_west': edge.find(".//ns0:data[@key='d14']", ns).text,
            'capacity': int(edge.find(".//ns0:data[@key='d15']", ns).text),
            'distance': float(edge.find(".//ns0:data[@key='d16']", ns).text),
            'edge_type': edge.find(".//ns0:data[@key='d11']", ns).text
        }
        edge_data_list.append(edge_data)

        # Insert in batches
        if len(edge_data_list) >= batch_size:
            bulk_insert_edges(edge_data_list)
            edge_data_list = []

    # Insert remaining edges
    if edge_data_list:
        bulk_insert_edges(edge_data_list)


def bulk_insert_nodes(node_data_list):
    try:
        stmt = insert(Node).values(node_data_list)
        stmt = stmt.on_conflict_do_nothing()  # Ignore duplicates
        session.execute(stmt)
        session.commit()
    except Exception as e:
        session.rollback()
        print(f"Error during node insertion: {e}")

def bulk_insert_edges(edge_data_list):
    try:
        stmt = insert(Edge).values(edge_data_list)
        stmt = stmt.on_conflict_do_nothing()  # Ignore duplicates
        session.execute(stmt)
        session.commit()
    except Exception as e:
        session.rollback()
        print(f"Error during edge insertion: {e}")

In [90]:
migrate_graphml_to_postgres('graph_data/base-G.graphml')
