In [2]:
import os
import requests
import sys
import time
import json
import networkx as nx
import community as community_louvain # python-louvain library
import matplotlib.pyplot as plt
import random
import traceback

# --- Configuration ---
# !!! IMPORTANT: Change this path to the 'database' directory of the Spider dataset on your machine !!!
spider_database_path = "../../database" 
# Example: spider_database_path = "C:/Users/YourUser/Downloads/spider/database" 
#    or spider_database_path = "/home/youruser/datasets/spider/database"

api_url = "http://localhost:9191/api/v1/db/get-schema/sqlite"
# --- End of configuration ---

# --- Function to call API to get Schema ---
def fetch_schema_from_api(db_file_path, endpoint_url):
    """
    Read SQLite file, send to API and return the 'data' part of the schema if successful.
    """
    try:
        with open(db_file_path, 'rb') as f:
            db_content = f.read()
        
        if not db_content:
             print(f"  Error: File is empty or cannot be read: {db_file_path}")
             return None

        file_name = os.path.basename(db_file_path)
        files = {'file': (file_name, db_content, 'application/octet-stream')}

        print(f"  Sending '{file_name}' to API...")
        response = requests.post(endpoint_url, files=files, timeout=60) 
        response.raise_for_status() 
        result = response.json()

        if result.get("code") == 0:
            schema_data = result.get("data", {})
            if not schema_data or 'tables' not in schema_data: # Check if 'tables' exists
                print(f"  Error: API returned success code but missing 'tables' data for '{file_name}'.")
                return None
            print(f"  Success: Schema retrieved for '{file_name}'.")
            return schema_data # Only return the 'data' part containing the schema
        else:
            error_message = result.get("message", "Unknown error from API")
            print(f"  API Error: '{file_name}' -> {error_message} (Error code: {result.get('code')})")
            return None

    except FileNotFoundError:
        print(f"  Error: File not found: {db_file_path}")
        return None
    except requests.exceptions.ConnectionError:
        print(f"  Error: Cannot connect to API at {endpoint_url}.")
        return None # Can add sys.exit(1) if you want to stop completely
    except requests.exceptions.Timeout:
        print(f"  Error: API request timed out for file '{file_name}'.")
        return None
    except requests.exceptions.RequestException as e:
        print(f"  Request Error: '{file_name}' -> {e}")
        return None
    except requests.exceptions.JSONDecodeError:
        print(f"  Error: Cannot decode JSON from API response for '{file_name}'.")
        print(f"  Response received (first 100 characters): {response.text[:100]}...")
        return None
    except Exception as e:
        print(f"  Unexpected error when calling API: '{file_name}' -> {e}")
        return None

