<a href="https://colab.research.google.com/github/brijeshrn/AI/blob/main/ContextDatPrep_triplets_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install openai



In [2]:
from google.colab import files
uploaded = files.upload()

Saving AdventureWorks.db to AdventureWorks.db


In [5]:
import sqlite3
conn = sqlite3.connect("AdventureWorks.db")
cursor = conn.cursor()

In [6]:
import pandas as pd

# --- 1. Get all table names ---
tables = pd.read_sql_query(
    "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';",
    conn
)["name"].tolist()

# --- 2. Extract schema for all tables ---
schema_rows = []
for table in tables:
    info = pd.read_sql_query(f"PRAGMA table_info({table});", conn)
    info["table"] = table
    schema_rows.append(info)
schema_df = pd.concat(schema_rows, ignore_index=True)

# --- 3. Show sample ---
pd.set_option('display.max_rows', 40, 'display.max_colwidth', 50)
display(schema_df)

# --- 4. Optionally, save as CSV ---
schema_df.to_csv("adventureworks_schema.csv", index=False)


Unnamed: 0,cid,name,type,notnull,dflt_value,pk,table
0,0,businessentityid,INTEGER,0,,0,salesperson
1,1,territoryid,INTEGER,0,,0,salesperson
2,2,salesquota,INTEGER,0,,0,salesperson
3,3,bonus,INTEGER,0,,0,salesperson
4,4,commissionpct,FLOAT,0,,0,salesperson
...,...,...,...,...,...,...,...
110,2,fromcurrencycode,TEXT,0,,0,currencyrate
111,3,tocurrencycode,TEXT,0,,0,currencyrate
112,4,averagerate,FLOAT,0,,0,currencyrate
113,5,endofdayrate,FLOAT,0,,0,currencyrate


In [7]:
#extract sample column values
schema_dict = {}
for table in tables:
    cols = pd.read_sql_query(f"PRAGMA table_info({table});", conn)
    sample_rows = pd.read_sql_query(f"SELECT * FROM {table} LIMIT 3;", conn)
    schema_dict[table] = {}
    for _, row in cols.iterrows():
        col = row["name"]
        samples = sample_rows[col].dropna().unique().tolist()[:2] if col in sample_rows else []
        schema_dict[table][col] = {
            "type": row["type"],
            "notnull": bool(row["notnull"]),
            "pk": bool(row["pk"]),
            "sample": samples
        }


In [8]:
import json

print(json.dumps(schema_dict, indent=2))

