# AI Data Classification Engine (Batch, Azure/OpenAI)

This notebook contains a **reusable classification pipeline** that:
- Reads data from an Excel/CSV file
- Sends rows in **batches** to Azure OpenAI or OpenAI
- Applies **per-column predefined categories** (optional)
- Generates **categorical + freetext outputs**
- Writes the results back to Excel

Each code block has instructions explaining how to use or modify it.

## 1. Imports and JSON helper
Run this cell first.

It imports all required libraries and defines a helper function to safely parse JSON returned by the model (handles minor formatting issues).

In [None]:
import os
import json
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Set

import pandas as pd
from openai import OpenAI, AzureOpenAI


def clean_and_parse_json(raw: str) -> dict:
    """
    Clean common LLM formatting (e.g., ```json fences) and parse JSON.
    Raises ValueError if parsing still fails.
    """
    text = raw.strip()

    # Remove markdown-style code fences if present
    if text.startswith("```"):
        lines = text.splitlines()
        # Drop first line (``` or ```json)
        if lines and lines[0].startswith("```"):
            lines = lines[1:]
        # Drop last line if it's just ```
        if lines and lines[-1].strip().startswith("```"):
            lines = lines[:-1]
        text = "\n".join(lines).strip()

    # Try direct parse first
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        # Try to salvage: cut at the last closing brace
        last_brace = text.rfind("}")
        if last_brace != -1:
            candidate = text[: last_brace + 1]
            try:
                return json.loads(candidate)
            except json.JSONDecodeError:
                pass

        raise ValueError(f"❌ Invalid JSON returned by model:\n{text}")

## 2. Configuration and Client Factory
This cell defines:
- `ClassificationConfig`: all configuration for a classification run.
- `create_client()`: creates either an Azure OpenAI or OpenAI client.

**Instructions:**
- Set your provider to `'azure'` or `'openai'` when creating the config later.
- For Azure, make sure environment variables `AZURE_OPENAI_API_KEY` and `AZURE_OPENAI_ENDPOINT` are set.
- For OpenAI, set `OPENAI_API_KEY`.

In [None]:
@dataclass
class ClassificationConfig:
    """Configuration for the classification pipeline."""

    # "openai" or "azure"
    provider: str = "azure"

    # OpenAI model name or Azure deployment name
    model: str = "gpt-4o-Global-Standard"  # For Azure: deployment name

    # System prompt (set to BASE_SYSTEM_PROMPT later)
    system_prompt: str = ""

    # Columns from your file used as input to the model
    input_columns: List[str] = field(default_factory=list)

    # Columns that the model must output
    output_columns: List[str] = field(default_factory=list)

    # Optional per-column predefined categories
    # Example:
    # {
    #   "Plausible": [
    #       {"name": "Yes", "definition": "..."},
    #       {"name": "No",  "definition": "..."}
    #   ],
    #   "RiskLevel": [
    #       {"name": "Low", "definition": "..."},
    #       {"name": "High", "definition": "..."}
    #   ]
    # }
    per_column_predefined_categories: Dict[str, List[Dict[str, Any]]] = field(
        default_factory=dict
    )

    # Optional business / domain context (project-specific)
    business_context: Dict[str, Any] = field(default_factory=dict)

    # Model parameters
    temperature: float = 0.0
    max_output_tokens: Optional[int] = 2000

    # How many rows per API call
    batch_size: int = 10


def create_client(config: ClassificationConfig):
    """Create a client for either Azure OpenAI or OpenAI."""
    if config.provider.lower() == "azure":
        # Azure OpenAI (old-style chat completions)
        return AzureOpenAI(
            api_key=os.getenv("AZURE_OPENAI_API_KEY"),
            azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),  # e.g. "https://<resource>.openai.azure.com"
            api_version="2024-05-01-preview",
        )
    else:
        # Normal OpenAI
        return OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

