In [1]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")

In [10]:
from IPython.display import display, HTML

db.get_context()
db.get_usable_table_names()

['Album',
 'Artist',
 'Customer',
 'Employee',
 'Genre',
 'Invoice',
 'InvoiceLine',
 'MediaType',
 'Playlist',
 'PlaylistTrack',
 'Track']

In [8]:
import re
from IPython.display import display, HTML

def visualize_db_schema(db_context):
    """Convert database context to clean HTML visualization"""
    
    # Extract table info
    table_info = db_context['table_info']
    
    # Split by CREATE TABLE statements
    tables = re.split(r'\n\nCREATE TABLE', table_info)
    
    html_output = """
    <style>
        .db-table { 
            border: 2px solid #4CAF50; 
            margin: 15px 0; 
            border-radius: 8px; 
            overflow: hidden;
        }
        .table-header { 
            background: #4CAF50; 
            color: white; 
            padding: 10px; 
            font-weight: bold; 
            font-size: 16px;
        }
        .table-schema { 
            padding: 10px; 
            font-family: monospace; 
            font-size: 12px;
            white-space: pre-wrap;
        }
        .sample-data { 
            padding: 10px; 
            border-top: 1px solid #ddd;
        }
        .sample-header { 
            font-weight: bold; 
            color: #2E7D32; 
            margin-bottom: 5px;
        }
    </style>
    """
    
    for i, table in enumerate(tables):
        if not table.strip():
            continue
            
        # Add CREATE TABLE back if it was split
        if i > 0:
            table = "CREATE TABLE" + table
            
        # Extract table name
        table_name_match = re.search(r'CREATE TABLE "(\w+)"', table)
        if not table_name_match:
            continue
        table_name = table_name_match.group(1)
        
        # Split schema and sample data
        parts = table.split('/*')
        schema_part = parts[0].strip()
        
        sample_data = ""
        if len(parts) > 1:
            sample_match = re.search(r'\n3 rows from.*?\n(.*?)\n\*/', table, re.DOTALL)
            if sample_match:
                sample_data = sample_match.group(1).strip()
        
        # Build HTML for this table
        html_output += f"""
        <div class="db-table">
            <div class="table-header">📊 {table_name} Table</div>
            <div class="table-schema">{schema_part}</div>
            <div class="sample-data">
                <div class="sample-header">Sample Data:</div>
                <pre style="margin:0; font-size:11px;">{sample_data}</pre>
            </div>
        </div>
        """
    
    return html_output

# Get database context and visualize
# context = db.get_context()
# html_viz = visualize_db_schema(context)
# display(HTML(html_viz))
