# 04 â€” Transaction Categorization Test

**Objective:** Execute the transaction categorization using Databricks `ai_query()` against the three test layers (Obvious, Ambiguous, Unknown).

### Tasks:
1. Define the system prompt with full taxonomy context and few-shot examples.
2. Ensure the catalog data is available in a Unity Catalog table.
3. Run `ai_query()` on the transaction catalog.
4. Save all results to the `results/` directory.

In [None]:
# Configuration
MODEL_NAME = "databricks-meta-llama-3-3-70b-instruct"  # Update based on Step 03 discovery

# Unity Catalog details
# 'ciq-bp_dummy-dev' is the CATALOG name. We need a SCHEMA and a TABLE name.
CATALOG_NAME = "ciq-bp_dummy-dev"
SCHEMA_NAME = "default"  # You can change this to 'bronze' or another schema if preferred
TABLE_NAME = "transaction_code_catalog"

# Construct the full path with backticks to handle the hyphens in the catalog name
CATALOG_TABLE_PATH = f"`{CATALOG_NAME}`.`{SCHEMA_NAME}`.`{TABLE_NAME}`"

TAXONOMY_PATH = "../taxonomy/transaction_categorization_taxonomy.md"
LOCAL_CATALOG_PATH = "../taxonomy/data/transaction_code_catalog.csv"

print(f"Using model: {MODEL_NAME}")
print(f"Target table: {CATALOG_TABLE_PATH}")

### 1. Upload Local Data to Unity Catalog
If the table doesn't exist in Unity Catalog yet, we need to create it from our local CSV.

In [None]:
import pandas as pd

try:
    print(f"Reading local catalog from: {LOCAL_CATALOG_PATH}")
    pdf = pd.read_csv(LOCAL_CATALOG_PATH)
    
    # Convert to Spark DataFrame
    sdf = spark.createDataFrame(pdf)
    
    # Write to Unity Catalog
    print(f"Writing to Unity Catalog: {CATALOG_TABLE_PATH}...")
    sdf.write.mode("overwrite").saveAsTable(CATALOG_TABLE_PATH)
    print("Table created successfully.")
except Exception as e:
    print(f"Error uploading table: {e}")
    print("If running locally, this step will fail. Ensure you are in a Databricks Notebook.")

### 2. Build Prompt
We load the full taxonomy markdown file to provide rich context to the LLM.

In [None]:
with open(TAXONOMY_PATH, "r") as f:
    taxonomy_md = f.read()

system_prompt = f"""
You are a transaction categorization engine for a US bank. 
Given a transaction code (TRANCD) and description (DESC), classify it into the StrategyCorp taxonomy below.

{taxonomy_md}

### Rules:
1. First determine Block A (Non-fee item) or Block B (Fee item). Fee items typically contain: "fee", "charge", "surcharge", "penalty", "service charge", "reversal".
2. Refunds/Reversals of fees: Must be Block A > Money movement > Deposits.
3. Classify through Level 2 > Level 3 > Level 4. Use EXACT strings from the taxonomy.
4. Use "Unclassified" if no mapping fits. Do not guess.
5. Return ONLY valid JSON. Do not include any explanation, reasoning, or markdown outside the JSON block. 

Your entire response must be a single JSON object with these keys: 
category_1, category_2, category_3, category_4, include_in_scoring, credit_debit, confidence.
"""

few_shot_examples = """
### Few-Shot Examples:
Input: TRANCD=183, DESC="ACH Debit - SERMONS"
Output: {{
  "category_1":"Non-fee item",
  "category_2":"Money movement",
  "category_3":"ACH",
  "category_4":null,
  "include_in_scoring":true,
  "credit_debit":"Debit",
  "confidence":0.99
}}

Input: TRANCD=299, DESC="ATM Service Charge"
Output: {{
  "category_1":"Fee item",
  "category_2":"All others",
  "category_3":"Money movement",
  "category_4":"ATM",
  "include_in_scoring":false,
  "credit_debit":"Debit",
  "confidence":0.98
}}

Input: TRANCD=141, DESC="Transfer from DDA"
Output: {{
  "category_1":"Non-fee item",
  "category_2":"Money movement",
  "category_3":"Transfers & Payments",
  "category_4":null,
  "include_in_scoring":true,
  "credit_debit":"Credit",
  "confidence":0.95
}}
"""