## 3. RowClassifier (Batch Classification Logic)
This class:
- Creates the API client
- Builds the JSON payload for a batch of rows
- Calls the model using `chat.completions.create`
- Parses the JSON response
- Tracks categories per output column for reuse

**Instructions:**
- You usually don't need to modify this unless you want to change how batching works.
- Focus on editing `ClassificationConfig` and `BASE_SYSTEM_PROMPT` instead.

In [None]:
class RowClassifier:
    def __init__(self, config: ClassificationConfig):
        self.config = config
        self.client = create_client(config)

        # Per-column category tracking (for reuse)
        # Example:
        # {
        #   "Plausible": {"Yes", "No"},
        #   "RiskLevel": {"Low", "Medium", "High"}
        # }
        self.seen_categories_per_col: Dict[str, Set[str]] = {}

        # Initialize from per_column_predefined_categories
        for col, cats in self.config.per_column_predefined_categories.items():
            self.seen_categories_per_col[col] = {
                c["name"] for c in cats if "name" in c
            }

    def _build_batch_payload(self, batch_rows: List[tuple]) -> Dict[str, Any]:
        """
        Build the JSON payload for a batch of rows.

        batch_rows: list of (row_index, row_series)
        Returns a dict that will be serialized to JSON and sent to the model.
        """
        rows_payload = []
        for idx, row in batch_rows:
            row_data = {
                col: ("" if pd.isna(row.get(col)) else str(row.get(col)))
                for col in self.config.input_columns
            }
            rows_payload.append(
                {
                    "row_index": int(idx),
                    "row_data": row_data,
                }
            )

        payload = {
            "input_columns": self.config.input_columns,
            "output_columns": self.config.output_columns,
            "rows": rows_payload,
            # Optional per-column predefined categories
            "per_column_predefined_categories": self.config.per_column_predefined_categories,
            # Categories already used per column, for reuse and consistency
            "previous_categories_per_column": {
                col: sorted(list(vals))
                for col, vals in self.seen_categories_per_col.items()
            },
            "business_context": self.config.business_context,
        }
        return payload

    def classify_batch(self, batch_rows: List[tuple]) -> Dict[int, Dict[str, Any]]:
        """
        Classify a batch of rows in a single API call using chat.completions.

        Returns:
            dict: {row_index: outputs_dict, ...}
        """
        if not batch_rows:
            return {}

        payload = self._build_batch_payload(batch_rows)

        # Call chat completions (old-style SDK usage)
        response = self.client.chat.completions.create(
            model=self.config.model,  # Azure: deployment name
            messages=[
                {"role": "system", "content": self.config.system_prompt},
                {"role": "user", "content": json.dumps(payload, ensure_ascii=False)},
            ],
            temperature=self.config.temperature,
            max_tokens=self.config.max_output_tokens,
        )

        # Get first choice text
        raw = response.choices[0].message.content.strip()

        # Parse JSON (robust against minor formatting issues)
        parsed = clean_and_parse_json(raw)

        if "rows" not in parsed or not isinstance(parsed["rows"], list):
            raise ValueError(f"❌ JSON must contain 'rows' list:\n{parsed}")

        result_by_index: Dict[int, Dict[str, Any]] = {}

        for item in parsed["rows"]:
            if (
                not isinstance(item, dict)
                or "row_index" not in item
                or "outputs" not in item
            ):
                raise ValueError(
                    f"❌ Each item must have 'row_index' and 'outputs':\n{item}"
                )

            row_index = int(item["row_index"])
            outputs = item["outputs"]

            # Track categories per output column for reuse
            for col, value in outputs.items():
                if not value:
                    continue
                if col not in self.seen_categories_per_col:
                    self.seen_categories_per_col[col] = set()
                self.seen_categories_per_col[col].add(str(value))

            result_by_index[row_index] = outputs

        return result_by_index

## 4. File-Level Pipeline (`classify_file`)
This function:
- Loads an Excel or CSV file
- Verifies input columns
- Adds missing output columns
- Processes rows in batches using `RowClassifier`
- Saves the classified result to a new Excel file

