# Database Query and Data Management Notebook

This notebook provides comprehensive database querying and data management capabilities using `db_connection.py` and logic from `test_json_to_db.py`.

## Features:
- **Database Connection**: Uses `db_connection.py` for secure database connections
- **SQL Querying**: Execute custom SQL queries and view results
- **Data Insertion**: Insert data back to database using proven logic
- **Table Management**: View table schemas, counts, and sample data
- **Data Analysis**: Analyze and visualize query results


In [1]:
# Activate sp500 virtual environment
import subprocess
import sys
import os

# Activate the sp500 virtual environment
venv_path = "../sp500/bin/activate"
if os.path.exists(venv_path):
    # Add the virtual environment's Python to the path
    venv_python = os.path.abspath("../sp500/bin/python")
    if venv_python not in sys.executable:
        print(f"Note: Please run this notebook with the sp500 virtual environment activated")
        print(f"Run: source ../sp500/bin/activate")
        print(f"Or use: {venv_python} -m jupyter notebook")
else:
    print("Warning: sp500 virtual environment not found at ../sp500/bin/activate")

# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sqlalchemy import text, inspect
from sqlalchemy.orm import sessionmaker
from pathlib import Path
from typing import Dict, List, Any, Optional

# Add database directory to path
current_dir = Path.cwd()
sys.path.append(str(current_dir))

# Import database modules
try:
    from db_connection import engine, Session
    from config.config import Config
    from models.sec_facts_raw import BronzeSecFacts, BronzeSecFactsDict
    from models.sec_submissions_raw import BronzeSecSubmissions
    print("Database modules imported successfully")
except ImportError as e:
    print(f"Error importing database modules: {e}")
    print("Make sure you're running this notebook from the database directory")


Note: Please run this notebook with the sp500 virtual environment activated
Run: source ../sp500/bin/activate
Or use: /Users/ssp/Documents/MS_CS/Projects_git/sp500_agentic_ai/backend/sp500/bin/python -m jupyter notebook
Error importing database modules: No module named 'pymysql'
Make sure you're running this notebook from the database directory


In [None]:
# Database Connection Test
def test_database_connection():
    """Test database connection and display basic info"""
    try:
        with Session() as session:
            # Test basic connection
            result = session.execute(text("SELECT 1 as test")).fetchone()
            print(f"Database connection successful: {result[0]}")
            
            # Get database info
            db_info = session.execute(text("SELECT DATABASE() as db_name")).fetchone()
            print(f"Connected to database: {db_info[0]}")
            
            return True
    except Exception as e:
        print(f"Database connection failed: {e}")
        return False

# Test connection
test_database_connection()


In [None]:
# Database Table Information
def get_table_info():
    """Get information about all tables in the database"""
    try:
        with Session() as session:
            # Get all table names
            tables_query = text("""
                SELECT TABLE_NAME, TABLE_ROWS, DATA_LENGTH, INDEX_LENGTH
                FROM information_schema.TABLES 
                WHERE TABLE_SCHEMA = DATABASE()
                ORDER BY TABLE_NAME
            """)
            
            tables_df = pd.read_sql(tables_query, engine)
            
            print("Database Tables Overview:")
            print("=" * 60)
            print(tables_df.to_string(index=False))
            
            return tables_df
    except Exception as e:
        print(f"Error getting table info: {e}")
        return None

# Get table information
tables_info = get_table_info()


In [None]:
# Record Counts for Main Tables
def get_record_counts():
    """Get record counts for main database tables"""
    try:
        with Session() as session:
            counts = {}
            
            # Count BronzeSecFacts
            try:
                facts_count = session.query(BronzeSecFacts).count()
                counts['BronzeSecFacts'] = facts_count
            except:
                counts['BronzeSecFacts'] = 0
            
            # Count BronzeSecFactsDict
            try:
                dict_count = session.query(BronzeSecFactsDict).count()
                counts['BronzeSecFactsDict'] = dict_count
            except:
                counts['BronzeSecFactsDict'] = 0
            
            # Count BronzeSecSubmissions
            try:
                submissions_count = session.query(BronzeSecSubmissions).count()
                counts['BronzeSecSubmissions'] = submissions_count
            except:
                counts['BronzeSecSubmissions'] = 0
            
            print(" Record Counts:")
            print("=" * 30)
            for table, count in counts.items():
                print(f"{table:20}: {count:,} records")
            
            total_records = sum(counts.values())
            print(f"{'Total':20}: {total_records:,} records")
            
            return counts
    except Exception as e:
        print(f" Error getting record counts: {e}")
        return {}

