In [1]:
from table_schema import generate_schema_prompt
from postgresql_setup import perform_query_on_postgresql_databases

## Check Schema


In [2]:
result = generate_schema_prompt(
    sql_dialect="PostgreSQL", db_name="california_schools", table_name="satscores"
)

In [3]:
print(result)

CREATE TABLE satscores (
cds text NOT NULL,
rtype text NULL,
sname text NULL,
dname text NULL,
cname text NULL,
enroll12 bigint NULL,
numtsttakr bigint NULL,
avgscrread bigint NULL,
avgscrmath bigint NULL,
avgscrwrite bigint NULL,
numge1500 bigint NULL,
    PRIMARY KEY (cds),
    FOREIGN KEY (cds) REFERENCES schools(cdscode)
);

Example data:
           cds rtype                        sname                             dname  cname  enroll12  numtsttakr  avgscrread  avgscrmath  avgscrwrite  numge1500
03100330000000     D                         None Amador County Office of Education Amador        16           0         NaN         NaN          NaN        NaN
03739810000000     D                         None             Amador County Unified Amador       317          97       525.0       514.0        503.0       60.0
03739810330050     S North Star Independent Study             Amador County Unified Amador        17           0         NaN         NaN          NaN        NaN


## Reproduce the error


### Preprcessing

we need to add a new column RWRatio into the satscores table


In [5]:
add_column_query = """
ALTER TABLE satscores
ADD COLUMN RWRatio numeric(5,2) DEFAULT 0;
"""
perform_query_on_postgresql_databases(add_column_query,db_name='california_schools')

### Error query


In [7]:
error_query = """
UPDATE satscores
SET RWRatio = CASE
    WHEN cname = 'Alameda' THEN AVG(CAST(avgscrread AS numeric(5,2))) / AVG(CAST(avgscrwrite AS numeric(5,2)))
    ELSE 0
END;
"""
perform_query_on_postgresql_databases(error_query,db_name='california_schools')

GroupingError: aggregate functions are not allowed in UPDATE
LINE 4:     WHEN cname = 'Alameda' THEN AVG(CAST(avgscrread AS numer...
                                        ^


## Reference Solution


In [8]:
correct_update_query = """
WITH cte_avg AS (
    SELECT 
        AVG(CAST(avgscrread AS numeric(5,2))) / AVG(CAST(avgscrwrite AS numeric(5,2))) AS avg_ratio
    FROM satscores
    WHERE cname = 'Alameda'
)
UPDATE satscores
SET RWRatio = cte_avg.avg_ratio
FROM cte_avg
WHERE cname = 'Alameda';
"""

perform_query_on_postgresql_databases(correct_update_query,db_name='california_schools')

## Test Cases


In [12]:
from decimal import Decimal


def test_satscores_update():
    try:

        # Test case 1: Check if RWRatio is updated for Alameda records
        query = (
            "SELECT COUNT(*) FROM satscores WHERE cname = 'Alameda' AND RWRatio <> 0"
        )
        result = perform_query_on_postgresql_databases(query,db_name='california_schools')
        assert (
            result[0][0] > 0
        ), f"Expected some records to be updated, but found {result[0][0]}"

        # Test case 2: Verify the RWRatio calculation for a specific record
        query = """
        SELECT cds, avgscrread, avgscrwrite, RWRatio 
        FROM satscores 
        WHERE cname = 'Alameda' AND avgscrread IS NOT NULL AND avgscrwrite IS NOT NULL
        LIMIT 1
        """
        result = perform_query_on_postgresql_databases(query,db_name='california_schools')
        assert len(result) > 0, "No valid records found for Alameda"
        record = result[0]
        expected_ratio = Decimal(record[1]) / Decimal(record[2])
        actual_ratio = Decimal(record[3])
        assert abs(actual_ratio - expected_ratio) < Decimal(
            "0.01"
        ), f"Expected ratio {expected_ratio}, but got {actual_ratio}"

        # Test case 3: Check that non-Alameda records are not updated
        query = (
            "SELECT COUNT(*) FROM satscores WHERE cname <> 'Alameda' AND RWRatio <> 0"
        )
        result = perform_query_on_postgresql_databases(query,db_name='california_schools')
        assert (
            result[0][0] == 0
        ), f"Expected 0 non-Alameda records to be updated, but found {result[0][0]}"

        # Test case 4: Verify the overall count of updated records
        query = "SELECT COUNT(*) FROM satscores WHERE RWRatio <> 0"
        result = perform_query_on_postgresql_databases(query,db_name='california_schools')
        alameda_count = perform_query_on_postgresql_databases(
            "SELECT COUNT(*) FROM satscores WHERE cname = 'Alameda'",db_name='california_schools'
        )[0][0]
        assert (
            result[0][0] == alameda_count
        ), f"Expected {alameda_count} records to be updated, but found {result[0][0]}"

        print("All test cases passed successfully!")

    except Exception as e:
        print(f"Test failed: {e}")


# Run the test
test_satscores_update()

All test cases passed successfully!


## Clean up


In [14]:
# Cleanup: Drop the RWRatio column
cleanup_query = """
ALTER TABLE satscores
DROP COLUMN IF EXISTS RWRatio;
"""
perform_query_on_postgresql_databases(cleanup_query,db_name='california_schools')
print("Cleanup completed: RWRatio column dropped.")

Cleanup completed: RWRatio column dropped.
