In [None]:
import json

import duckdb
from jinja2 import Template

from src.duckdb_prompt_udf import prompt

# Connect to DuckDB
con = duckdb.connect(database=":memory:", read_only=False)

# Register the Python function as a scalar UDF
con.create_function("prompt", prompt, [str, str, str, float], str)

"Done!"

In [None]:
system_prompt = (
    "You are a government auditor reviewing New Yorks 311 service request system.\n\n"
    "Your task:\n"
    "1. Review the given service request details (agency, complaint type, and description)\n"
    "2. Choose the most appropriate category and subcategory from the provided CATEGORIES json structure\n"
    "3. The sample data in CATEGORIES json serves as guidance for classification\n\n"
    "Important rules:\n"
    "- Only use CATEGORIES and SUBCATEGORIES that exist in the provided CATEGORIES json block\n"
    "- Do NOT use the sample labels as categories or subcategories\n"
    "- Do NOT create new categories\n\n"
    "Response format:\n"
    "```json\n"
    '{"category": "STRING", "subcategory": "STRING"}\n'
    "```\n"
    "Note: Your response must contain ONLY the JSON object, nothing else."
)

json_schema = json.dumps(
    {
        "name": "category_response",
        "type": "object",
        "strict": "true",
        "schema": {
            "type": "object",
            "properties": {
                "category": {
                    "type": "string",
                },
                "subcategory": {
                    "type": "string",
                },
            },
        },
        "required": ["category", "subcategory"],
    }
)

# Define the SQL query template
query_template = """
COPY (
    WITH llm_categorization AS (
        SELECT
            regexp_replace(
                prompt(
                    '# CATEGORIES:\n'
                    || '```json\n'
                    ||categories::VARCHAR || '\n'
                    || '```\n'
                    ||'# SERVICE REQUEST:\n'
                    ||'AGENCY: ' || IFNULL(agency, 'N/A') || '\n'
                    ||'COMPLAINT TYPE: ' || IFNULL(complaint_type, 'N/A')  || '\n'
                    ||'DESCRIPTION: ' || IFNULL(descriptor, 'N/A') || '\n',
                    '{{ system_prompt }}'::VARCHAR,
                    '{{ json_schema }}'::VARCHAR,
                    {{ temperature }}::FLOAT
                ),
                '```json|```',
                '',
                'g'
            )::VARCHAR AS raw_llm_response,
            json_extract_string(raw_llm_response, '$.category')::VARCHAR AS category,
            json_extract_string(raw_llm_response, '$.subcategory')::VARCHAR AS subcategory,
            agency,
            complaint_type,
            descriptor,
            request_count,
        FROM (
            SELECT 
                agency,
                complaint_type,
                descriptor,
                count(*) AS request_count
            FROM "{{ data_file }}" 
                group by 1,2,3
                order by 4 desc
                limit {{ limit }}
        ) CROSS JOIN read_json('./data/categories.json')
    )

    SELECT
        agency,
        complaint_type,
        descriptor,
        category,
        subcategory,
        raw_llm_response,
        request_count
    FROM llm_categorization 
) TO '{{ output_file }}';
"""

# # Render the template with variables
template = Template(query_template)
query = template.render(
    system_prompt=system_prompt,
    json_schema=json_schema,
    temperature=0.2,
    data_file="./data/cityofnewyork/service_requests_2024.parquet",
    output_file="./output/llm_categorize_output_2024.csv",
    limit=1200,
)

# Execute the query
print("Executing query...")

result = con.execute(query).fetchall()
print(result)

In [None]:
# Check for invalid categories in the output file
qa_query = """
COPY (
    WITH categories_data AS (
        SELECT UNNEST(categories) AS category_obj
        FROM read_json('./data/categories.json')
    ),

    unnested AS (
    SELECT 
        category_obj['category'] AS category,
        UNNEST(category_obj['subcategories']) AS subcategory_obj
    FROM categories_data
    ),
    
    all_categories AS (
        SELECT 
            category,
            subcategory_obj['category'] AS subcategory
        FROM unnested
    )

    SELECT * 
    FROM './output/llm_categorize_output.csv' AS o 
    LEFT JOIN all_categories AS c 
        ON o.category = c.category 
        AND o.subcategory = c.subcategory
    WHERE c.subcategory is null
) TO './output/invalid_categories_2024.csv';
"""