## --- Function to perform clustering from Schema data (UPDATED) ---
def cluster_schema_from_data(schema_data, 
                             db_identifier="Unknown DB", 
                             plot_graph=False, 
                             output_dir="cluster_plots", 
                             resolution_value=1.0):
    """
    Analyze schema_data, build graph and cluster with customizable resolution.
    Input: schema_data (dict), db_identifier (str), plot_graph (bool), resolution_value (float)
    Output: List of clusters (list of lists) or None if error.
    """
    print(f"  -> Clustering schema for '{db_identifier}'...")
    try:
        tables_info = schema_data.get('tables', [])
        if not tables_info:
             print("    Error: No 'tables' information found in schema_data.")
             return [] 
        
        tables = [table['tableIdentifier'] for table in tables_info]
        
        foreign_keys = []
        for table_info in tables_info:
            source_table = table_info['tableIdentifier']
            for column_info in table_info.get('columns', []):
                if column_info.get('relations'): 
                    for relation in column_info['relations']:
                        target_table = relation.get('tableIdentifier')
                        if target_table and source_table in tables and target_table in tables and source_table != target_table:
                            foreign_keys.append((source_table, target_table))
        
        foreign_keys = list(set(foreign_keys)) 
        print(f"    Extracted: {len(tables)} tables, {len(foreign_keys)} foreign keys.")

        # --- Step 2: Build Graph ---
        G = nx.Graph()
        G.add_nodes_from(tables)
        G.add_edges_from(foreign_keys)

        if G.number_of_nodes() == 0:
            print("    Notice: No tables to create graph.")
            return []

        # --- Step 3: Apply Community Algorithm (Louvain) ---
        if G.number_of_edges() > 0:
            print(f"    Applying Louvain with resolution = {resolution_value}") 
            partition = community_louvain.best_partition(G, resolution=resolution_value)
        else:
            partition = {node: i for i, node in enumerate(G.nodes())}
            print("    Warning: Graph has no edges. Each table is its own cluster.")
            
        # --- Step 4: Process Clustering Results ---
        communities = {}
        for node, community_id in partition.items():
            if community_id not in communities:
                communities[community_id] = []
            communities[community_id].append(node)

        print(f"    Detected {len(communities)} clusters.")
        cluster_list = []
        sorted_communities = sorted(communities.values(), key=len, reverse=True)

        for i, nodes in enumerate(sorted_communities):
            sorted_nodes = sorted(nodes) 
            print(f"      Cluster {i + 1} (Size: {len(sorted_nodes)}): {sorted_nodes}") 
            cluster_list.append(sorted_nodes)
            
        # --- Step 5: Draw Graph (Optional) ---
        if plot_graph and G.number_of_nodes() > 0:
            try:
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                    print(f"    Created directory for graph storage: {output_dir}")

                # Increase figure size for more space
                plt.figure(figsize=(20, 16))
                
                # Use spring_layout with smaller k to compress edges
                # k is a parameter that determines the distance between nodes, smaller value creates a more compact graph
                pos = nx.spring_layout(G, k=0.3, seed=42)  # seed for consistent results
                # Or can use nx.kamada_kawai_layout with scale adjustment:
                # pos = nx.kamada_kawai_layout(G)
                # pos = {node: (x*0.8, y*0.8) for node, (x, y) in pos.items()}  # Reduce distance

                unique_community_ids = set(partition.values())
                num_communities = len(unique_community_ids)
                colors = plt.cm.get_cmap('tab20', max(num_communities, 2))  # Use a richer color palette
                
                community_color_map = {community_id: colors(i % 20) for i, community_id in enumerate(unique_community_ids)}
                node_colors = [community_color_map[partition[node]] for node in G.nodes()]

                # Draw nodes with larger size and font
                nx.draw_networkx_nodes(G, pos, 
                                     node_size=2000,  # Increase node size
                                     node_color=node_colors,
                                     alpha=0.8)  # Slight transparency
                
                # Draw shorter edges with lighter color
                nx.draw_networkx_edges(G, pos,
                                     width=1.5,  # Increase edge thickness
                                     alpha=0.5,  # Reduce opacity to not overwhelm nodes
                                     edge_color='grey')
                
                # Draw labels with larger and bolder font
                nx.draw_networkx_labels(G, pos, 
                                      font_size=14,  # Increase font size
                                      font_weight='bold',
                                      font_family='sans-serif')
                
                # Adjust axis limits to ensure labels aren't cut off
                plt.axis('off')  # Turn off axis
                plt.tight_layout()  # Optimize space usage
                
                # Create safe filename from db_identifier
                safe_db_name = "".join(c if c.isalnum() else "_" for c in db_identifier)
                plot_filename = os.path.join(output_dir, f"clusters_{safe_db_name}.png")
                plt.title(f"Clusters for {db_identifier}", fontsize=18, fontweight='bold')  # Increase title size
                
                # Save image with higher DPI
                plt.savefig(plot_filename, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"    Saved cluster graph to: {plot_filename}")
            except Exception as plot_err:
                print(f"    Error when drawing or saving graph: {plot_err}")

        return cluster_list

    except Exception as e:
        print(f"  Error during clustering for '{db_identifier}': {e}")
        traceback.print_exc()
        return None

# --- Main Script Logic ---
if not os.path.isdir(spider_database_path):
    print(f"ERROR: Path '{spider_database_path}' does not exist or is not a directory.")
    sys.exit(1)

print(f"Starting to scan and retrieve schemas from SQLite files in: {spider_database_path}")
print(f"API Endpoint: {api_url}")

processed_db_count = 0
fetch_success_count = 0
cluster_success_count = 0
failed_items = [] # Store filenames and failure reasons
all_database_clusters = {} # Store clustering results: {db_full_path: cluster_list}

# --- Flags to enable/disable graph drawing ---
GENERATE_PLOTS = True # Set to True if you want to create graph image files for each DB
PLOT_OUTPUT_DIR = "spider_cluster_plots" # Directory to save graph images (if GENERATE_PLOTS=True)
# -----------------------------

