In [2]:
import re

In [1]:
with open('final-poc01.sql', 'r') as file:
    sql_content = file.read()
print("File content loaded successfully")

File content loaded successfully


In [3]:
import re
from collections import defaultdict
from typing import Dict, Set, Tuple

class SQLSchemaAnalyzer:
    def __init__(self, sql_code: str):
        self.sql_code = sql_code
        self.tables_dependencies = defaultdict(set)
        self.schema_tables = defaultdict(set)
        
    def _split_schema_table(self, full_table_name: str) -> Tuple[str, str]:
        """Split a full table name into schema and table components."""
        parts = full_table_name.replace(';', '').strip().split('.')
        if len(parts) == 2:
            return parts[0], parts[1]
        return 'DEFAULT', parts[0]
    
    def _clean_table_name(self, table_name: str) -> str:
        """Remove any SQL specifics from table name."""
        return table_name.split(' ')[0].strip()
    
    def _get_cte_names(self, sql_block: str) -> Set[str]:
        """Extract names of CTEs (Common Table Expressions) from a SQL block."""
        # Find all CTE names after WITH
        cte_pattern = r'WITH\s+([^(]+?)\s+AS\s*\('
        ctes = set()
        
        # Handle multiple CTEs
        cte_blocks = re.finditer(r'WITH\s+(.*?)SELECT', sql_block, re.DOTALL | re.IGNORECASE)
        for cte_block in cte_blocks:
            cte_text = cte_block.group(1)
            # Split multiple CTEs
            for cte in re.finditer(r'([^\s,]+)\s+AS\s*\(', cte_text, re.IGNORECASE):
                ctes.add(cte.group(1).strip())
                
        return ctes
    
    def _extract_dependencies(self, sql_block: str, target_table: str, cte_names: Set[str]) -> Set[str]:
        """Extract table dependencies excluding CTEs."""
        dependencies = set()
        
        # Find all table references in FROM and JOIN clauses
        from_tables = re.findall(r'FROM\s+([^\s(]+)', sql_block, re.IGNORECASE)
        join_tables = re.findall(r'JOIN\s+([^\s(]+)', sql_block, re.IGNORECASE)
        
        # Process all found tables
        for table in from_tables + join_tables:
            table = self._clean_table_name(table)
            # Ignore CTEs and self-references
            if table not in cte_names and table != target_table:
                dependencies.add(table)
        
        return dependencies
    
    def extract_table_info(self):
        """Extract table creation and dependency information from SQL code."""
        # Find all CREATE TABLE statements
        create_statements = re.finditer(
            r'CREATE\s+TABLE\s+([^\s(]+)',
            self.sql_code,
            re.IGNORECASE
        )
        
        for create_match in create_statements:
            target_table = self._clean_table_name(create_match.group(1))
            schema, table = self._split_schema_table(target_table)
            self.schema_tables[schema].add(table)
            
            # Find the corresponding CREATE TABLE block
            start_idx = create_match.start()
            next_create = self.sql_code.find('CREATE TABLE', start_idx + 1)
            block_end = len(self.sql_code) if next_create == -1 else next_create
            block = self.sql_code[start_idx:block_end]
            
            # Get CTEs to ignore
            cte_names = self._get_cte_names(block)
            
            # Extract dependencies excluding CTEs
            dependencies = self._extract_dependencies(block, target_table, cte_names)
            
            # Process dependencies
            for source in dependencies:
                source_schema, source_table = self._split_schema_table(source)
                self.schema_tables[source_schema].add(source_table)
                self.tables_dependencies[target_table].add(source)
    
    def generate_mermaid_diagram(self) -> str:
        """Generate a Mermaid diagram showing schemas as subgraphs and table dependencies."""
        mermaid_lines = ["graph TD"]
        
        # Create subgraphs for each schema
        for schema, tables in self.schema_tables.items():
            mermaid_lines.append(f"    subgraph {schema}")
            for table in sorted(tables):
                # Use schema_table as node ID to ensure uniqueness
                node_id = f"{schema}_{table}"
                mermaid_lines.append(f"        {node_id}[\"{table}\"]")
            mermaid_lines.append("    end")
        
        # Add dependencies
        mermaid_lines.append("")  # Empty line for readability
        for target, sources in self.tables_dependencies.items():
            target_schema, target_table = self._split_schema_table(target)
            if(target_schema != 'DEFAULT'):
                target_id = f"{target_schema}_{target_table}"
            
            for source in sources:
                source_schema, source_table = self._split_schema_table(source)
                if(source_schema != 'DEFAULT'):
                    source_id = f"{source_schema}_{source_table}"
                
                if(target_schema != 'DEFAULT' and source_schema != 'DEFAULT'):
                    mermaid_lines.append(f"    {source_id} --> {target_id}")
        
        return "\n".join(mermaid_lines)

