<a href="https://colab.research.google.com/github/horasan/eng_to_sql_ner/blob/main/NER_B_1_mapping_be_to_db.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
# read data from google drive
drive.mount('/content/drive')
FOLDER_PATH = "NER_for_SQL"
FULL_PATH = "/content/drive/My Drive/Colab Notebooks/" + FOLDER_PATH + "/"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import json
import networkx as nx
from itertools import product
#from transformers import RobertaTokenizerFast, RobertaForTokenClassification
#from transformers import AutoTokenizer, AutoModelForTokenClassification
import os
import torch
from collections import defaultdict
from typing import Dict


In [3]:
def normalize(name):
    return name.lower()

In [4]:
def load_schema_config(schema_json_path):
    with open(schema_json_path, 'r') as f:
        schema = json.load(f)

    graph = nx.Graph()
    entity_to_tables = {}

    # Add tables and columns as nodes and map entities
    for table in schema['tables']:
        table_name = normalize(table['table_name'])
        graph.add_node(table_name)
        for col in table.get('columns', []):
            col_name = normalize(col['name'])
            if 'business_entity' in col:
                entity = normalize(col['business_entity'])
                entity_to_tables.setdefault(entity, set()).add(table_name)

    # Add relationships as edges with FK info as attributes
    for rel in schema.get('relationships', []):
        from_table = normalize(rel['from_table'])
        to_table = normalize(rel['to_table'])
        from_col = normalize(rel.get('from_column', ''))
        to_col = normalize(rel.get('to_column', ''))
        graph.add_edge(from_table, to_table, from_column=from_col, to_column=to_col)

    return entity_to_tables, graph

In [5]:
class SmartTableResolver:
    def __init__(self, schema):
        self.schema = schema
        self.graph = self._build_graph()
        self.entity_to_columns = self._build_entity_to_columns()

    def _build_graph(self):
        graph = defaultdict(list)
        for rel in self.schema['relationships']:
            from_table = rel['from_table'].lower()
            to_table = rel['to_table'].lower()
            graph[from_table].append(to_table)
            graph[to_table].append(from_table)  # For pathfinding only
        return graph

    def _build_entity_to_columns(self):
        entity_map = defaultdict(list)
        for table in self.schema['tables']:
            for col in table['columns']:
                if 'business_entity' in col and col['business_entity']:
                    entity_map[col['business_entity'].lower()].append((
                        table['table_name'].lower(), col['name'].lower()
                    ))
        return entity_map

    def resolve(self, entities):
        entity_candidates = {
            ent: self.entity_to_columns.get(ent.lower(), [])
            for ent in entities
        }

        # All candidate tables
        table_sets = [set(t for t, _ in candidates) for candidates in entity_candidates.values() if candidates]
        if not table_sets:
            return {'tables': [], 'relationships': []}

        # Flatten all unique tables
        from itertools import product
        all_candidates = list(product(*table_sets))

        best_path = None
        min_nodes = float('inf')

        for combo in all_candidates:
            all_paths = set()
            for i in range(len(combo)):
                for j in range(i + 1, len(combo)):
                    src = combo[i]
                    tgt = combo[j]
                    path = self._shortest_path(src, tgt)
                    if path:
                        all_paths.update(path)

            if all_paths and len(all_paths) < min_nodes:
                best_path = all_paths
                min_nodes = len(all_paths)

        if not best_path:
            return {'tables': list(set().union(*table_sets)), 'relationships': []}

        involved_tables = set(best_path)

        # Extract only real (directed) relationships from schema
        valid_relationships = set(
            (rel['from_table'].lower(), rel['from_column'].lower(), rel['to_table'].lower(), rel['to_column'].lower())
            for rel in self.schema['relationships']
        )

        relationships = []
        for rel in self.schema['relationships']:
            ft = rel['from_table'].lower()
            tt = rel['to_table'].lower()
            if ft in involved_tables and tt in involved_tables:
                relationships.append(rel)

        return {'tables': sorted(involved_tables), 'relationships': relationships}

    def _shortest_path(self, start, end):
        queue = deque([(start, [start])])
        visited = set()
        while queue:
            current, path = queue.popleft()
            if current == end:
                return path
            visited.add(current)
            for neighbor in self.graph[current]:
                if neighbor not in visited:
                    queue.append((neighbor, path + [neighbor]))
        return None


