<a href="https://colab.research.google.com/github/horasan/eng_to_sql_ner/blob/main/NER_B_2_full_life_cycle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [73]:
import torch
import json
import networkx as nx
from itertools import product
#from transformers import RobertaTokenizerFast, RobertaForTokenClassification
#from transformers import AutoTokenizer, AutoModelForTokenClassification
import os
import torch
from collections import defaultdict
from typing import Dict

In [74]:
from google.colab import drive
# read data from google drive
drive.mount('/content/drive')
FOLDER_PATH = "NER_for_SQL"
FULL_PATH = "/content/drive/My Drive/Colab Notebooks/" + FOLDER_PATH + "/"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [75]:

bio_tagged_dataset_file_name   = "synthetic_queries_300_bio_tagged.txt"

tag2id_with_cust_file_name = "tag2id_with_cust.json"
id2tag_with_cust_file_name = "id2tag_with_cust.json"

trained_model_path = FULL_PATH + "ner-roberta-with-cust"
trained_tokenizer_path = FULL_PATH + "ner-roberta-with-cust"

# utils

In [76]:
def predict(text, tokenizer, model, id2tag):
    # Tokenize input
    tokens = text.split()
    encoding = tokenizer(
        tokens,
        is_split_into_words=True,
        return_tensors="pt",
        truncation=True,
        return_offsets_mapping=True
    )

    # Remove offset_mapping before feeding to model
    encoding.pop("offset_mapping")

    # Run inference
    model.eval()
    with torch.no_grad():
        output = model(**encoding)
        predictions = torch.argmax(output.logits, dim=-1)

    # Get word-level predictions
    word_ids = encoding.word_ids()
    results = []
    for idx, word_idx in enumerate(word_ids):
        if word_idx is not None and (idx == 0 or word_idx != word_ids[idx - 1]):
            label_id = predictions[0][idx].item()
            #tag = id2tag.get(label_id, "O")  # default to "O" if not found
            results.append((tokens[word_idx], id2tag[label_id]))
            #results.append((tokens[word_idx], tag))
    return results


In [77]:
from collections import defaultdict

def ner_tuples_to_json(tagged_tokens):
    entities = defaultdict(str)
    current_entity = None

    for token, tag in tagged_tokens:
        if tag.startswith("B-"):
            current_entity = tag[2:]
            entities[current_entity] += token.lstrip("▁").strip() + " "
        elif tag.startswith("I-") and current_entity:
            entities[current_entity] += token.lstrip("▁").strip() + " "
        else:
            current_entity = None  # Reset if it's "O" or invalid

    # Strip trailing spaces
    return {k: v.strip() for k, v in entities.items()}

In [78]:
def normalize(name):
    return name.lower()

In [79]:
def load_schema_config(schema_json_path):
    with open(schema_json_path, 'r') as f:
        schema = json.load(f)

    graph = nx.Graph()
    entity_to_tables = {}

    # Add tables and columns as nodes and map entities
    for table in schema['tables']:
        table_name = normalize(table['table_name'])
        graph.add_node(table_name)
        for col in table.get('columns', []):
            col_name = normalize(col['name'])
            if 'business_entity' in col:
                entity = normalize(col['business_entity'])
                entity_to_tables.setdefault(entity, set()).add(table_name)

    # Add relationships as edges with FK info as attributes
    for rel in schema.get('relationships', []):
        from_table = normalize(rel['from_table'])
        to_table = normalize(rel['to_table'])
        from_col = normalize(rel.get('from_column', ''))
        to_col = normalize(rel.get('to_column', ''))
        graph.add_edge(from_table, to_table, from_column=from_col, to_column=to_col)

    return entity_to_tables, graph

