In [1]:
from neo4j import GraphDatabase
from dataclasses import dataclass
from collections import defaultdict
from typing import Any

In [2]:
@dataclass
class Attr:
    name: str
    data_type: str
    nullable: bool
        
@dataclass
class FK:
    source_attr_name: str
    target_table: str
    target_attr_name: str

In [3]:
class Node:
    def __init__(self, label, attr, vals):
        self.vals = vals
        self.label = label
        self.attr = attr
        
class Edge:
    def __init__(self, label, source, target, attr=[], vals=[]):
        self.label = label
        self.source = source
        self.target = target
        self.attr = attr
        self.vals = vals

In [4]:
@dataclass(frozen=True)
class Node:
    label: str
    attr: list[str]
    vals: list[Any]
        
@dataclass(frozen=True)
class Edge:
    label: str
    attr: list[str]
    vals: list[Any]
    source: Node
    target: Node

In [5]:
class NodeClass:
    def __init__(self, label, attr, pk, nodes):
        self.label = label
        self.pk = pk
        self.nodes = nodes
        self.attr = attr


class EdgeClass:
    def __init__(self, label, source, target, attr, edges):
        self.label = label
        self.source = source
        self.target = target
        self.attr = attr
        self.edges = edges

In [6]:
class GraphSchema:
    
    def __init__(self, nodes, edges):
        self.nodes = nodes
        self.edges = edges

In [7]:
class Table:
    def __init__(self, name):
        self.name = name
        self.attr = []
        self.fk = []
        self.pk = set()
        self.is_referenced = False
        self.tuples = []
        
    def add_attr(self, attr):
        self.attr.append(attr)
        
    def add_fk(self, source_attribute, target_table_name, target_attribute):
        self.fk.append(FK(source_attribute, target_table_name, target_attribute))
        
    def add_pk(self, attr):
        self.pk.add(attr)
        
    def is_m2m(self):
        if self.is_referenced:
            return False
        
        refcnt = 0
        c = set()
        for fk in self.fk:
            if fk.source_attr_name not in self.pk:
                return False
            if fk.target_table not in c:
                c.add(target_table)
                refcnt += 1
        if refcnt != 2:
            return False
        
        return True
    
    def extract_tables(self):
        if not self.is_m2m():
            raise RuntimeError("exctracting edges is possible only for m2m tables")
            
        targets = []
        for fk in self.fk:
            target_table = fk.target_table
            targets.append(target_table)
            
        return sorted(targets) 

In [47]:
TABLE_PRIMARY_KEYS = {
    "suppliers": {'supplier_id'},
    "customer_demographics": {'customer_type_id'},
    "territories": {'territory_id'},
    "shippers": {'shipper_id'},
    "orders": {'order_id'},
    "customer_customer_demo": {'customer_type_id', 'customer_id'},
    "order_details": {'product_id', 'order_id'},
    "employees": {'employee_id'},
    "categories": {'category_id'},
    "employee_territories": {'employee_id', 'territory_id'},
    "customers": {'customer_id'},
    "products": {'product_id'},
    "region": {'region_id'},
    "us_states": {'state_id'}
}

In [9]:
driver = GraphDatabase.driver("neo4j://localhost:7687", auth=("neo4j", "password"))

In [11]:
with driver.session() as session:
    result = session.run("match (n) return n")

In [12]:
uri = "neo4j://localhost:7687"
user = "neo4j"
password = "password"

In [13]:
def nodeMapper(nodeRec):
    label = list(nodeRec.labels)[0]
    attrs = tuple(k for k in nodeRec)
    vals = tuple(nodeRec[k] for k in nodeRec)
    return Node(label, attrs, vals)

In [14]:
def edgeMapper(edgeRec, src, dst):
    label = edgeRec.type
    attrs = tuple(k for k in edgeRec)
    vals = tuple(edgeRec[k] for k in edgeRec)
    return Edge(label, attrs, vals, src, dst)