**Instructions:**
- You usually only change `input_path`, `output_path`, and `sheet_name` when calling this.
- Make sure your input file exists and has the input columns you configured.

In [None]:
def classify_file(
    input_path: str,
    output_path: str,
    config: ClassificationConfig,
    sheet_name: Optional[str] = None,
):
    """
    - Loads Excel or CSV
    - Runs model classification in batches
    - Writes Excel file with output columns added
    """
    ext = os.path.splitext(input_path)[1].lower()
    if ext in [".xlsx", ".xls"]:
        df = pd.read_excel(input_path, sheet_name=sheet_name)
    elif ext == ".csv":
        df = pd.read_csv(input_path)
    else:
        raise ValueError("❌ Unsupported file format. Use .xlsx, .xls, or .csv")

    # Ensure input columns exist
    for col in config.input_columns:
        if col not in df.columns:
            raise ValueError(f"❌ Missing input column in file: {col}")

    # Create output columns if not already present
    for col in config.output_columns:
        if col not in df.columns:
            df[col] = None

    classifier = RowClassifier(config)
    indices = list(df.index)

    # Process in batches
    for start in range(0, len(indices), config.batch_size):
        end = start + config.batch_size
        batch_indices = indices[start:end]
        batch_rows = [(idx, df.loc[idx]) for idx in batch_indices]

        try:
            outputs_by_index = classifier.classify_batch(batch_rows)
        except Exception as e:
            print(f"⚠️ Error in batch {start}–{end}: {e}")
            continue

        # Fill DataFrame with outputs
        for idx in batch_indices:
            if idx not in outputs_by_index:
                continue
            outputs = outputs_by_index[idx]
            for col in config.output_columns:
                if col in outputs:
                    df.at[idx, col] = outputs[col]

        print(f"✅ Processed rows {start}–{min(end, len(indices)) - 1}")

    # Save result
    df.to_excel(output_path, index=False)
    print(f"✅ Classification complete. Output saved to: {output_path}")

## 5. Optimized System Prompt
This is the **core brain** of the classifier.
- It explains how to treat input/output columns.
- It defines behavior for categorical vs freetext output columns.
- It enforces strict JSON output.

**Instructions:**
- You can adapt business logic here (e.g. domain rules, logistics rules).
- Keep the overall structure and formatting rules intact.