In [6]:
class SmartTableResolver_old:
    def __init__(self, entity_to_tables, graph, entity_preference=None):
        self.entity_to_tables = entity_to_tables
        self.graph = graph
        self.entity_preference = entity_preference or {}

    def combo_score(self, combo, ner_labels):
        score = 0
        for ent, tbl in zip(ner_labels, combo):
            prefs = self.entity_preference.get(ent, [])
            try:
                score += prefs.index(tbl)
            except ValueError:
                score += 100
        return score

    def resolve(self, ner_labels):
        ner_labels = [normalize(l) for l in ner_labels]
        print(f"NER labels normalized: {ner_labels}")

        candidates_per_entity = []
        for ent in ner_labels:
            if ent not in self.entity_to_tables:
                print(f"Unknown entity: {ent}")
                return []
            print(f"Entity '{ent}' candidates: {self.entity_to_tables[ent]}")
            candidates_per_entity.append(list(self.entity_to_tables[ent]))

        best_solution = None
        best_combo = None
        best_score = None
        best_relationships = None

        for combo in product(*candidates_per_entity):
            terminals = set(combo)
            disconnected = False
            nodes_in_paths = set()
            rel_columns = set()

            terminals_list = list(terminals)
            for i in range(len(terminals_list)):
                for j in range(i + 1, len(terminals_list)):
                    src = terminals_list[i]
                    tgt = terminals_list[j]
                    try:
                        path = nx.shortest_path(self.graph, src, tgt)
                        for u, v in zip(path[:-1], path[1:]):
                            edge_data = self.graph.get_edge_data(u, v, default={})
                            from_col = edge_data.get('from_column', '')
                            to_col = edge_data.get('to_column', '')
                            rel_columns.add( (u, from_col, v, to_col) )
                        nodes_in_paths.update(path)
                    except nx.NetworkXNoPath:
                        disconnected = True
                        break
                if disconnected:
                    break

            if disconnected:
                continue

            solution_nodes = nodes_in_paths.union(terminals)
            score = self.combo_score(combo, ner_labels)
            #print(f"Testing combo: {combo}")
            #print(f"Nodes in paths for combo: {solution_nodes}")
            #print(f"Combo score: {score}")

            if best_solution is None or score < best_score:
                best_solution = solution_nodes
                best_combo = combo
                best_score = score
                best_relationships = rel_columns
                #print(f"New best solution: {best_solution} with score {best_score}")

        if best_solution is None:
            print("No valid solution found")
            return []

        relationships_list = [
            {
                "from_table": from_tbl,
                "from_column": from_col,
                "to_table": to_tbl,
                "to_column": to_col
            }
            for (from_tbl, from_col, to_tbl, to_col) in best_relationships
        ]

        #print(f"Best combo chosen: {best_combo}")
        return {
            "tables": sorted(best_solution),
            "relationships": relationships_list
        }

In [7]:
def generate_sql_with_joins_using_where(resolved_result):
    tables = resolved_result["tables"]
    relationships = resolved_result["relationships"]

    # Lowercase table names for SQL consistency
    formatted_tables = [t.lower() for t in tables]
    select_clause = "SELECT *"
    from_clause = f"FROM {', '.join(formatted_tables)}"
    where_clauses = []

    for rel in relationships:
        left = f"{rel['from_table'].lower()}.{rel['from_column'].lower()}"
        right = f"{rel['to_table'].lower()}.{rel['to_column'].lower()}"
        where_clauses.append(f"{left} = {right}")

    where_clause = ""
    if where_clauses:
        where_clause = "WHERE " + " AND ".join(where_clauses)

    sql_query = f"{select_clause}\n{from_clause}\n{where_clause}"
    return sql_query


In [8]:
if __name__ == "__main__":

    # Write to file just for loading function (or modify loader to accept dict)
    schema_file = "NER-B-1_schema_info_trs.json"

    entity_to_tables, graph = load_schema_config(FULL_PATH + schema_file)

    # Define preferences: lower index = higher preference for table per entity
    entity_preference = {
        "dealer": ["TRS.DEAL_MASTER"],
        "deal_date": ["TRS.DEAL_MASTER"],
        "value_date": ["TRS.DEAL_MASTER", "TRS.DEAL_DETAIL"],
        "customer_name": ["TRS.CUSTOMER"]
    }

    resolver = SmartTableResolver_old(entity_to_tables, graph, entity_preference)



In [15]:
graph

<networkx.classes.graph.Graph at 0x7a13e0763f10>

In [9]:

def generate_sql_with_entity_filters_and_values(result: Dict, schema: Dict, business_entity_values: Dict[str, str], put_extracted_values=True) -> str:
    tables = result['tables']
    relationships = result['relationships']

    where_clauses = []

    for rel in relationships:
        left = f"{rel['from_table'].lower()}.{rel['from_column'].lower()}"
        right = f"{rel['to_table'].lower()}.{rel['to_column'].lower()}"
        where_clauses.append(f"{left} = {right}")

    entity_keys_normalized = {k.strip().lower(): k for k in business_entity_values}

    for table_info in schema['tables']:
        table = table_info['table_name'].lower()
        for col in table_info.get('columns', []):
            entity = col.get('business_entity')
            if not entity:
                continue
            norm_entity = entity.strip().lower()
            if norm_entity in entity_keys_normalized:
                original_key = entity_keys_normalized[norm_entity]
                value = business_entity_values[original_key]
                val_repr = f"'{value}'" if put_extracted_values else f"@{entity}"
                where_clauses.append(f"{table}.{col['name'].lower()} = {val_repr}")

    sql = f"SELECT *\nFROM {', '.join(tables)}\nWHERE\n    " + "\n    AND ".join(where_clauses)
    return sql

