In [19]:
import os, sys 
import pandas as pd 
import numpy as np
import streamlit as st
import sqlparse
from collections import OrderedDict, Counter
from github import Github
from databricks import sql 
import streamlit_authenticator as stauth
import yaml 
from yaml.loader import SafeLoader
from dotenv import load_dotenv
load_dotenv()
import streamlit.components.v1 as components

# LLM libraries
from langchain_core.prompts import PromptTemplate
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain.chains.llm import LLMChain
from langchain_openai import ChatOpenAI

In [2]:
def list_catalog_schema_tables():
    with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
                    http_path       = os.getenv("DATABRICKS_HTTP_PATH"),
                    access_token    = os.getenv("DATABRICKS_ACCESS_TOKEN")) as connection:
        with connection.cursor() as cursor:
            # cursor.catalogs()
            # result_catalogs = cursor.fetchall()

            # cursor.schemas()
            # result_schemas = cursor.fetchall()

            cursor.tables()
            result_tables = cursor.fetchall()

            return result_tables

In [3]:
test_results = list_catalog_schema_tables()

In [4]:
test_results

[Row(TABLE_CAT='hive_metastore', TABLE_SCHEM='dev_tools', TABLE_NAME='sqlgenpro_user_query_history', TABLE_TYPE='', REMARKS='UNKNOWN', TYPE_CAT=None, TYPE_SCHEM=None, TYPE_NAME=None, SELF_REFERENCING_COL_NAME=None, REF_GENERATION=None),
 Row(TABLE_CAT='hive_metastore', TABLE_SCHEM='online_food_business', TABLE_NAME='menu_items', TABLE_TYPE='', REMARKS='UNKNOWN', TYPE_CAT=None, TYPE_SCHEM=None, TYPE_NAME=None, SELF_REFERENCING_COL_NAME=None, REF_GENERATION=None),
 Row(TABLE_CAT='hive_metastore', TABLE_SCHEM='online_food_business', TABLE_NAME='order_details', TABLE_TYPE='', REMARKS='UNKNOWN', TYPE_CAT=None, TYPE_SCHEM=None, TYPE_NAME=None, SELF_REFERENCING_COL_NAME=None, REF_GENERATION=None),
 Row(TABLE_CAT='hive_metastore', TABLE_SCHEM='online_food_business', TABLE_NAME='orders', TABLE_TYPE='', REMARKS='UNKNOWN', TYPE_CAT=None, TYPE_SCHEM=None, TYPE_NAME=None, SELF_REFERENCING_COL_NAME=None, REF_GENERATION=None),
 Row(TABLE_CAT='hive_metastore', TABLE_SCHEM='online_food_business', TABLE

In [5]:
df_databricks = pd.DataFrame(test_results).iloc[:,:4]
df_databricks.columns=["catalog","schema","table","table_type"]

In [6]:
df_databricks

Unnamed: 0,catalog,schema,table,table_type
0,hive_metastore,dev_tools,sqlgenpro_user_query_history,
1,hive_metastore,online_food_business,menu_items,
2,hive_metastore,online_food_business,order_details,
3,hive_metastore,online_food_business,orders,
4,hive_metastore,online_food_business,payments,
5,hive_metastore,online_food_business,restaurants,
6,hive_metastore,online_food_business,reviews,
7,hive_metastore,online_food_business,users,
8,hive_metastore,test_demo,users_info,
9,samples,nyctaxi,trips,


## Creating ERD:

In [13]:
def create_erd_diagram(catalog,schema,tables_list):

    table_schema = {}


    # Iterating through each selected tables and get the list of columns for each table.
    for table in tables_list:

        conn = sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
                        http_path       = os.getenv("DATABRICKS_HTTP_PATH"),
                        access_token    = os.getenv("DATABRICKS_ACCESS_TOKEN"))
            
        query = f"DESCRIBE TABLE `{catalog}`.{schema}.{table}"
        df = pd.read_sql(sql=query,con=conn)
        cols = df['col_name'].tolist()
        col_types = df['data_type'].tolist()
        cols_dict = [f"{col} : {col_type}" for col,col_type in zip(cols,col_types)]
        table_schema[table] = cols_dict
    
    print(table_schema)

    # Generating the mermaid code for the ERD diagram
    ### Defining the prompt template
    template_string = """ 
    You are an expert in creating ERD diagrams (Entity Relationship Diagrams) for databases. 
    You have been given the task to create an ERD diagram for the selected tables in the database. 
    The ERD diagram should contain the tables and the columns present in the tables. 
    You need to generate the Mermaid code for the complete ERD diagram.
    Make sure the ERD diagram is clear and easy to understand with proper relationships details.

    The selected tables in the database are given below (delimited by ##) in the dictionary format: Keys being the table names and values being the list of columns and their datatype in the table.

    ##
    {table_schema}
    ##

    Before generating the mermaid code, validate it and make sure it is correct and clear.     
    Give me the final mermaid code for the ERD diagram after proper analysis.
    """

    prompt_template = PromptTemplate.from_template(template_string)

    ### Defining the LLM chain
    llm_chain = LLMChain(
        llm=ChatOpenAI(model="gpt-4o-mini",temperature=0),
        prompt=prompt_template
    )

    response =  llm_chain.invoke({"table_schema":table_schema})
    output = response['text']    
    return output

In [14]:
# Selected tables list 
catalog = "hive_metastore"
schema = "online_food_business"
tables_list = ["menu_items","orders","users"]

In [15]:
output = create_erd_diagram(catalog,schema,tables_list)

  df = pd.read_sql(sql=query,con=conn)


{'menu_items': ['menu_item_id : string', 'restaurant_id : string', 'item_name : string', 'price : double'], 'orders': ['order_id : string', 'user_id : string', 'order_time : timestamp', 'delivery_address : string', 'order_status : string', 'restaurant_id : string', 'total_amount : double'], 'users': ['user_id : string', 'name : string', 'gender : string', 'email : string', 'phone_number : string', 'delivery_address : string']}


In [12]:
print(output)

To create an Entity Relationship Diagram (ERD) using the provided tables and their columns, we first need to analyze the relationships between the tables. 

### Analysis of Tables and Relationships:

1. **Tables and Columns**:
   - **menu_items**:
     - menu_item_id: string (Primary Key)
     - restaurant_id: string (Foreign Key)
     - item_name: string
     - price: double
     
   - **orders**:
     - order_id: string (Primary Key)
     - user_id: string (Foreign Key)
     - order_time: timestamp
     - delivery_address: string
     - order_status: string
     - restaurant_id: string (Foreign Key)
     - total_amount: double
     
   - **users**:
     - user_id: string (Primary Key)
     - name: string
     - gender: string
     - email: string
     - phone_number: string
     - delivery_address: string

2. **Relationships**:
   - **users** to **orders**: One-to-Many (One user can place many orders)
   - **menu_items** to **orders**: Many-to-One (Many menu items can belong to one o

In [16]:
# Function to render the mermaid diagram
def process_llm_response_for_mermaid(response: str) -> str:
    # Extract the Mermaid code block from the response
    start_idx = response.find("```mermaid") + len("```mermaid")
    end_idx = response.find("```", start_idx)
    mermaid_code = response[start_idx:end_idx].strip()

    return mermaid_code

In [17]:
cleaned_mermaid_code = process_llm_response_for_mermaid(output)

In [18]:
print(cleaned_mermaid_code)

erDiagram
    USERS {
        string user_id PK
        string name
        string gender
        string email
        string phone_number
        string delivery_address
    }
    
    ORDERS {
        string order_id PK
        string user_id FK
        timestamp order_time
        string delivery_address
        string order_status
        string restaurant_id
        double total_amount
    }
    
    MENU_ITEMS {
        string menu_item_id PK
        string restaurant_id
        string item_name
        double price
    }

    USERS ||--o{ ORDERS : places
    ORDERS }o--o{ MENU_ITEMS : contains


In [20]:
def mermaid(code: str) -> None:
    # Escaping backslashes for special characters in the code
    code_escaped = code.replace("\\", "\\\\").replace("`", "\\`")
    
    # components.html(
    #     f"""
    #     <div id="mermaid-container" style="width: 100%; height: 100%; overflow: auto;">
    #         <pre class="mermaid">
    #             {code_escaped}
    #         </pre>
    #     </div>

    #     <script type="module">
    #         import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
    #         mermaid.initialize({{ startOnLoad: true }});
    #     </script>
    #     """,
    #     height=800  # You can adjust the height as needed
    # )       
    components.html(
        f"""
        <div id="mermaid-container" style="width: 100%; height: 800px; overflow: auto;">
            <pre class="mermaid">
                {code_escaped}
            </pre>
        </div>

        <script type="module">
            import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
            mermaid.initialize({{ startOnLoad: true }});
        </script>
        """,
        height=800  # You can adjust the height as needed
    )