In [1]:
import json
import logging
import re
import warnings
from datetime import datetime
from pathlib import Path
from pprint import pprint
from typing import Annotated, Any, Generator, Literal, Type, TypeVar

# Standard imports
import numpy as np
import numpy.typing as npt
import pandas as pd
import polars as pl

# Visualization
# import matplotlib.pyplot as plt

# NumPy settings
np.set_printoptions(precision=4)

# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

# Polars settings
pl.Config.set_fmt_str_lengths(1_000)
pl.Config.set_tbl_cols(n=1_000)
pl.Config.set_tbl_rows(n=200)

warnings.filterwarnings("ignore")

# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [None]:
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme({
    "white": "#FFFFFF",  # Bright white
    "info": "#00FF00",  # Bright green
    "warning": "#FFD700",  # Bright gold
    "error": "#FF1493",  # Deep pink
    "success": "#00FFFF",  # Cyan
    "highlight": "#FF4500",  # Orange-red
})
console = Console(theme=custom_theme)


def create_path(path: str | Path) -> None:
    """
    Create parent directories for the given path if they don't exist.

    Parameters
    ----------
    path : str | Path
        The file path for which to create parent directories.

    """
    Path(path).parent.mkdir(parents=True, exist_ok=True)


def go_up_from_current_directory(*, go_up: int = 1) -> None:
    """This is used to up a number of directories.

    Params:
    -------
    go_up: int, default=1
        This indicates the number of times to go back up from the current directory.

    Returns:
    --------
    None
    """
    import os
    import sys

    CONST: str = "../"
    NUM: str = CONST * go_up

    # Goto the previous directory
    prev_directory = os.path.join(os.path.dirname(__name__), NUM)
    # Get the 'absolute path' of the previous directory
    abs_path_prev_directory = os.path.abspath(prev_directory)

    # Add the path to the System paths
    sys.path.insert(0, abs_path_prev_directory)
    print(abs_path_prev_directory)

In [3]:
go_up_from_current_directory(go_up=1)

/Users/mac/Desktop/MyProjects/batch-process


In [None]:
import json
import time
from datetime import datetime
from typing import Any

from celery import chord, current_task, group
from schemas import DataProcessingSchema
from src import create_logger
from src.celery import celery_app
from src.database import get_db_session
from src.database.db_models import BaseTask, DataProcessingJob

logger = create_logger(name="data_processing")


# Note: When `bind=True`, celery automatically passes the task instance as the first argument
# meaning that we need to use `self` and this provides additional functionality like retries, etc
@celery_app.task(bind=True, base=BaseTask)
def process_data_chunk(self, chunk_data: list[str], chunk_id: int) -> dict[str, Any | None | float | int]:  # noqa: ANN001, ARG001
    """
    Process a chunk of data

    Parameters
    ----------
    chunk_data : list[str]
        List of strings to be processed
    chunk_id : int
        Unique identifier for this chunk

    Returns
    -------
    dict[str, Any | None | float | int]
        Dictionary containing processed data, processing time, and item count
    """
    try:
        start_time = time.time()

        # Simulate data processing
        processed_data: list[str] = []
        total_items: int | None = len(chunk_data)

        for i, item in enumerate(chunk_data):
            # Update task progress
            current_task.update_state(
                state="PROGRESS",
                meta={"current": i + 1, "total": total_items, "chunk_id": chunk_id},
            )

            # Simulate processing time
            time.sleep(0.9)

            if isinstance(item, str):
                processed_item = item.upper()

            else:
                processed_item = item

            processed_data.append(processed_item)

        processing_time: float | None = time.time() - start_time

        logger.info(f"Processed chunk {chunk_id} with {total_items} items in {processing_time:.2f}s")

        return {
            "chunk_id": chunk_id,
            "processed_data": processed_data,
            "processing_time": processing_time,
            "items_count": total_items,
        }

    except Exception as e:
        logger.error(f"Error processing chunk {chunk_id}: {e}")
        raise self.retry(exc=e) from e