def analyze_sql(sql_code: str) -> str:
    """Analyze SQL code and return a Mermaid diagram."""
    analyzer = SQLSchemaAnalyzer(sql_code)
    analyzer.extract_table_info()
    return analyzer.generate_mermaid_diagram()

# Example usage
if __name__ == "__main__":
    # Example SQL code
    sql_code = sql_content.upper()
    
    mermaid_diagram = analyze_sql(sql_code)
    print(mermaid_diagram)

graph TD
    subgraph MASTER_S_DEGREE
        MASTER_S_DEGREE_ADULT_PATIENTS["ADULT_PATIENTS"]
        MASTER_S_DEGREE_ALL_ADULT_ICU_PATIENTS["ALL_ADULT_ICU_PATIENTS"]
        MASTER_S_DEGREE_AVG_HOURLY_VITAL_SIGNS["AVG_HOURLY_VITAL_SIGNS"]
        MASTER_S_DEGREE_FILTERED_HOURLY_VITAL_SIGNS["FILTERED_HOURLY_VITAL_SIGNS"]
        MASTER_S_DEGREE_FILTERED_HOURLY_VITAL_SIGNS_24H_NO_OUTLIERS["FILTERED_HOURLY_VITAL_SIGNS_24H_NO_OUTLIERS"]
        MASTER_S_DEGREE_FIRST_ICU_STAYS["FIRST_ICU_STAYS"]
        MASTER_S_DEGREE_HOURLY_VITAL_SIGNS["HOURLY_VITAL_SIGNS"]
        MASTER_S_DEGREE_HOURLY_VITAL_SIGNS_FILTERED["HOURLY_VITAL_SIGNS_FILTERED"]
        MASTER_S_DEGREE_HOURLY_VITAL_SIGNS_FILTERED_WITH_TRENDS["HOURLY_VITAL_SIGNS_FILTERED_WITH_TRENDS"]
        MASTER_S_DEGREE_HOURLY_VITAL_SIGNS_NEW_SEPSIS["HOURLY_VITAL_SIGNS_NEW_SEPSIS"]
        MASTER_S_DEGREE_ICU_STAYS_MIN_DURATION["ICU_STAYS_MIN_DURATION"]
        MASTER_S_DEGREE_SEPSIS_3_PATIENTS["SEPSIS_3_PATIENTS"]
        MASTER_S_DEGREE_

In [8]:
# Enhanced pattern to match both schema and table names
create_table_pattern = r'CREATE TABLE (\w+\.\w+)'

# Find all CREATE TABLE statements and their dependencies
schemas_and_tables = set()  # Using set to avoid duplicates
dependencies = set()  # Store tables used in joins, merges, etc.

# Find explicit table creations
for match in re.finditer(create_table_pattern, sql_content):
    schemas_and_tables.add(match.group(1))

# Find dependencies (tables used in FROM and JOIN clauses)
from_pattern = r'FROM\s+(\w+\.\w+)'
join_pattern = r'JOIN\s+(\w+\.\w+)'

# Add tables from FROM clauses
for match in re.finditer(from_pattern, sql_content):
    dependencies.add(match.group(1))