In [422]:
driver = GraphDatabase.driver(uri, auth=(user, password))
session = driver.session()
result = session.run("match (a)-[b]->(c) return a,b,c")
triplets = []

for s, e, t in result:
    s = nodeMapper(s)
    t = nodeMapper(t)
    e = edgeMapper(e, s, t)
    triplets.append((s, e, t))
    
results2 = session.run("match (a) return a") # merge nodes without edges

In [423]:
tuples = defaultdict(set)

for n1, _, n2 in triplets:
    l = n1.label
    tuples[l].add(n1)
    l = n2.label
    tuples[l].add(n2)

for n1 in results2:
    for n2 in n1:
        s = nodeMapper(n2)
        l = s.label
        tuples[l].add(s)
    
for k in tuples:
    tuples[k] = list(tuples[k])

In [424]:
session.close()

In [425]:
for k in tuples:
    print(k, len(tuples[k]))

territories 53
employees 9
region 4
orders 830
shippers 6
customers 91
products 77
suppliers 29
categories 8
us_states 51


In [426]:
ncs = []
for k, v in tuples.items():
    label = k
    a2t = {}
    for n in v:
        attrs = n.attr
        vals = n.vals
        for i in range(len(attrs)):
            attr = attrs[i]
            tp = type(vals[i])    
            a2t[attr] = tp
    attrs = [Attr(k, a2t[k], False) for k in a2t]
    nc = NodeClass(label, attrs, set(), v)
    ncs.append(nc)

In [427]:
edgesSets = defaultdict(set)

for s, e, t in triplets:
    label = e.label
    edgesSets[label].add(e)
    

In [428]:
ecs = []

for k, v in edgesSets.items():
    label = k
    a2t = {}
    for e in v:
        attrs = e.attr
        vals = e.vals
        source = e.source.label
        target = e.target.label
        for i in range(len(attrs)):
            attr = attrs[i]
            tp = type(vals[i])
            a2t[attr] = tp
            
    attrs = [Attr(k, a2t[k], False) for k in a2t]
    ec = EdgeClass(label, source, target, attrs, v)
    ecs.append(ec)

In [390]:
for ec in ecs:
    print(ec.label, ec.attr)

employee_territories [Attr(name='territory_id', data_type=<class 'str'>, nullable=False), Attr(name='employee_id', data_type=<class 'int'>, nullable=False)]
region_id []
ship_via []
employee_id []
customer_id []
reports_to []
order_details [Attr(name='discount', data_type=<class 'float'>, nullable=False), Attr(name='quantity', data_type=<class 'int'>, nullable=False), Attr(name='unit_price', data_type=<class 'float'>, nullable=False), Attr(name='order_id', data_type=<class 'int'>, nullable=False), Attr(name='product_id', data_type=<class 'int'>, nullable=False)]
supplier_id []
category_id []


In [79]:
# process primary keys

In [429]:
for nc in ncs:
    keys = TABLE_PRIMARY_KEYS[nc.label]
    nc.pk = keys

In [430]:
def is_m2m(edge_class):
    sources = set()
    targets = set()
    src_many = False
    tgt_many = False
    has_attr = False
    
    for e in edge_class:
        if e.vals:
            has_attr = True
        
        if e.source in sources:
            tgt_many = True
        sources.add(e.source)
        
        if e.target in targets:
            src_many = True
        targets.add(e.target)
        
        
    return src_many, tgt_many, has_attr

In [431]:
for ec in ecs:
    src_many, tgt_many, has_attr = is_m2m(ec.edges)
    if has_attr or (src_many and tgt_many):
        print(generate_migration(ec))
        print("\n\n")

CREATE TABLE employee_territories  (
	territory_id  text,
	employee_id  smallint
);



CREATE TABLE order_details  (
	discount  real,
	quantity  smallint,
	unit_price  real,
	order_id  smallint,
	product_id  smallint
);





In [432]:
def generate_deletion(nc):
    query = f"""DROP TABLE IF EXISTS {nc.label};"""
    
    return query