for dirpath, dirnames, filenames in os.walk(spider_database_path):
    filenames.sort() 
    for filename in filenames:
        if filename.lower().endswith(".sqlite"):
            db_full_path = os.path.join(dirpath, filename)
            db_folder_name = os.path.basename(dirpath)
            db_identifier = f"{db_folder_name}/{filename}" # Identifier for database
            
            print(f"\nProcessing: {db_identifier}")
            processed_db_count += 1
            
            # 1. Get schema from API
            schema_data = fetch_schema_from_api(db_full_path, api_url)
            
            if schema_data:
                fetch_success_count += 1
                # 2. If schema retrieval is successful, proceed with clustering
                clusters = cluster_schema_from_data(schema_data, 
                                                    db_identifier=db_identifier, 
                                                    plot_graph=GENERATE_PLOTS,
                                                    output_dir=PLOT_OUTPUT_DIR)
                
                if clusters is not None:
                    # Clustering successful
                    cluster_success_count += 1
                    all_database_clusters[db_full_path] = clusters
                else:
                    # Clustering failed (even though schema was retrieved)
                    failed_items.append(f"{db_full_path} (Clustering error)")
            else:
                # Schema retrieval failed
                failed_items.append(f"{db_full_path} (API/Fetch error)")

            # time.sleep(0.05) # Add small delay if needed

print("\n--- Processing Complete ---")
print(f"Total .sqlite files processed: {processed_db_count}")
print(f"Number of successful schema retrievals: {fetch_success_count}")
print(f"Number of successful clusterings: {cluster_success_count}")
print(f"Number of failed processes (API or clustering): {len(failed_items)}")

if failed_items:
    print("\nList of failed items:")
    for item in failed_items:
        print(f"- {item}")

# Now the variable `all_database_clusters` contains clustering results for each database
# Example of how to access results:
# if all_database_clusters:
#     first_db = list(all_database_clusters.keys())[0]
#     print(f"\nExample clustering results for first database ({first_db}):")
#     for i, cluster in enumerate(all_database_clusters[first_db]):
#         print(f"  Cluster {i+1}: {cluster}")

Starting to scan and retrieve schemas from SQLite files in: E:/Workspace/Repositories/thesis/test/pipeline/SPIDER/database
API Endpoint: http://localhost:9191/api/v1/db/get-schema/sqlite

Processing: aan_1/aan_1.sqlite
  Sending 'aan_1.sqlite' to API...
  Success: Schema retrieved for 'aan_1.sqlite'.
  -> Clustering schema for 'aan_1/aan_1.sqlite'...
    Extracted: 5 tables, 4 foreign keys.
    Applying Louvain with resolution = 1.0
    Detected 2 clusters.
      Cluster 1 (Size: 3): ['Affiliation', 'Author', 'Author_list']
      Cluster 2 (Size: 2): ['Citation', 'Paper']
    Created directory for graph storage: spider_cluster_plots


  colors = plt.cm.get_cmap('tab20', max(num_communities, 2))  # Use a richer color palette


    Saved cluster graph to: spider_cluster_plots\clusters_aan_1_aan_1_sqlite.png

Processing: academic/academic.sqlite
  Sending 'academic.sqlite' to API...
  Success: Schema retrieved for 'academic.sqlite'.
  -> Clustering schema for 'academic/academic.sqlite'...
    Extracted: 15 tables, 17 foreign keys.
    Applying Louvain with resolution = 1.0
    Detected 6 clusters.
      Cluster 1 (Size: 3): ['author', 'domain_author', 'writes']
      Cluster 2 (Size: 3): ['conference', 'domain', 'domain_conference']
      Cluster 3 (Size: 3): ['domain_keyword', 'keyword', 'publication_keyword']
      Cluster 4 (Size: 3): ['cite', 'domain_publication', 'publication']
      Cluster 5 (Size: 2): ['domain_journal', 'journal']
      Cluster 6 (Size: 1): ['organization']
    Saved cluster graph to: spider_cluster_plots\clusters_academic_academic_sqlite.png

Processing: activity_1/activity_1.sqlite
  Sending 'activity_1.sqlite' to API...
  Success: Schema retrieved for 'activity_1.sqlite'.
  -> Clust