# Add tables from JOIN clauses
for match in re.finditer(join_pattern, sql_content):
    dependencies.add(match.group(1))

print("Created Tables:")
for table in sorted(schemas_and_tables):
    print(f"  - {table}")

print("\nDependencies (Referenced Tables):")
for table in sorted(dependencies):
    print(f"  - {table}")

Created Tables:
  - MASTER_S_DEGREE.ADULT_PATIENTS
  - MASTER_S_DEGREE.ALL_ADULT_ICU_PATIENTS
  - MASTER_S_DEGREE.AVG_HOURLY_VITAL_SIGNS
  - MASTER_S_DEGREE.FILTERED_HOURLY_VITAL_SIGNS
  - MASTER_S_DEGREE.FILTERED_HOURLY_VITAL_SIGNS_24H_NO_OUTLIERS
  - MASTER_S_DEGREE.FIRST_ICU_STAYS
  - MASTER_S_DEGREE.HOURLY_VITAL_SIGNS
  - MASTER_S_DEGREE.HOURLY_VITAL_SIGNS_FILTERED
  - MASTER_S_DEGREE.HOURLY_VITAL_SIGNS_FILTERED_WITH_TRENDS
  - MASTER_S_DEGREE.HOURLY_VITAL_SIGNS_NEW_SEPSIS
  - MASTER_S_DEGREE.ICU_STAYS_MIN_DURATION
  - MASTER_S_DEGREE.SEPSIS_3_PATIENTS
  - MASTER_S_DEGREE.SEPSIS_3_PATIENTS_EXCLUDED_PREEXISTING
  - MASTER_S_DEGREE.SEPSIS_AND_SHOCK_PATIENTS
  - MASTER_S_DEGREE.TMP_HIGH_LACTATE
  - MASTER_S_DEGREE.TMP_HYPOTENSION
  - MASTER_S_DEGREE.TMP_SOFA_SCORE
  - MASTER_S_DEGREE.TMP_VASOPRESSOR_USE
  - MASTER_S_DEGREE.TMP_VITAL_SIGNS
  - MASTER_S_DEGREE.VITAL_SIGNS_6H_WINDOW

Dependencies (Referenced Tables):
  - ADM.ADMITTIME
  - MASTER_S_DEGREE.ADULT_PATIENTS
  - MASTER_S_DEGRE

In [10]:
import re

# Exemplo de entrada SQL
sql_input = """
CREATE VIEW view_teste AS
SELECT a.id, b.nome
FROM tabela_a a, tabela_b b
JOIN tabela_c c ON b.id = c.id;
"""

# Padrão para extrair o nome da view
view_pattern = r'CREATE VIEW (\w+) AS'

# Padrões para extrair tabelas das cláusulas FROM e JOIN
from_pattern = r'FROM\s+([\w\.]+)(?:\s+\w+)?(?:\s*,\s*([\w\.]+)(?:\s+\w+)?)*'
join_pattern = r'JOIN\s+([\w\.]+)(?:\s+\w+)?'

# Extrair o nome da view
view_match = re.search(view_pattern, sql_input, re.IGNORECASE)
view_name = view_match.group(1) if view_match else None

# Extrair tabelas da cláusula FROM
from_tables = set()
from_matches = re.finditer(from_pattern, sql_input, re.IGNORECASE)
for match in from_matches:
    tables = [table for table in match.groups() if table]
    from_tables.update(tables)

# Extrair tabelas da cláusula JOIN
join_tables = set()
join_matches = re.finditer(join_pattern, sql_input, re.IGNORECASE)
for match in join_matches:
    join_tables.add(match.group(1))

# Juntar todas as dependências
dependencies = from_tables.union(join_tables)

# Exibir resultados
if view_name:
    print(f"View: {view_name}")
    print("Dependências diretas:")
    for table in sorted(dependencies):
        print(f"  - {table}")
else:
    print("Nenhuma view encontrada.")

View: view_teste
Dependências diretas:
  - tabela_a
  - tabela_b
  - tabela_c