In [433]:
def generate_migration(nc):
    tablename = nc.label
    query = f"""CREATE TABLE {tablename} """ + " (\n"
    parts = []
    for attr in nc.attr:
        part = ""
        if attr.data_type == float:
            tp = ' real'
        elif attr.data_type == int:
            tp = ' smallint'
        elif attr.data_type == str:
            tp = ' text'
        part += f"\t{attr.name} {tp}"
        if attr.nullable:
            part += " NOT NULL"
        parts.append(part)
        # query += ",\n"
    query += ",\n".join(parts)
    return query + "\n);"

In [434]:
def generate_pk_constraint(nc):
    tablename = nc.label
    pk = list(nc.pk)[0]
    q_tpl = f"ALTER TABLE ONLY {tablename} ADD CONSTRAINT pk_{tablename} PRIMARY KEY ({pk});"
    return q_tpl

In [435]:
def generate_inserts_for_nodeClass(nc):
    
    tpl = "INSERT INTO {} ({}) VALUES ({});"
    res = []
    for node in nc.nodes:
        t = []

        for val in node.vals:
            if isinstance(val, str):
                val = val.replace("'", "''")
                val = "\'" + val + "\'"
            else:
                val = str(val)
            t.append(val)
            
        res.append(tpl.format(node.label, ",".join(node.attr), ",".join(t)))
    return res
        
    return [tpl.format(node.label, ",".join(node.attr), node.vals) for node in nc.nodes]

In [436]:
def generate_inserts_for_edgeClass(ec):
    
    tpl = "INSERT INTO {} ({}) VALUES {};"
        
    return [tpl.format(edge.label, ",".join(edge.attr), edge.vals) for edge in ec.edges]

In [437]:
for nc in ncs:
    print(*generate_inserts_for_nodeClass(nc), sep="\n")

INSERT INTO territories (region_id,territory_description,territory_id) VALUES (1,'Cary','27511');
INSERT INTO territories (region_id,territory_description,territory_id) VALUES (2,'Hoffman Estates','60179');
INSERT INTO territories (region_id,territory_description,territory_id) VALUES (2,'Seattle','98104');
INSERT INTO territories (region_id,territory_description,territory_id) VALUES (1,'Morristown','07960');
INSERT INTO territories (region_id,territory_description,territory_id) VALUES (2,'Santa Clara','95054');
INSERT INTO territories (region_id,territory_description,territory_id) VALUES (3,'Philadelphia','19428');
INSERT INTO territories (region_id,territory_description,territory_id) VALUES (4,'Bentonville','72716');
INSERT INTO territories (region_id,territory_description,territory_id) VALUES (2,'Redmond','98052');
INSERT INTO territories (region_id,territory_description,territory_id) VALUES (1,'Boston','02116');
INSERT INTO territories (region_id,territory_description,territory_id) 

In [438]:
database = "northwind_map"

In [439]:
for nc in ncs:
    print(nc.label, len(nc.nodes))

territories 53
employees 9
region 4
orders 830
shippers 6
customers 91
products 77
suppliers 29
categories 8
us_states 51