@celery_app.task
def combine_processed_chunks(chunk_results: list[Any]) -> dict[str, Any]:
    """
    Combine results from multiple data processing chunks
    """
    try:
        with get_db_session() as session:
            # Sort chunks by chunk_id
            sorted_results = sorted(chunk_results, key=lambda x: x["chunk_id"])

            # Combine all processed data
            combined_data: list[str] = []
            total_processing_time: int = 0
            total_items: int = 0

            for result in sorted_results:
                combined_data.extend(result["processed_data"])
                total_processing_time += result["processing_time"]
                total_items += result["items_count"]

            avg_processing_time = round((total_processing_time / len(sorted_results)), 2)
            # Save to database
            data = DataProcessingSchema(
                job_name="bulk_data_processing",
                input_data=json.dumps({"chunks": sorted_results}),
                output_data=json.dumps({"combined_data": combined_data, "total_items": total_items}),
                processing_time=avg_processing_time,
                status="completed",
                completed_at=datetime.now(),
            ).model_dump()
            job = DataProcessingJob(**data)
            session.add(job)
            session.flush()

            logger.info(f"Combined {len(sorted_results)} chunks with {total_items} total items")

            return {
                "status": "completed",
                "total_chunks": len(sorted_results),
                "total_items": total_items,
                "avg_processing_time": avg_processing_time,
                "job_id": job.id,
            }

    except Exception as e:
        logger.error(f"Error combining chunks: {e}")
        raise


@celery_app.task
def process_large_dataset(data: list[Any], chunk_size: int = 10) -> dict[str, Any]:
    """
    Process a large dataset by splitting into chunks and using chord
    """
    try:
        # Split data into chunks
        chunks: list[list[Any]] = [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)]

        # Create a chord: process chunks in parallel, then combine results
        job = chord(
            group(process_data_chunk.s(chunk, i) for i, chunk in enumerate(chunks)),
            combine_processed_chunks.s(),
        )

        result = job.apply_async()

        return {
            "status": "dispatched",
            "total_items": len(data),
            "chunks": len(chunks),
            "chord_id": result.id,
        }

    except Exception as e:
        logger.error(f"Error dispatching large dataset processing: {e}")
        raise


<@task: None of None>

In [None]:
def _get_prediction(
    record: PersonSchema | MultiPersonsSchema,
    model_dict: dict[str, Any],
) -> list[dict[str, Any]]:
    """Process a single record and return predictions.

    Parameters
    ----------
    record : PersonSchema | MultiPersonsSchema
        Input record containing person or multiple person data.
    model_dict : dict[str, Any]
        Dictionary containing model and processor objects.

    Returns
    -------
    list[dict[str, Any]]
        List of dictionaries containing predictions and features.
    """
    data: pl.DataFrame = record_to_dataframe(record)  # type: ignore
    # return data
    features: npt.NDArray[np.float64] = model_dict["processor"].transform(data)
    data_features: pl.DataFrame = pl.DataFrame(
        features, schema=model_dict["processor"].get_feature_names_out().tolist()
    ).drop(["num_vars__survived"])

    y_pred: npt.NDArray[np.float64] = model_dict["model"].predict_proba(data_features)[:, 1]
    data = data.with_columns(probability=y_pred).with_columns(  # type: ignore
        survived=(pl.col("probability") > 0.5).cast(pl.Int64)
    )
    data_dict: list[dict[str, Any]] = data.to_dicts()
    return data_dict

In [None]:
from sqlalchemy import delete, insert, select, update

from schemas import EmailSchema
from src.database.db_models import EmailLog, get_db_session, init_db

In [None]:
init_db()

## [Docs](https://docs.sqlalchemy.org/en/20/orm/queryguide/select.html)

### [Insert](https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-bulk-insert-statements)

- Old API

```python
with get_db_session() as session:
    data_dict = input_data.to_data_model_dict()
    record = EmailLog(**data_dict)
    session.add(record)
    session.flush()
    output_data = {key: getattr(record, key) for key in record.output_fields()}
```

<br>

- New API