In [80]:
class SmartTableResolver_old:
    def __init__(self, entity_to_tables, graph, entity_preference=None):
        self.entity_to_tables = entity_to_tables
        self.graph = graph
        self.entity_preference = entity_preference or {}

    def combo_score(self, combo, ner_labels):
        score = 0
        for ent, tbl in zip(ner_labels, combo):
            prefs = self.entity_preference.get(ent, [])
            try:
                score += prefs.index(tbl)
            except ValueError:
                score += 100
        return score

    def resolve(self, ner_labels):
        ner_labels = [normalize(l) for l in ner_labels]
        print(f"NER labels normalized: {ner_labels}")

        candidates_per_entity = []
        for ent in ner_labels:
            if ent not in self.entity_to_tables:
                print(f"Unknown entity: {ent}")
                return []
            print(f"Entity '{ent}' candidates: {self.entity_to_tables[ent]}")
            candidates_per_entity.append(list(self.entity_to_tables[ent]))

        best_solution = None
        best_combo = None
        best_score = None
        best_relationships = None

        for combo in product(*candidates_per_entity):
            terminals = set(combo)
            disconnected = False
            nodes_in_paths = set()
            rel_columns = set()

            terminals_list = list(terminals)
            for i in range(len(terminals_list)):
                for j in range(i + 1, len(terminals_list)):
                    src = terminals_list[i]
                    tgt = terminals_list[j]
                    try:
                        path = nx.shortest_path(self.graph, src, tgt)
                        for u, v in zip(path[:-1], path[1:]):
                            edge_data = self.graph.get_edge_data(u, v, default={})
                            from_col = edge_data.get('from_column', '')
                            to_col = edge_data.get('to_column', '')
                            rel_columns.add( (u, from_col, v, to_col) )
                        nodes_in_paths.update(path)
                    except nx.NetworkXNoPath:
                        disconnected = True
                        break
                if disconnected:
                    break

            if disconnected:
                continue

            solution_nodes = nodes_in_paths.union(terminals)
            score = self.combo_score(combo, ner_labels)
            #print(f"Testing combo: {combo}")
            #print(f"Nodes in paths for combo: {solution_nodes}")
            #print(f"Combo score: {score}")

            if best_solution is None or score < best_score:
                best_solution = solution_nodes
                best_combo = combo
                best_score = score
                best_relationships = rel_columns
                #print(f"New best solution: {best_solution} with score {best_score}")

        if best_solution is None:
            print("No valid solution found")
            return []

        relationships_list = [
            {
                "from_table": from_tbl,
                "from_column": from_col,
                "to_table": to_tbl,
                "to_column": to_col
            }
            for (from_tbl, from_col, to_tbl, to_col) in best_relationships
        ]

        #print(f"Best combo chosen: {best_combo}")
        return {
            "tables": sorted(best_solution),
            "relationships": relationships_list
        }

In [81]:
def generate_sql_with_joins_using_where(resolved_result):
    tables = resolved_result["tables"]
    relationships = resolved_result["relationships"]

    # Lowercase table names for SQL consistency
    formatted_tables = [t.lower() for t in tables]
    select_clause = "SELECT *"
    from_clause = f"FROM {', '.join(formatted_tables)}"
    where_clauses = []

    for rel in relationships:
        left = f"{rel['from_table'].lower()}.{rel['from_column'].lower()}"
        right = f"{rel['to_table'].lower()}.{rel['to_column'].lower()}"
        where_clauses.append(f"{left} = {right}")

    where_clause = ""
    if where_clauses:
        where_clause = "WHERE " + " AND ".join(where_clauses)

    sql_query = f"{select_clause}\n{from_clause}\n{where_clause}"
    return sql_query


In [82]:
if __name__ == "__main__":

    # Write to file just for loading function (or modify loader to accept dict)
    schema_file = "NER-B-1_schema_info_trs.json"

    entity_to_tables, graph = load_schema_config(FULL_PATH + schema_file)

    # Define preferences: lower index = higher preference for table per entity
    entity_preference = {
        "dealer": ["TRS.DEAL_MASTER"],
        "deal_date": ["TRS.DEAL_MASTER"],
        "value_date": ["TRS.DEAL_MASTER", "TRS.DEAL_DETAIL"],
        "customer_name": ["TRS.CUSTOMER"]
    }

    resolver = SmartTableResolver_old(entity_to_tables, graph, entity_preference)



In [83]:

def generate_sql_with_entity_filters_and_values(result: Dict, schema: Dict, business_entity_values: Dict[str, str], put_extracted_values=True) -> str:
    tables = result['tables']
    relationships = result['relationships']

    where_clauses = []

    for rel in relationships:
        left = f"{rel['from_table'].lower()}.{rel['from_column'].lower()}"
        right = f"{rel['to_table'].lower()}.{rel['to_column'].lower()}"
        where_clauses.append(f"{left} = {right}")

    entity_keys_normalized = {k.strip().lower(): k for k in business_entity_values}

    for table_info in schema['tables']:
        table = table_info['table_name'].lower()
        for col in table_info.get('columns', []):
            entity = col.get('business_entity')
            if not entity:
                continue
            norm_entity = entity.strip().lower()
            if norm_entity in entity_keys_normalized:
                original_key = entity_keys_normalized[norm_entity]
                value = business_entity_values[original_key]
                val_repr = f"'{value}'" if put_extracted_values else f"@{entity}"
                where_clauses.append(f"{table}.{col['name'].lower()} = {val_repr}")

    sql = f"SELECT *\nFROM {', '.join(tables)}\nWHERE\n    " + "\n    AND ".join(where_clauses)
    return sql