In [440]:
with open("/Users/a.palagashvili/coursework/psql/queries/tmp.sql", "w") as out:

    print(f"DROP DATABASE {database};", file=out)
    print(f"CREATE DATABASE {database};", file=out)
    print(f"\c {database};", file=out)
    print("\n\n", file=out)
    
    for nc in ncs:
        deletion = generate_deletion(nc)
        migration = generate_migration(nc)
        print(deletion, file=out)
        print(migration, file=out)
        print("\n\n", file=out)

    for nc in ncs:
        print(*generate_inserts_for_nodeClass(nc), sep='\n', file=out)
        print("\n\n", file=out)

    for nc in ncs:
        print(generate_pk_constraint(nc), file=out)

    print("\n\n", file=out)
    for ec in ecs:
        src_many, tgt_many, has_attr = is_m2m(ec.edges)
        if has_attr or (src_many and tgt_many):
            print(generate_deletion(ec), file=out)
            print(generate_migration(ec), file=out)
            print(*generate_inserts_for_edgeClass(ec), sep="\n", file=out)
            print("\n\n", file=out)

    for ec in ecs:
        src_many, tgt_many, has_attr = is_m2m(ec.edges)
        if not has_attr and (not src_many or not tgt_many):
            if src_many:
                source, target = ec.source, ec.target
                q = f"ALTER TABLE ONLY {source} ADD CONSTRAINT fk_{source}_{target} FOREIGN KEY ({ec.label}) REFERENCES {target};"
                print(q, file=out)
            if tgt_many:
                source, target = ec.source, ec.target
                q = f"ALTER TABLE ONLY {target} ADD CONSTRAINT fk_{target}_{source} FOREIGN KEY ({ec.label}) REFERENCES {source};"
                print(q, file=out)

        else:
            source, target = ec.source, ec.target
            for nc in ncs:
                if nc.label == source:
                    source_nc = nc
            for attr in ec.attr:
                if attr.name in source_nc.pk:
                    source_fk = attr.name

            for nc in ncs:
                if nc.label == target:
                    target_nc = nc
            for attr in ec.attr:
                if attr.name in target_nc.pk:
                    target_fk = attr.name

            q = f"ALTER TABLE ONLY {ec.label} ADD CONSTRAINT pk_{ec.label} PRIMARY KEY ({source_fk}, {target_fk});"
            print(q, file=out)

            q = f"ALTER TABLE ONLY {ec.label} ADD CONSTRAINT fk_{ec.label}_{source} FOREIGN KEY ({source_fk}) REFERENCES {source};"
            print(q, file=out)
            q = f"ALTER TABLE ONLY {ec.label} ADD CONSTRAINT fk_{ec.label}_{target} FOREIGN KEY ({target_fk}) REFERENCES {target};"
            print(q, file=out)


In [None]:
def generate_fk_contraint_o2m(ec):
    source = 
    

In [204]:
for ec in ecs:
    src_many, tgt_many, has_attr = is_m2m(ec.edges)
    if not has_attr and (not src_many or not tgt_many):
        if src_many:
            source, target = ec.source, ec.target
            q = f"ALTER TABLE ONLY {source} ADD CONSTRAINT fk_{source}_{target} FOREIGN KEY ({ec.label}) REFERENCES {target};"
            print(q)
        if tgt_many:
            source, target = ec.source, ec.target
            q = f"ALTER TABLE ONLY {target} ADD CONSTRAINT fk_{target}_{source} FOREIGN KEY ({ec.label}) REFERENCES {source};"
            print(q)
            
    else:
        source, target = ec.source, ec.target
        for nc in ncs:
            if nc.label == source:
                source_nc = nc
        for attr in ec.attr:
            if attr.name in source_nc.pk:
                source_fk = attr.name
        
        for nc in ncs:
            if nc.label == target:
                target_nc = nc
        for attr in ec.attr:
            if attr.name in target_nc.pk:
                target_fk = attr.name
        
        q = f"ALTER TABLE ONLY {ec.label} ADD CONSTRAINT pk_{ec.label} PRIMARY KEY ({source_fk}, {target_fk});"
        print(q)
        
        q = f"ALTER TABLE ONLY {ec.label} ADD CONSTRAINT fk_{ec.label}_{source} FOREIGN KEY ({source_fk}) REFERENCES {source};"
        print(q)
        q = f"ALTER TABLE ONLY {ec.label} ADD CONSTRAINT fk_{ec.label}_{target} FOREIGN KEY ({target_fk}) REFERENCES {target};"
        print(q)