```py
with get_db_session() as session:
    data_dict = input_data.to_data_model_dict()
    session.execute(insert(EmailLog), [data_dict])
```

In [None]:
input_data: EmailSchema = EmailSchema(
    recipient="marketing@client.com",
    subject="Partnership Proposal",
    body="We would like to discuss a potential partnership opportunity.",
)
console.print(input_data)

In [None]:
input_data.model_dump()

In [None]:
with get_db_session() as session:
    data_dict = input_data.model_dump()
    record = EmailLog(**data_dict)
    session.add(record)
    session.flush()
    output_data = {key: getattr(record, key) for key in record.output_fields()}


console.print(output_data)

In [None]:
with get_db_session() as session:
    statement = session.query(EmailLog).where(EmailLog.created_at < datetime.now())
    record = session.execute(statement).scalar_one()
    output_data = {key: getattr(record, key) for key in record.output_fields()}


console.print(output_data)

In [None]:
input_data_2: EmailSchema = EmailSchema(
    recipient="emeka2@example.com",
    subject="test!!!",
    body="this is an example body",
    status="processing",
)
input_data_3: EmailSchema = EmailSchema(
    recipient="john.doe@example.com",
    subject="Meeting Reminder",
    body="Hi John, just a reminder about our meeting tomorrow at 10 AM.",
    status="processing",
)
input_data_4: EmailSchema = EmailSchema(
    recipient="info@company.org",
    subject="New Product Launch",
    body="Dear valued customer, check out our exciting new product!",
    status="sent",
    created_at=datetime(2025, 7, 10, 9, 0, 0),
    sent_at="2025-07-10T09:05:00",
)
console.print((input_data_2, input_data_3, input_data_4))

### [Bulk Insert](https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-bulk-insert-statements)

- Old API

```py
with get_db_session() as session:
    data_list: list[dict[str, Any]] = [_data.to_data_model_dict() for _data in (input_data_2, input_data_3, input_data_4)]
    session.bulk_insert_mappings(EmailLog, data_list)
```

<br>

- New API

```py
with get_db_session() as session:
    data_list: list[dict[str, Any]] = [
        _data.to_data_model_dict()
        for _data in (input_data_2, input_data_3, input_data_4)
    ]
    session.execute(insert(EmailLog), data_list)
```

In [None]:
with get_db_session() as session:
    data_list: list[dict[str, Any]] = [_data.model_dump() for _data in (input_data_2, input_data_3, input_data_4)]
    session.execute(insert(EmailLog), data_list)

### Select

In [None]:
# Select a single record
with get_db_session() as session:
    statement = select(EmailLog).where(EmailLog.id == 1, EmailLog.status == "pending")
    record = session.execute(statement).scalar_one()
    output_data = {key: getattr(record, key) for key in record.output_fields()}


console.print(output_data)

In [None]:
# Select all records
with get_db_session() as session:
    statement = select(EmailLog)
    record = session.execute(statement).scalars()

    output_data = [{key: getattr(row, key) for key in row.output_fields()} for row in record]

console.print(output_data)

### [Update](https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-update-and-delete-with-custom-where-criteria)