In [10]:
ner_labels_from_json = ["deal_type", "customer_name", "status", "value_date", "amount", "currency"]
tables = resolver.resolve(ner_labels_from_json)

NER labels normalized: ['deal_type', 'customer_name', 'status', 'value_date', 'amount', 'currency']
Entity 'deal_type' candidates: {'trs.deal_master'}
Entity 'customer_name' candidates: {'cus.customer'}
Entity 'status' candidates: {'trs.deal_master'}
Entity 'value_date' candidates: {'trs.deal_detail', 'trs.deal_master'}
Entity 'amount' candidates: {'trs.deal_detail'}
Entity 'currency' candidates: {'trs.deal_detail'}


In [11]:
print(tables)

{'tables': ['cus.customer', 'trs.deal_detail', 'trs.deal_master'], 'relationships': [{'from_table': 'trs.deal_detail', 'from_column': 'dealmasteroid', 'to_table': 'trs.deal_master', 'to_column': 'oid'}, {'from_table': 'trs.deal_master', 'from_column': 'customeroid', 'to_table': 'cus.customer', 'to_column': 'oid'}]}


In [12]:
def filter_valid_relationships(resolved, schema):
    # Build a map: table_name → set of column names
    table_columns = {
        table['table_name'].lower(): {col['name'].lower() for col in table['columns']}
        for table in schema['tables']
    }

    valid_rels = []
    for rel in resolved['relationships']:
        ft = rel['from_table'].lower()
        fc = rel['from_column'].lower()
        tt = rel['to_table'].lower()
        tc = rel['to_column'].lower()

        # Only keep if both columns exist in their respective tables
        if fc in table_columns.get(ft, set()) and tc in table_columns.get(tt, set()):
            valid_rels.append(rel)

    return {
        'tables': resolved['tables'],
        'relationships': valid_rels
    }


In [13]:
with open(FULL_PATH + schema_file, 'r') as f:
        sample_schema = json.load(f)

In [14]:
sample_schema

{'tables': [{'table_name': 'TRS.DEAL_MASTER',
   'columns': [{'name': 'OID'},
    {'name': 'DEALER', 'business_entity': 'DEALER'},
    {'name': 'OPERATIONDATE', 'business_entity': 'DEAL_DATE'},
    {'name': 'CUSTOMEROID'},
    {'name': 'PRDREFNO', 'business_entity': 'PRODUCT_REF_NO'},
    {'name': 'OPERATIONREFNO'},
    {'name': 'VALDATE', 'business_entity': 'VALUE_DATE'},
    {'name': 'DEALTP', 'business_entity': 'DEAL_TYPE'},
    {'name': 'APPSTATUS', 'business_entity': 'STATUS'}]},
  {'table_name': 'TRS.DEAL_DETAIL',
   'columns': [{'name': 'OID'},
    {'name': 'CURR', 'business_entity': 'CURRENCY'},
    {'name': 'VALDATE', 'business_entity': 'VALUE_DATE'},
    {'name': 'OPREFNO'},
    {'name': 'DEALMASTEROID'},
    {'name': 'ACCOID'},
    {'name': 'AMT', 'business_entity': 'AMOUNT'}]},
  {'table_name': 'CUS.CUSTOMER',
   'columns': [{'name': 'OID'},
    {'name': 'NAME', 'business_entity': 'CUSTOMER_NAME'},
    {'name': 'NO'},
    {'name': 'FROMCOUNTRY'}]},
  {'table_name': 'ACC.ACC

In [15]:
cleaned = filter_valid_relationships(tables, sample_schema)
print(cleaned)

{'tables': ['cus.customer', 'trs.deal_detail', 'trs.deal_master'], 'relationships': [{'from_table': 'trs.deal_detail', 'from_column': 'dealmasteroid', 'to_table': 'trs.deal_master', 'to_column': 'oid'}, {'from_table': 'trs.deal_master', 'from_column': 'customeroid', 'to_table': 'cus.customer', 'to_column': 'oid'}]}


In [16]:
print(cleaned)
print(generate_sql_with_joins_using_where(cleaned))

{'tables': ['cus.customer', 'trs.deal_detail', 'trs.deal_master'], 'relationships': [{'from_table': 'trs.deal_detail', 'from_column': 'dealmasteroid', 'to_table': 'trs.deal_master', 'to_column': 'oid'}, {'from_table': 'trs.deal_master', 'from_column': 'customeroid', 'to_table': 'cus.customer', 'to_column': 'oid'}]}
SELECT *
FROM cus.customer, trs.deal_detail, trs.deal_master
WHERE trs.deal_detail.dealmasteroid = trs.deal_master.oid AND trs.deal_master.customeroid = cus.customer.oid