# Get record counts
record_counts = get_record_counts()


In [None]:
# Custom SQL Query Function
def execute_sql_query(query: str, return_df: bool = True, limit: int = 1000):
    """
    Execute a custom SQL query and return results
    
    Args:
        query: SQL query string
        return_df: Whether to return pandas DataFrame (True) or raw results (False)
        limit: Maximum number of rows to return (safety limit)
    
    Returns:
        pandas DataFrame or list of results
    """
    try:
        with Session() as session:
            # Add LIMIT if not present and query is SELECT
            if query.strip().upper().startswith('SELECT') and 'LIMIT' not in query.upper():
                query = f"{query.rstrip(';')} LIMIT {limit}"
            
            print(f" Executing query:")
            print(f"   {query}")
            print("-" * 50)
            
            if return_df:
                result_df = pd.read_sql(query, engine)
                print(f" Query executed successfully - {len(result_df)} rows returned")
                return result_df
            else:
                result = session.execute(text(query)).fetchall()
                print(f" Query executed successfully - {len(result)} rows returned")
                return result
                
    except Exception as e:
        print(f" Error executing query: {e}")
        return None

# Example usage:
# df = execute_sql_query("SELECT * FROM bronze_sec_facts LIMIT 10")


In [None]:
# Sample Queries - BronzeSecFacts Analysis
print(" Sample Queries for BronzeSecFacts Table")
print("=" * 50)

# Query 1: Top CIKs by record count
query1 = """
SELECT cik, COUNT(*) as record_count
FROM bronze_sec_facts 
GROUP BY cik 
ORDER BY record_count DESC 
LIMIT 10
"""
df1 = execute_sql_query(query1)
if df1 is not None:
    print("\nTop 10 CIKs by record count:")
    print(df1.to_string(index=False))

print("\n" + "="*50)

# Query 2: Sample records
query2 = """
SELECT cik, tag, val, end_date, filed
FROM bronze_sec_facts 
WHERE val IS NOT NULL 
ORDER BY filed DESC 
LIMIT 5
"""
df2 = execute_sql_query(query2)
if df2 is not None:
    print("\nSample recent records:")
    print(df2.to_string(index=False))


In [None]:
# Data Insertion Functions (Based on test_json_to_db.py logic)
def insert_dataframe_to_table(df: pd.DataFrame, table_name: str, if_exists: str = 'append'):
    """
    Insert a pandas DataFrame into a database table
    
    Args:
        df: pandas DataFrame to insert
        table_name: Target table name
        if_exists: What to do if table exists ('append', 'replace', 'fail')
    
    Returns:
        Number of rows inserted
    """
    try:
        rows_inserted = df.to_sql(
            table_name, 
            engine, 
            if_exists=if_exists, 
            index=False,
            method='multi',
            chunksize=1000
        )
        print(f" Successfully inserted {len(df)} rows into {table_name}")
        return len(df)
    except Exception as e:
        print(f" Error inserting data into {table_name}: {e}")
        return 0

def bulk_insert_records(records: List[Dict], model_class, batch_size: int = 1000):
    """
    Bulk insert records using SQLAlchemy ORM (based on test_json_to_db.py logic)
    
    Args:
        records: List of dictionaries representing records
        model_class: SQLAlchemy model class
        batch_size: Number of records to insert per batch
    
    Returns:
        Number of records inserted
    """
    try:
        with Session() as session:
            total_inserted = 0
            
            for i in range(0, len(records), batch_size):
                batch = records[i:i + batch_size]
                
                # Create model instances
                model_instances = [model_class(**record) for record in batch]
                
                # Bulk insert
                session.bulk_save_objects(model_instances)
                session.commit()
                
                total_inserted += len(batch)
                print(f" Inserted batch {i//batch_size + 1}: {len(batch)} records")
            
            print(f" Total records inserted: {total_inserted}")
            return total_inserted
            
    except Exception as e:
        print(f" Error in bulk insert: {e}")
        return 0