In [None]:
with get_db_session() as session:
    statement = (
        update(EmailLog)
        .where(EmailLog.id == 1)
        .values(status="sent", sent_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    )
    # It closes the session and returns None
    session.execute(statement)

# Verify that the record was updated
with get_db_session() as session:
    statement = select(EmailLog)
    record = session.execute(statement).scalars()

    output_data = [{key: getattr(row, key) for key in row.output_fields()} for row in record]

console.print(output_data)

### [Delete](https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-update-and-delete-with-custom-where-criteria)

In [None]:
with get_db_session() as session:
    statement = delete(EmailLog).where(EmailLog.id == 2)
    # It closes the session and returns None
    session.execute(statement)

# Verify that the record was updated
with get_db_session() as session:
    statement = select(EmailLog)
    record = session.execute(statement).scalars()

    output_data = [{key: getattr(row, key) for key in row.output_fields()} for row in record]

console.print(output_data)

In [None]:
from config import app_config

In [None]:
beat_dict: dict[str, dict[str, Any]] = dict(app_config.celery_config.beat_config.beat_schedule.model_dump().items())

# Add the health_check
beat_dict["health_check"] = app_config.celery_config.beat_config.health_check.model_dump()


console.print(beat_dict)

In [None]:
app_config.celery_config.beat_config.beat_schedule.model_dump().items()

In [None]:
import json
from datetime import datetime
from typing import Any

from pydantic import BaseModel, field_serializer


class MyModel(BaseModel):
    name: str
    age: int
    role: str
    salary: float = 0.0
    others: Any | None = None

    @field_serializer("others")
    def serialize(self, value: Any) -> str:
        if isinstance(value, datetime):
            return value.isoformat()
        return json.dumps(value)


def my_func(name: str, **kwargs) -> MyModel:
    my_dict = {"name": name, **kwargs}
    return MyModel(**my_dict)


result = my_func(
    "Neidu",
    age=30,
    role="AI Engineer",
    friend="None",
    others=["Hi"],
    # others=datetime.now(),
)


In [None]:
print(result.model_dump())

json.loads(result.model_dump()["others"])

In [None]:
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Any

import joblib

from celery import chord, current_task, group
from schemas import ModelOutput, MultiPersonsSchema, MultiPredOutput, PersonSchema
from src import PACKAGE_PATH, create_logger
from src.celery import celery_app
from src.database import get_db_session
from src.database.db_models import BaseTask, MLPredictionJob
from src.ml.utils import get_batch_prediction, get_prediction

logger = create_logger(name="ml_prediction")


@celery_app.task(bind=True, base=BaseTask)
def process_prediction_chunk(self, persons_data: list[dict[str, Any]], chunk_id: int) -> dict[str, Any]:  # noqa: ANN001
    """
    Process a chunk of ML predictions.

    Parameters
    ----------
    persons_data : list[dict[str, Any]]
        List of person data dictionaries for prediction
    chunk_id : int
        Unique identifier for this chunk

    Returns
    -------
    dict[str, Any]
        Dictionary containing chunk processing results and metadata
    """
    try:
        start_time = time.time()

        # Validate input data
        multi_persons = MultiPersonsSchema(persons=persons_data)  # type: ignore
        total_items = len(multi_persons.persons)

        # Load model once for the entire chunk
        model_dict_fp: Path = PACKAGE_PATH / "models/model.pkl"
        with open(model_dict_fp, "rb") as f:
            model_dict = joblib.load(f)

        # Process predictions
        prediction_results = []

        for i, person in enumerate(multi_persons.persons):
            # Update task progress
            current_task.update_state(
                state="PROGRESS",
                meta={"current": i + 1, "total": total_items, "chunk_id": chunk_id},
            )

            # Make individual prediction
            result: ModelOutput = get_prediction(person, model_dict)
            prediction_results.append(result.model_dump())

        processing_time = time.time() - start_time

        logger.info(f"Processed chunk {chunk_id} with {total_items} predictions in {processing_time:.2f}s")

        return {
            "chunk_id": chunk_id,
            "prediction_results": prediction_results,
            "processing_time": processing_time,
            "items_count": total_items,
            "status": "success",
        }

    except Exception as e:
        logger.error(f"Error processing prediction chunk {chunk_id}: {e}")
        raise self.retry(exc=e) from e


@celery_app.task
def combine_prediction_results(chunk_results: list[dict[str, Any]]) -> dict[str, Any]:
    """
    Combine results from multiple prediction chunks.

    Parameters
    ----------
    chunk_results : list[dict[str, Any]]
        List of chunk processing results

    Returns
    -------
    dict[str, Any]
        Dictionary containing combined prediction results
    """
    try:
        with get_db_session() as session:
            # Sort chunks by chunk_id
            sorted_results = sorted(chunk_results, key=lambda x: x["chunk_id"])

            # Combine all prediction results
            combined_predictions = []
            total_processing_time = 0
            total_items = 0

            for result in sorted_results:
                combined_predictions.extend(result["prediction_results"])
                total_processing_time += result["processing_time"]
                total_items += result["items_count"]

            avg_processing_time = round((total_processing_time / len(sorted_results)), 2)

            # Save to database
            job_data = {
                "job_name": "batch_ml_prediction",
                "input_data": json.dumps({"chunks": len(sorted_results), "total_items": total_items}),
                "output_data": json.dumps({"predictions": combined_predictions}),
                "processing_time": avg_processing_time,
                "prediction_count": total_items,
                "status": "completed",
                "completed_at": datetime.now(),
            }

            job = MLPredictionJob(**job_data)
            session.add(job)
            session.flush()

            logger.info(f"Combined {len(sorted_results)} chunks with {total_items} total predictions")

            return {
                "status": "completed",
                "total_chunks": len(sorted_results),
                "total_predictions": total_items,
                "avg_processing_time": avg_processing_time,
                "job_id": job.id,
                "predictions": combined_predictions,
            }

    except Exception as e:
        logger.error(f"Error combining prediction results: {e}")
        raise


@celery_app.task
def process_batch_predictions(persons_data: list[dict[str, Any]], chunk_size: int = 10) -> dict[str, Any]:
    """
    Process a large batch of ML predictions by splitting into chunks and using chord.

    Parameters
    ----------
    persons_data : list[dict[str, Any]]
        List of person data dictionaries for prediction
    chunk_size : int, optional
        Size of each processing chunk, by default 10

    Returns
    -------
    dict[str, Any]
        Dictionary containing batch processing dispatch information
    """
    try:
        # Split data into chunks
        chunks = [persons_data[i : i + chunk_size] for i in range(0, len(persons_data), chunk_size)]

        # Create a chord: process chunks in parallel, then combine results
        job = chord(
            group(process_prediction_chunk.s(chunk, i) for i, chunk in enumerate(chunks)),
            combine_prediction_results.s(),
        )

        result = job.apply_async()

        logger.info(f"Dispatched batch prediction job with {len(persons_data)} items in {len(chunks)} chunks")

        return {
            "status": "dispatched",
            "total_items": len(persons_data),
            "chunks": len(chunks),
            "chunk_size": chunk_size,
            "chord_id": result.id,
        }

    except Exception as e:
        logger.error(f"Error dispatching batch predictions: {e}")
        raise


@celery_app.task(bind=True, base=BaseTask)
def process_dlq_message(self, message_data: dict[str, Any]) -> dict[str, Any]:  # noqa: ANN001
    """
    Process a message from the dead letter queue.

    Parameters
    ----------
    message_data : dict[str, Any]
        Message data from DLQ

    Returns
    -------
    dict[str, Any]
        Dictionary containing DLQ processing results
    """
    try:
        # Validate the message data
        if "persons" in message_data:
            # Batch message
            record = MultiPersonsSchema(**message_data)
            message_type = "batch"
            item_count = len(record.persons)
        else:
            # Single message
            record = PersonSchema(**message_data)
            message_type = "single"
            item_count = 1

        # Log DLQ message to database (you might want to create a DLQ table)
        logger.warning(f"Processing DLQ message: {message_type} with {item_count} items")

        # For now, just log the DLQ data - you can extend this to save to a DLQ table
        with get_db_session() as session:
            job_data = {
                "job_name": f"dlq_{message_type}_processing",
                "input_data": json.dumps(message_data),
                "output_data": json.dumps({"status": "dlq_processed", "message_type": message_type}),
                "processing_time": 0.0,
                "prediction_count": 0,
                "status": "dlq_processed",
                "completed_at": datetime.now(),
            }

            job = MLPredictionJob(**job_data)
            session.add(job)
            session.flush()

            logger.info(f"DLQ message processed and logged with job_id: {job.id}")

            return {
                "status": "dlq_processed",
                "message_type": message_type,
                "item_count": item_count,
                "job_id": job.id,
            }

    except Exception as e:
        logger.error(f"Error processing DLQ message: {e}")
        raise self.retry(exc=e) from e