{
  "salesperson": {
    "businessentityid": {
      "type": "INTEGER",
      "notnull": false,
      "pk": false,
      "sample": [
        274,
        275
      ]
    },
    "territoryid": {
      "type": "INTEGER",
      "notnull": false,
      "pk": false,
      "sample": [
        "",
        2
      ]
    },
    "salesquota": {
      "type": "INTEGER",
      "notnull": false,
      "pk": false,
      "sample": [
        "",
        300000
      ]
    },
    "bonus": {
      "type": "INTEGER",
      "notnull": false,
      "pk": false,
      "sample": [
        0,
        4100
      ]
    },
    "commissionpct": {
      "type": "FLOAT",
      "notnull": false,
      "pk": false,
      "sample": [
        0.0,
        0.012
      ]
    },
    "salesytd": {
      "type": "FLOAT",
      "notnull": false,
      "pk": false,
      "sample": [
        559697.5639,
        3763178.1787
      ]
    },
    "saleslastyear": {
      "type": "FLOAT",
      "notnull": false,
      "pk": false

In [9]:
for table, cols in schema_dict.items():
    print(f"\nTable: {table}")
    for col in cols:
        print(f"  - {col}")


Table: salesperson
  - businessentityid
  - territoryid
  - salesquota
  - bonus
  - commissionpct
  - salesytd
  - saleslastyear
  - rowguid
  - modifieddate

Table: product
  - productid
  - NAME
  - productnumber
  - makeflag
  - finishedgoodsflag
  - color
  - safetystocklevel
  - reorderpoint
  - standardcost
  - listprice
  - size
  - sizeunitmeasurecode
  - weightunitmeasurecode
  - weight
  - daystomanufacture
  - productline
  - class
  - style
  - productsubcategoryid
  - productmodelid
  - sellstartdate
  - sellenddate
  - discontinueddate
  - rowguid
  - modifieddate

Table: productmodelproductdescriptionculture
  - productmodelid
  - productdescriptionid
  - cultureid
  - modifieddate

Table: productdescription
  - productdescriptionid
  - description
  - rowguid
  - modifieddate

Table: productreview
  - productreviewid
  - productid
  - reviewername
  - reviewdate
  - emailaddress
  - rating
  - comments
  - modifeddate
  - modifieddate

Table: productcategory
  - produ

In [16]:
for table in schema_dict:
    print(f"\nChecking FKs for table: {table}")
    cursor.execute(f"PRAGMA foreign_key_list('{table}')")
    fks = cursor.fetchall()
    print("PRAGMA returned:", fks)   # See what is returned
    for fk in fks:
        print("  FK entry:", fk)
        _, _, ref_table, from_col, to_col, *_ = fk
        if from_col in schema_dict[table]:
            schema_dict[table][from_col]["fk"] = {
                "ref_table": ref_table,
                "ref_column": to_col
            }



Checking FKs for table: salesperson
PRAGMA returned: []

Checking FKs for table: product
PRAGMA returned: []

Checking FKs for table: productmodelproductdescriptionculture
PRAGMA returned: []

Checking FKs for table: productdescription
PRAGMA returned: []

Checking FKs for table: productreview
PRAGMA returned: []

Checking FKs for table: productcategory
PRAGMA returned: []

Checking FKs for table: productsubcategory
PRAGMA returned: []

Checking FKs for table: salesorderdetail
PRAGMA returned: []

Checking FKs for table: salesorderheader
PRAGMA returned: []

Checking FKs for table: salesterritory
PRAGMA returned: []

Checking FKs for table: countryregioncurrency
PRAGMA returned: []

Checking FKs for table: currencyrate
PRAGMA returned: []


In [15]:
import json

print(json.dumps(schema_dict, indent=2))

{
  "salesperson": {
    "businessentityid": {
      "type": "INTEGER",
      "notnull": false,
      "pk": false,
      "sample": [
        274,
        275
      ]
    },
    "territoryid": {
      "type": "INTEGER",
      "notnull": false,
      "pk": false,
      "sample": [
        "",
        2
      ]
    },
    "salesquota": {
      "type": "INTEGER",
      "notnull": false,
      "pk": false,
      "sample": [
        "",
        300000
      ]
    },
    "bonus": {
      "type": "INTEGER",
      "notnull": false,
      "pk": false,
      "sample": [
        0,
        4100
      ]
    },
    "commissionpct": {
      "type": "FLOAT",
      "notnull": false,
      "pk": false,
      "sample": [
        0.0,
        0.012
      ]
    },
    "salesytd": {
      "type": "FLOAT",
      "notnull": false,
      "pk": false,
      "sample": [
        559697.5639,
        3763178.1787
      ]
    },
    "saleslastyear": {
      "type": "FLOAT",
      "notnull": false,
      "pk": false

In [17]:
# ---------- Configuration ----------
from openai import OpenAI
from google.colab import userdata
OPENAI_API_KEY = userdata.get('OPENAI_KEY')
 #---------- Client Setup ----------
client = OpenAI(api_key=OPENAI_API_KEY)

In [45]:
import openai
import json

SYSTEM_MSG = """You are a world-class SQL database architect.
Review the provided database schema (in JSON), and:
- Detect and add missing foreign keys (FKs), referencing the canonical AdventureWorks schema.
- Represent every FK as: "fk": {"table": "referenced_table", "column": "referenced_column"} (not as a string).
- Correct or annotate primary keys (PKs) if missing or incorrect.
- Add clear, human-readable descriptions to each table and column, if missing.
- Suggest corrections for datatypes or column naming inconsistencies.
- Only modify the schema structure; do not change any data.
- Preserve all existing keys and metadata in each column, especially "sample" values, unless you are explicitly correcting their type or structure.
Reply with a corrected schema_dict in valid JSON format, with all FKs as structured objects.
"""

def extract_json_from_markdown(text):
    text = text.strip()
    if text.startswith("```"):
        lines = text.splitlines()
        if lines[0].strip().startswith("```"):
            lines = lines[1:]
        if lines and lines[-1].strip().startswith("```"):
            lines = lines[:-1]
        text = "\n".join(lines).strip()
    return text

def restore_samples(original_schema, enriched_schema):
    for table, columns in original_schema.items():
        if table not in enriched_schema:
            continue
        for col, meta in columns.items():
            if col not in enriched_schema[table]:
                continue
            if isinstance(meta, dict) and "sample" in meta:
                if "sample" not in enriched_schema[table][col]:
                    enriched_schema[table][col]["sample"] = meta["sample"]
    return enriched_schema

def enrich_schema_with_llm(schema_dict, model="gpt-4o"):
    prompt = (
        "Here is the extracted database schema_dict:\n"
        + json.dumps(schema_dict, indent=2)
        + "\n\nPlease check and correct for the following:\n"
        "1. Add missing foreign keys (FKs) using the AdventureWorks canonical schema.\n"
        "2. Every FK must be as: \"fk\": {\"table\": \"referenced_table\", \"column\": \"referenced_column\"} (not as a string!).\n"
        "3. Add or correct primary keys (PKs).\n"
        "4. Add meaningful descriptions for tables and columns.\n"
        "5. Suggest datatype corrections where obvious.\n"
        "6. Preserve all column metadata fields, especially 'sample' values. Do not remove or overwrite unless you are correcting them.\n"
        "Return only the corrected schema_dict as valid JSON."
    )
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": SYSTEM_MSG},
            {"role": "user", "content": prompt}
        ],
        temperature=0.0,
    )
    raw_output = response.choices[0].message.content
    json_str = extract_json_from_markdown(raw_output)
    try:
        corrected_schema = json.loads(json_str)
        return corrected_schema
    except Exception as e:
        print("Failed to parse LLM output as JSON. Cleaned output was:\n", json_str)
        raise e

# Usage:
# corrected_schema = enrich_schema_with_llm(original_schema)
# corrected_schema = restore_samples(original_schema, corrected_schema)


In [46]:
# Assuming your schema_dict is already loaded

corrected_schema = enrich_schema_with_llm(schema_dict)

In [47]:
import json

print(json.dumps(corrected_schema, indent=2))

{
  "salesperson": {
    "description": "Contains sales performance data for each salesperson.",
    "businessentityid": {
      "type": "INTEGER",
      "notnull": true,
      "pk": true,
      "sample": [
        274,
        275
      ],
      "fk": {
        "table": "employee",
        "column": "businessentityid"
      },
      "description": "Primary key. Unique identifier for the salesperson."
    },
    "territoryid": {
      "type": "INTEGER",
      "notnull": false,
      "pk": false,
      "sample": [
        "",
        2
      ],
      "fk": {
        "table": "salesterritory",
        "column": "territoryid"
      },
      "description": "Sales territory associated with the salesperson."
    },
    "salesquota": {
      "type": "FLOAT",
      "notnull": false,
      "pk": false,
      "sample": [
        "",
        300000
      ],
      "description": "Sales quota for the salesperson."
    },
    "bonus": {
      "type": "FLOAT",
      "notnull": false,
      "pk": fals

In [29]:
#5 schema triplets
def extract_schema_triplets(schema_dict):
    triplets = []

    for table, columns in schema_dict.items():
        # Table-level description if available
        table_desc = columns.get("description", None)
        if table_desc:
            triplets.append([table, "description", table_desc])

        for col, props in columns.items():
            # Skip the table description key itself
            if col == "description":
                continue
            # Primary key
            if props.get("pk", False):
                triplets.append([f"{table}.{col}", "primary_key", table])
            # Foreign key
            fk = props.get("fk")
            if fk:
                triplets.append([f"{table}.{col}", "foreign_key", fk])
            # Column type
            if "type" in props:
                triplets.append([f"{table}.{col}", "column_type", props["type"]])
            # Column description
            if "description" in props:
                triplets.append([f"{table}.{col}", "description", props["description"]])
    return triplets

# Example usage:
schema_triplets = extract_schema_triplets(corrected_schema)


In [40]:
print(json.dumps(schema_triplets, indent=2))

[
  [
    "salesperson",
    "description",
    "Contains sales performance data for each salesperson."
  ],
  [
    "salesperson.businessentityid",
    "primary_key",
    "salesperson"
  ],
  [
    "salesperson.businessentityid",
    "foreign_key",
    {
      "table": "person",
      "column": "businessentityid"
    }
  ],
  [
    "salesperson.businessentityid",
    "column_type",
    "INTEGER"
  ],
  [
    "salesperson.businessentityid",
    "description",
    "Primary key. Unique identifier for the salesperson."
  ],
  [
    "salesperson.territoryid",
    "foreign_key",
    {
      "table": "salesterritory",
      "column": "territoryid"
    }
  ],
  [
    "salesperson.territoryid",
    "column_type",
    "INTEGER"
  ],
  [
    "salesperson.territoryid",
    "description",
    "Sales territory in which the salesperson operates."
  ],
  [
    "salesperson.salesquota",
    "column_type",
    "FLOAT"
  ],
  [
    "salesperson.salesquota",
    "description",
    "Sales quota for the sa

In [32]:
#business rule triplets
# Raw rule lines (easy to update or append)
business_rule_lines = [
    "productreview.rating, aggregation, avg",
    "productreview.rating, range, 1-5",
    "salesperson.salesytd, compare, salesperson.salesquota",
    "salesperson.yoy_growth_pct, threshold, 5",
    "salesorderheader.totaldue_fx, threshold, 5000",
    "salesorderheader.orderdate, year, 2013",
    "product.productid, bundle, product.productid",
    "productsubcategory.productcategoryid, group_by, productcategory.productcategoryid",
    "productreview.reviewdate, within_days_of, salesorderheader.orderdate:30",
    "productreview.productid, join, salesorderdetail.productid",
    "salesperson.salesytd, greater_than, salesperson.salesquota",
    "salesperson.salesytd, year_on_year_growth, salesperson.saleslastyear"
]

# Parse to list of dicts
business_rules_tripplets = []
for line in business_rule_lines:
    parts = [x.strip() for x in line.split(",")]
    rule = {"subject": parts[0], "predicate": parts[1], "object": parts[2]}
    business_rules_tripplets.append(rule)

# business_rules is now:
# [
#   {'subject': 'productreview.rating', 'predicate': 'aggregation', 'object': 'avg'},
#   ...
# ]


In [33]:
print(json.dumps(business_rules_tripplets, indent=2))

[
  {
    "subject": "productreview.rating",
    "predicate": "aggregation",
    "object": "avg"
  },
  {
    "subject": "productreview.rating",
    "predicate": "range",
    "object": "1-5"
  },
  {
    "subject": "salesperson.salesytd",
    "predicate": "compare",
    "object": "salesperson.salesquota"
  },
  {
    "subject": "salesperson.yoy_growth_pct",
    "predicate": "threshold",
    "object": "5"
  },
  {
    "subject": "salesorderheader.totaldue_fx",
    "predicate": "threshold",
    "object": "5000"
  },
  {
    "subject": "salesorderheader.orderdate",
    "predicate": "year",
    "object": "2013"
  },
  {
    "subject": "product.productid",
    "predicate": "bundle",
    "object": "product.productid"
  },
  {
    "subject": "productsubcategory.productcategoryid",
    "predicate": "group_by",
    "object": "productcategory.productcategoryid"
  },
  {
    "subject": "productreview.reviewdate",
    "predicate": "within_days_of",
    "object": "salesorderheader.orderdate:30"
  }

In [41]:
import json

# --- Replace with your actual variable names if different ---
# Example (uncomment and run if you want to test with sample data):
# business_rules_triplets = [{"subject": "foo", "predicate": "bar", "object": "baz"}]
# corrected_schema = {"table": {"column": "value"}}
# schema_triplets = [["table.column", "primary_key", "table"]]

# Save business_rules_triplets.json
with open("business_rules_triplets.json", "w") as f:
    json.dump(business_rules_tripplets, f, indent=2)

# Save corrected_schema.json
with open("corrected_schema.json", "w") as f:
    json.dump(corrected_schema, f, indent=2)

# Save schema_triplets.json
with open("schema_triplets.json", "w") as f:
    json.dump(schema_triplets, f, indent=2)

print("All three JSON files have been saved in your working directory!")


All three JSON files have been saved in your working directory!


In [42]:
with open("corrected_schema.json") as f:
    print(f.read())

{
  "salesperson": {
    "description": "Contains sales performance data for each salesperson.",
    "businessentityid": {
      "type": "INTEGER",
      "notnull": true,
      "pk": true,
      "fk": {
        "table": "person",
        "column": "businessentityid"
      },
      "description": "Primary key. Unique identifier for the salesperson."
    },
    "territoryid": {
      "type": "INTEGER",
      "notnull": false,
      "fk": {
        "table": "salesterritory",
        "column": "territoryid"
      },
      "description": "Sales territory in which the salesperson operates."
    },
    "salesquota": {
      "type": "FLOAT",
      "notnull": false,
      "description": "Sales quota for the salesperson."
    },
    "bonus": {
      "type": "FLOAT",
      "notnull": false,
      "description": "Bonus awarded to the salesperson."
    },
    "commissionpct": {
      "type": "FLOAT",
      "notnull": false,
      "description": "Commission percentage for the salesperson."
    },
  

In [44]:
from google.colab import files  # for Colab, or skip if on Jupyter

#files.download("business_rules_triplets.json")
files.download("corrected_schema.json")
files.download("schema_triplets.json")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [48]:
# Save corrected_schema.json
with open("corrected_schema.json", "w") as f:
    json.dump(corrected_schema, f, indent=2)
files.download("corrected_schema.json")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>