In [84]:
def filter_valid_relationships(tables_info, full_schema):
    valid_tables = set(tables_info['tables'])

    # Build a map of table -> set(columns)
    table_columns = {
        table['table_name'].lower(): {col['name'].lower() for col in table['columns']}
        for table in full_schema['tables']
    }

    valid_rels = []
    for rel in full_schema['relationships']:
        ft, fc = rel['from_table'].lower(), rel['from_column'].lower()
        tt, tc = rel['to_table'].lower(), rel['to_column'].lower()

        if ft in valid_tables and tt in valid_tables:
            if fc in table_columns.get(ft, set()) and tc in table_columns.get(tt, set()):
                valid_rels.append(rel)

    return {
        "tables": tables_info['tables'],
        "relationships": valid_rels
    }


# Load saved model

In [85]:
from transformers import RobertaTokenizerFast, RobertaForTokenClassification
from transformers import AutoTokenizer, AutoModelForTokenClassification

tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base", add_prefix_space=True)
model = AutoModelForTokenClassification.from_pretrained(trained_model_path)

## Log saved id-tag files

In [86]:
import json
# just to be sure type is integer
with open(FULL_PATH + id2tag_with_cust_file_name, "r") as f:
    id2tag = {int(k): v for k, v in json.load(f).items()}


## Capture the business entities

In [87]:
text = "Get MM trades for ABC BANK with status approved and value date is tomorrow and amount is 3000 and cur is EUR"
result = predict(text, tokenizer, model, id2tag)


In [88]:
result_json = ner_tuples_to_json(result)
print(result_json)

{'DEAL_TYPE': 'MM', 'CUSTOMER_NAME': 'ABC BANK', 'STATUS': 'approved', 'VALUE_DATE': 'tomorrow', 'AMOUNT': '3000', 'CURRENCY': 'EUR'}


# Generate SQL

In [89]:
captured_business_entity_values = result_json

# get keys from JSON
#ner_labels_from_captured_business_entity_values = ["DEAL_TYPE", "CUSTOMER_NAME", "STATUS", "VALUE_DATE", "AMOUNT", "CURRENCY"]
ner_labels_from_captured_business_entity_values = captured_business_entity_values.keys()
ner_labels_from_captured_business_entity_values = [l.upper() for l in ner_labels_from_captured_business_entity_values]
print(ner_labels_from_captured_business_entity_values)

['DEAL_TYPE', 'CUSTOMER_NAME', 'STATUS', 'VALUE_DATE', 'AMOUNT', 'CURRENCY']


## Get possible tables.

In [90]:
tables = resolver.resolve(ner_labels_from_captured_business_entity_values)

NER labels normalized: ['deal_type', 'customer_name', 'status', 'value_date', 'amount', 'currency']
Entity 'deal_type' candidates: {'trs.deal_master'}
Entity 'customer_name' candidates: {'cus.customer'}
Entity 'status' candidates: {'trs.deal_master'}
Entity 'value_date' candidates: {'trs.deal_master', 'trs.deal_detail'}
Entity 'amount' candidates: {'trs.deal_detail'}
Entity 'currency' candidates: {'trs.deal_detail'}


## Load schema data

In [91]:
with open(FULL_PATH + schema_file, 'r') as f:
        trs_schema = json.load(f)

In [92]:
# Clean unnecessary tables
cleaned_tables = filter_valid_relationships(tables, trs_schema)
print(cleaned_tables)

{'tables': ['cus.customer', 'trs.deal_detail', 'trs.deal_master'], 'relationships': [{'from_table': 'TRS.DEAL_MASTER', 'from_column': 'CUSTOMEROID', 'to_table': 'CUS.CUSTOMER', 'to_column': 'OID'}, {'from_table': 'TRS.DEAL_DETAIL', 'from_column': 'DEALMASTEROID', 'to_table': 'TRS.DEAL_MASTER', 'to_column': 'OID'}]}


## Print only the related tables.

In [93]:
print(generate_sql_with_joins_using_where(cleaned_tables))

SELECT *
FROM cus.customer, trs.deal_detail, trs.deal_master
WHERE trs.deal_master.customeroid = cus.customer.oid AND trs.deal_detail.dealmasteroid = trs.deal_master.oid


## Print SQL with parameters

In [94]:
sql = generate_sql_with_entity_filters_and_values(cleaned_tables, trs_schema, captured_business_entity_values, put_extracted_values=True)
print(sql)

SELECT *
FROM cus.customer, trs.deal_detail, trs.deal_master
WHERE
    trs.deal_master.customeroid = cus.customer.oid
    AND trs.deal_detail.dealmasteroid = trs.deal_master.oid
    AND trs.deal_master.valdate = 'tomorrow'
    AND trs.deal_master.dealtp = 'MM'
    AND trs.deal_master.appstatus = 'approved'
    AND trs.deal_detail.curr = 'EUR'
    AND trs.deal_detail.valdate = 'tomorrow'
    AND trs.deal_detail.amt = '3000'
    AND cus.customer.name = 'ABC BANK'