print(" Data insertion functions loaded")


In [None]:
# Data Analysis and Visualization Functions
def analyze_cik_data(cik: str):
    """
    Analyze data for a specific CIK
    
    Args:
        cik: CIK number to analyze
    """
    try:
        # Get facts for this CIK
        facts_query = f"""
        SELECT tag, val, end_date, filed
        FROM bronze_sec_facts 
        WHERE cik = '{cik}'
        ORDER BY filed DESC
        LIMIT 100
        """
        
        facts_df = execute_sql_query(facts_query)
        
        if facts_df is not None and len(facts_df) > 0:
            print(f" Analysis for CIK {cik}:")
            print(f"   Total records: {len(facts_df)}")
            print(f"   Date range: {facts_df['filed'].min()} to {facts_df['filed'].max()}")
            print(f"   Unique tags: {facts_df['tag'].nunique()}")
            
            # Show top tags by frequency
            top_tags = facts_df['tag'].value_counts().head(10)
            print(f"\n   Top 10 tags:")
            for tag, count in top_tags.items():
                print(f"     {tag}: {count}")
            
            return facts_df
        else:
            print(f"No data found for CIK {cik}")
            return None
            
    except Exception as e:
        print(f" Error analyzing CIK {cik}: {e}")
        return None

def create_data_visualization(df: pd.DataFrame, chart_type: str = 'bar'):
    """
    Create visualizations for query results
    
    Args:
        df: pandas DataFrame to visualize
        chart_type: Type of chart ('bar', 'line', 'hist', 'scatter')
    """
    try:
        plt.figure(figsize=(12, 6))
        
        if chart_type == 'bar' and len(df.columns) >= 2:
            # Bar chart for categorical data
            col1, col2 = df.columns[0], df.columns[1]
            if df[col2].dtype in ['int64', 'float64']:
                df.head(20).plot(x=col1, y=col2, kind='bar', ax=plt.gca())
                plt.title(f'{col1} vs {col2}')
                plt.xticks(rotation=45)
            else:
                df[col1].value_counts().head(20).plot(kind='bar')
                plt.title(f'Distribution of {col1}')
                plt.xticks(rotation=45)
        
        elif chart_type == 'line' and len(df.columns) >= 2:
            # Line chart for time series
            col1, col2 = df.columns[0], df.columns[1]
            df.plot(x=col1, y=col2, kind='line', ax=plt.gca())
            plt.title(f'{col1} vs {col2}')
        
        elif chart_type == 'hist':
            # Histogram for numeric data
            numeric_cols = df.select_dtypes(include=[np.number]).columns
            if len(numeric_cols) > 0:
                df[numeric_cols[0]].hist(bins=20)
                plt.title(f'Distribution of {numeric_cols[0]}')
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f" Error creating visualization: {e}")

print(" Data analysis and visualization functions loaded")


## Usage Examples

### 1. Execute Custom SQL Queries
```python
# Example: Get revenue data for a specific CIK
revenue_query = """
SELECT tag, val, end_date, filed
FROM bronze_sec_facts 
WHERE cik = '0000001800' 
AND tag LIKE '%revenue%'
ORDER BY filed DESC
"""
revenue_df = execute_sql_query(revenue_query)
```

### 2. Analyze Specific CIK Data
```python
# Analyze data for Apple (CIK: 0000001800)
apple_data = analyze_cik_data('0000001800')
```

### 3. Insert New Data
```python
# Example: Insert processed data back to database
sample_data = pd.DataFrame({
    'cik': ['0000001800'],
    'tag': ['test_tag'],
    'val': [1000000],
    'end_date': ['2023-12-31'],
    'filed': ['2024-01-15']
})

# Insert using DataFrame method
insert_dataframe_to_table(sample_data, 'bronze_sec_facts')

# Or insert using ORM method
records = sample_data.to_dict('records')
bulk_insert_records(records, BronzeSecFacts)
```

### 4. Create Visualizations
```python
# Create bar chart from query results
create_data_visualization(df1, 'bar')
```