# Execute the query
print("Executing query...")
con.execute(qa_query)

# Read the result
invalid_count = con.execute(
    "SELECT COUNT(*) FROM './output/invalid_categories_2024.csv'"
).fetchone()[0]
all_categories_count = con.execute(
    "SELECT COUNT(*) FROM './output/llm_categorize_output_2024.csv'"
).fetchone()[0]
print(f"Invalid categories: {invalid_count / all_categories_count * 100:.2f} %")

In [None]:
# Query to calculate the percentage of data with no matching categories
mismatch_query = """
WITH llm_output AS (
    SELECT *
    FROM './output/llm_categorize_output_2024.csv'
),
service_requests AS (
    SELECT *
    FROM './data/cityofnewyork/service_requests_2024.parquet'
),
joined_data AS (
    SELECT
        sr.*,
        lo.category AS llm_category,
        lo.subcategory AS llm_subcategory
    FROM service_requests sr
    LEFT JOIN llm_output lo
    ON sr.agency = lo.agency
    AND sr.complaint_type = lo.complaint_type
    AND sr.descriptor = lo.descriptor
)
SELECT
    COUNT(*) AS total_records,
    SUM(CASE WHEN llm_category IS NULL OR llm_subcategory IS NULL THEN 1 ELSE 0 END) AS mismatched_records
FROM joined_data;
"""

# Execute the query
print("Executing mismatch query...")
mismatch_result = con.execute(mismatch_query).fetchone()

# Calculate and print the percentage of mismatched data
total_records, mismatched_records = mismatch_result
mismatch_percentage = (mismatched_records / total_records) * 100
print(
    f"Percentage of service requests without matching category data: {mismatch_percentage:.2f} %"
)

In [None]:
# Query to count matching service requests per category and subcategory
matching_query = """
WITH llm_output AS (
    SELECT *
    FROM './output/llm_categorize_output_2024.csv'
),
service_requests AS (
    SELECT *
    FROM './data/cityofnewyork/service_requests_2024.parquet'
),
categories_data AS (
    SELECT UNNEST(categories) AS category_obj
    FROM read_json('./data/categories.json')
),
categories_data_unnested AS (
    SELECT 
        category_obj['category'] AS category,
        UNNEST(category_obj['subcategories']) AS subcategory_obj
    FROM categories_data
),    
all_categories AS (
    SELECT 
        category,
        subcategory_obj['category'] AS subcategory
    FROM categories_data_unnested
),
joined_data AS (
    SELECT
        lo.category AS category,
        lo.subcategory AS subcategory,
        COUNT(*) AS matching_requests
    FROM service_requests sr
    LEFT JOIN llm_output lo
        ON sr.agency = lo.agency
        AND sr.complaint_type = lo.complaint_type
        AND sr.descriptor = lo.descriptor
    WHERE category IS NOT NULL AND subcategory IS NOT NULL
    GROUP BY category, subcategory
    ORDER BY matching_requests DESC
)

SELECT
    jd.category,
    jd.subcategory,
    jd.matching_requests,
    ac.category IS NOT NULL AS is_valid_category,
    ac.subcategory IS NOT NULL AS is_valid_subcategory
FROM joined_data jd
LEFT JOIN all_categories ac
    ON jd.category = ac.category
    AND jd.subcategory = ac.subcategory;
"""

# Execute the query
print("Executing matching query...")
matching_result = con.execute(matching_query).fetchall()

# Pretty print the result
for (
    category,
    subcategory,
    matching_requests,
    is_valid_category,
    is_valid_subcategory,
) in matching_result:
    print(
        f"Category: {category}, Subcategory: {subcategory}, Matching Requests: {matching_requests}, "
        f"Valid Category: {is_valid_category}, Valid Subcategory: {is_valid_subcategory}"
    )