ALTER TABLE ONLY employee_territories ADD CONSTRAINT pk_employee_territories PRIMARY KEY (territory_id, employee_id);
ALTER TABLE ONLY employee_territories ADD CONSTRAINT fk_employee_territories_territories FOREIGN KEY (territory_id) REFERENCES territories;
ALTER TABLE ONLY employee_territories ADD CONSTRAINT fk_employee_territories_employees FOREIGN KEY (employee_id) REFERENCES employees;
ALTER TABLE ONLY territories ADD CONSTRAINT fk_territories_region FOREIGN KEY (region_id) REFERENCES region;
ALTER TABLE ONLY orders ADD CONSTRAINT fk_orders_shippers FOREIGN KEY (ship_via) REFERENCES shippers;
ALTER TABLE ONLY orders ADD CONSTRAINT fk_orders_employees FOREIGN KEY (employee_id) REFERENCES employees;
ALTER TABLE ONLY orders ADD CONSTRAINT fk_orders_customers FOREIGN KEY (customer_id) REFERENCES customers;
ALTER TABLE ONLY employees ADD CONSTRAINT fk_employees_employees FOREIGN KEY (reports_to) REFERENCES employees;
ALTER TABLE ONLY order_details ADD CONSTRAINT pk_order_details PRIMARY

In [195]:
source_nc.pk

{'territory_id'}

In [198]:
for attr in ec.attr:
    if attr.name in source_nc.pk:
        source_fk = attr.name

In [199]:
source_fk

'territory_id'

In [191]:
target_nc.label

'employees'

In [186]:
for nc in ncs:
    if nc.label == "employees":
        break

In [188]:
nc.label

'employees'

In [185]:
ec.source, ec.target

('territories', 'employees')

In [143]:
schema = GraphSchema(ncs, [])

In [151]:
# generate migrations for tables

In [133]:
tuples['orders'][0].attr

('ship_postal_code',
 'ship_city',
 'required_date',
 'freight',
 'order_date',
 'shipped_date',
 'employee_id',
 'ship_via',
 'customer_id',
 'ship_address',
 'ship_name',
 'order_id',
 'ship_country')

In [137]:
type(tuples['orders'][0].vals[3]) == float

True

In [66]:
len(t)

108

In [103]:
t[0]

<__main__.Node at 0x119dd0040>

In [100]:
t[0] == t[1]

False

In [94]:
t = []
for k in tuples['employees']:
    if k.vals[-1] == "Andrew":
        if k not in t:
            t.append(k)
            print(k.vals)

['(206) 555-9482', 'USA', '3457', 'Andrew received his BTS commercial in 1974 and a Ph.D. in international marketing from the University of Dallas in 1981.  He is fluent in French and Italian and reads German.  He joined the company as a sales representative, was promoted to sales manager in January 1992 and to vice president of sales in March 1993.  Andrew is a member of the Sales Management Roundtable, the Seattle Chamber of Commerce, and the Pacific Rim Importers Association.', '908 W. Capital Way', 'Dr.', 'Tacoma', '1952-02-19', '<memory at 0x1456b6940>', 'Fuller', '1992-08-14', 'Vice President, Sales', 'http://accweb/emmployees/fuller.bmp', 2, 'WA', '98401', 'Andrew']
['(206) 555-9482', 'USA', '3457', 'Andrew received his BTS commercial in 1974 and a Ph.D. in international marketing from the University of Dallas in 1981.  He is fluent in French and Italian and reads German.  He joined the company as a sales representative, was promoted to sales manager in January 1992 and to vice 

In [55]:
len(tuples['employees'])

895

In [13]:
n = triplets[0][0]

In [18]:
n

<Node id=1 labels=frozenset({'territories'}) properties={'region_id': 1, 'territory_description': 'Westboro', 'territory_id': '01581'}>

In [19]:
e = triplets[0][1]

In [28]:
edgeMapper(e)

In [23]:
e.keys()

dict_keys(['territory_id', 'employee_id'])

In [51]:
len(triplets)

16481

In [44]:
node = nodes[0]['n']

In [48]:
list(node.labels)[0]

'suppliers'

In [42]:
node.keys()

['n']

In [34]:
list(result)

[]

In [24]:
result

<neo4j.work.result.Result at 0x11b3affd0>

In [25]:
result = [item for item in result]

In [26]:
result

[]