In [None]:
BASE_SYSTEM_PROMPT = """
You are an AI engine that classifies and enriches tabular data. 
You will receive batched rows and must return STRICT JSON output for each row.

---------------------------------------------------------
INPUT YOU WILL RECEIVE
---------------------------------------------------------
You receive one JSON object containing:

- input_columns: list of column names whose values appear in each row_data.
- output_columns: list of column names you must generate.
- rows: list of items:
    {
      "row_index": <int>,
      "row_data": { "<col>": "<value>", ... }
    }
- per_column_predefined_categories (optional):
    {
      "<output_col>": [
        {"name": "<category>", "definition": "<meaning>"},
        ...
      ]
    }
- previous_categories_per_column (optional):
    {
      "<output_col>": ["existing", "categories", ...]
    }
- business_context (optional):
    Domain rules, logistics logic, constraints, or definitions.
    These override generic assumptions.

ALWAYS:
- Use ALL input columns together to understand meaning.
- Respect business_context when making decisions.

---------------------------------------------------------
HOW TO ASSIGN OUTPUT COLUMN VALUES
---------------------------------------------------------

For EACH output column C in output_columns:

1) IF C has predefined categories:
   - Treat it as a **strict categorical column**.
   - Use ONLY the category names in per_column_predefined_categories[C].
   - Use category **definitions** to choose the correct one.
   - NEVER invent new categories for this column.
   - previous_categories_per_column[C] may be referenced for consistency,
     but predefined categories override everything.

2) IF C does NOT have predefined categories:
   - Determine whether it is **categorical** or **freetext** based on its name and purpose.

   A) FREETEXT COLUMNS (e.g., "Reason", "Comment", "Explanation", "Description"):
      - MUST generate fresh, natural-language text for this row.
      - Completely IGNORE previous_categories_per_column[C].
      - Write clear, factual, row-specific explanations.

   B) CATEGORICAL COLUMNS WITHOUT A PREDEFINED LIST:
      - You MAY create new categories when needed.
      - BEFORE creating new categories:
          - Check previous_categories_per_column[C] and reuse an existing one if it fits.
      - Avoid producing near-duplicate categories.
      - Ensure consistency across similar rows.

---------------------------------------------------------
REASONING REQUIREMENTS
---------------------------------------------------------
- Use definition-based comparison for predefined categories.
- Use contextual reasoning for other categorical columns.
- Consider ALL input columns together.
- Apply business_context when interpreting thresholds, ranges, or domain-specific rules.
- Prefer stable, repeatable decisions over creative ones.

---------------------------------------------------------
STRICT OUTPUT FORMAT (NO EXCEPTIONS)
---------------------------------------------------------
You must return EXACTLY the following JSON structure:

{
  "rows": [
    {
      "row_index": <same integer>,
      "outputs": {
        "<OutputColumn1>": "<value>",
        "<OutputColumn2>": "<value>",
        ...
      }
    },
    ...
  ]
}

RULES:
- DO NOT add any text outside the JSON object.
- DO NOT use markdown or code fences.
- DO NOT add keys that are not in output_columns.
- DO NOT omit any output column.
- ALL values must be strings.
- row_index MUST match the input row_index.

---------------------------------------------------------
BEHAVIORAL RULES
---------------------------------------------------------
- No hallucinations.
- No renaming of columns or categories.
- No invented structure.
- For categorical columns: consistency is mandatory.
- For freetext columns: clarity and accuracy are mandatory.
- When unsure, choose the most contextually justified category.

Your job: analyze each row → produce valid JSON exactly as required.
""".strip()

## 6. Example Usage
This cell shows how to configure and run the classifier on a real file.

**Steps to use:**
1. Make sure your input Excel/CSV file exists (update the path as needed).
2. Adjust `input_columns` and `output_columns` to match your file.
3. Optionally define `per_column_predefined_categories` and `business_context`.
4. Run the cell to generate the output Excel file.


In [None]:
if __name__ == "__main__":
    # Example configuration for a logistics plausibility check
    config = ClassificationConfig(
        provider="azure",              # or "openai"
        model="gpt-4o-rpe",            # Azure deployment name or OpenAI model name
        system_prompt=BASE_SYSTEM_PROMPT,

        # Input columns from your Excel/CSV file
        input_columns=["Length", "Width", "Height", "Weight"],

        # Output columns to be generated by the model
        output_columns=["Plausible", "RiskLevel", "Reason"],

        # Optional predefined categories per output column
        per_column_predefined_categories={
            "Plausible": [
                {"name": "Yes", "definition": "Item is suitable for sea shipping."},
                {"name": "No",  "definition": "Item is not suitable for sea shipping."}
            ],
            "RiskLevel": [
                {"name": "Low",    "definition": "Fits comfortably within standard limits."},
                {"name": "Medium", "definition": "Near standard limits."},
                {"name": "High",   "definition": "Exceeds limits or requires special handling."}
            ]
            # "Reason" has no predefined categories → treated as freetext
        },

        # Optional business knowledge
        business_context={
            "company": "Siemens Energy",
            "shipping_rules": {
                "max_length": 400,
                "max_width": 240,
                "max_height": 220,
                "max_weight": 5000
            }
        },

        temperature=0.0,
        max_output_tokens=1500,
        batch_size=10,
    )

    # Run classification on an input file
    classify_file(
        input_path="input_data.xlsx",      # change to your real file path
        output_path="classified_output.xlsx",
        config=config,
        sheet_name="Sheet1",               # or None for CSV
    )