full_system_prompt = system_prompt + "\n" + few_shot_examples
print("Prompt prepared.")

### 3. Execute Classification (SQL)
We use `ai_query()` to call the model and `from_json()` to parse the results into structured columns.

In [None]:
# Escape single quotes in the prompt for SQL
escaped_prompt = full_system_prompt.replace("'", "''")

classification_query = f"""
WITH raw_results AS (
  SELECT 
    TRANCD,
    sample_desc_1,
    volume,
    source_file,
    ai_query(
      '{MODEL_NAME}',
      CONCAT('{escaped_prompt}', '\nClassify: TRANCD=', TRANCD, ', DESC="', sample_desc_1, '"')
    ) as llm_raw
  FROM {CATALOG_TABLE_PATH}
)
SELECT 
  *, 
  from_json(llm_raw, 'category_1 STRING, category_2 STRING, category_3 STRING, category_4 STRING, include_in_scoring BOOLEAN, credit_debit STRING, confidence DOUBLE') as parsed
FROM raw_results
"""

try:
    print("Executing categorization query...")
    results_df = spark.sql(classification_query)
    display(results_df)
    
    # Save results for evaluation
    output_table_name = f"`{CATALOG_NAME}`.`{SCHEMA_NAME}`.transaction_classification_results"
    print(f"Saving results to table: {output_table_name}...")
    results_df.write.mode("overwrite").saveAsTable(output_table_name)
    
    # Also export to CSV for local script processing
    # Note: /Workspace paths work in Databricks
    # results_df.toPandas().to_csv("../results/latest_run.csv", index=False)
    
except Exception as e:
    print(f"Error during classification: {e}")

### 4. Task-Specific AI Functions (ai_classify & ai_extract)

We can compare `ai_query` (general-purpose) with task-specific functions like `ai_classify()` for flat categorization and `ai_extract()` for identifying specific attributes. These are optimized for simpler tasks and might be faster or more reliable for the first level of classification.

In [None]:
# Test 1: Classify into Block A (Non-fee item) or Block B (Fee item)
labels = ["Non-fee item", "Fee item"]

classify_l1_query = f"""
SELECT 
  TRANCD,
  sample_desc_1 as description,
  ai_classify(sample_desc_1, ARRAY('Non-fee item', 'Fee item')) as l1_prediction
FROM {CATALOG_TABLE_PATH}
"""

try:
    print("Executing ai_classify for Level 1 (Block assignment)...\n")
    display(spark.sql(classify_l1_query))
except Exception as e:
    print(f"Error with ai_classify: {e}")

In [None]:
# Test 2: Classify Level 2 categories
l2_labels = ["Money movement", "Account operations", "NSF/OD", "Misc", "Service Charges", "Interchange"]
labels_sql = ", ".join([f"'{l}'" for l in l2_labels])

classify_l2_query = f"""
SELECT 
  TRANCD,
  sample_desc_1 as description,
  ai_classify(sample_desc_1, ARRAY({labels_sql})) as l2_prediction
FROM {CATALOG_TABLE_PATH}
"""

try:
    print("Executing ai_classify for Level 2...\n")
    display(spark.sql(classify_l2_query))
except Exception as e:
    print(f"Error with ai_classify: {e}")

In [None]:
# Test 3: Extracting attributes using ai_extract
# Useful for spotting 'reversals', 'refunds', or 'fee' types within the text.
entities = ["transaction_method", "is_reversal", "is_fee", "is_refund"]
entities_sql = ", ".join([f"'{e}'" for e in entities])

extract_query = f"""
SELECT 
  TRANCD,
  sample_desc_1 as description,
  ai_extract(sample_desc_1, ARRAY({entities_sql})) as extracted_info
FROM {CATALOG_TABLE_PATH}
"""

try:
    print("Executing ai_extract for specific attributes...\n")
    display(spark.sql(extract_query))
except Exception as e:
    print(f"Error with ai_extract: {e}")