In [None]:
from collections import defaultdict
import json
from decimal import Decimal

from boto3.dynamodb.conditions import Key, Attr
import pandas as pd


import sys
import os


from dyno_viewer.aws.ddb import (
    scan_items,
    get_table,
    serialise_dynamodb_json,
    get_table_info,
)


class DecimalEncoder(json.JSONEncoder):
    """Helper class to convert Decimal types to numbers for JSON serialization."""

    def default(self, obj):
        if isinstance(obj, Decimal):
            return float(obj) if "." in str(obj) else int(obj)
        return super(DecimalEncoder, self).default(obj)

In [None]:
table = get_table("somebank", "ap-southeast-2", None)

In [None]:
items, _ = scan_items(table, paginate=False)

In [None]:
items

In [None]:
table_info = get_table_info(table)
primary_key = table_info["keySchema"]["primaryKey"]
sort_key = table_info["keySchema"].get("sortKey")

## v2

In [None]:
def group_by_primary_sort_key(items, primary_key, sort_key) -> dict:
    result = defaultdict(list)
    for item in items:
        pk_value = item.get(primary_key)
        sk_value = item.get(sort_key)
        if not pk_value and not sk_value:
            continue
        result[(pk_value, sk_value)].append(item)
    return dict(result)

In [None]:
group_pk_sk = group_by_primary_sort_key(items, primary_key, sort_key)

In [None]:
df_group_pk_sk = (
    pd.DataFrame(group_pk_sk)
    .T.reset_index()
    .rename(columns={"level_0": "primary_key", "level_1": "sort_key", 0: "payload"})
)

In [None]:
result = defaultdict(set)
for item in items:
    pk_value = item.get(primary_key)
    sk_value = item.get(sort_key)
    if not pk_value and not sk_value:
        continue
    result[(pk_value, sk_value)].update(item.keys()) 

In [None]:
dict(result)

In [None]:
result = defaultdict(dict)
for item in items:
    pk_value = item.get(primary_key)
    sk_value = item.get(sort_key)
    if not pk_value and not sk_value:
        continue
    for k, v in item.items():
        result[(pk_value, sk_value)].setdefault(k, []).append(type(v))

dict(result)

In [None]:
def ddb_item_schema(
    items: list[dict],
    primary_key: str,
    sort_key: str,
    model_name_attr: str | None = None,
) -> dict:
    """Generate a schema for DynamoDB items based on their primary and sort keys.

    :param items: List of DynamoDB items
    :type items: list[dict]
    :param primary_key: Primary key attribute name
    :type primary_key: str
    :param sort_key: Sort key attribute name
    :type sort_key: str
    :param model_name_attr: Optional attribute for the item name that is used to define what the item is (e.g. "customer", "order")
    :type model_name_attr: str | None

    :return: A dictionary with primary keys and sort keys as keys and their types as values
    :rtype: dict
    """
    result = defaultdict(dict)
    for item in items:
        pk_value = item.get(primary_key)
        sk_value = item.get(sort_key)
        ddb_model_name = item.get(model_name_attr)
        if not pk_value and not sk_value:
            continue
        result[(pk_value, sk_value)] = {"model_name": ddb_model_name, "type": {}}

        for k, v in item.items():
            result[(pk_value, sk_value)]["type"].setdefault(k, []).append(type(v))
    return dict(result)

In [None]:
schema = ddb_item_schema(items, primary_key, sort_key, model_name_attr="type")
{v["model_name"] for v in schema.values() if v["model_name"]}

In [None]:
from typing import Dict, List, Any, Optional, Set, Type, Union
from pydantic import BaseModel, Field, create_model


def generate_pydantic_models(schema: dict) -> Dict[str, Type[BaseModel]]:
    """
    Generate Pydantic models from the schema returned by ddb_item_schema.

    :param schema: Dictionary with (pk, sk) tuples as keys and schema dictionaries as values
    :type schema: dict

    :return: Dictionary mapping model names to Pydantic model classes
    :rtype: Dict[str, Type[BaseModel]]
    """
    models = {}

    # Group items by entity type
    for (pk, sk), schema_dict in schema.items():
        # Extract entity type from pk and sk
        pk_prefix = pk.split("#")[0] if "#" in pk else pk
        sk_prefix = sk.split("#")[0] if "#" in sk else sk

        # Use model_name if available, otherwise use pk/sk combo
        model_name = schema_dict.get("model_name") or f"{pk_prefix}_{sk_prefix}"

        fields = {}
        for attr_name, attr_types in schema_dict["type"].items():
            field_type = (
                attr_types.pop() if len(attr_types) == 1 else Union[*attr_types]
            )
            fields[attr_name] = field_type
        generated_model = create_model(model_name, **fields)
        models[model_name] = generated_model

    return models


# Test the function with our sample schema
schema = ddb_item_schema(items, primary_key, sort_key, model_name_attr="type")
pydantic_models = generate_pydantic_models(schema)


In [None]:
# Extract JSON schemas from each Pydantic model
model_schemas = [model.model_json_schema() for model in pydantic_models.values()]


# Write the schemas to a JSON file
import os

output_dir = os.path.join(os.path.dirname(os.getcwd()), "output")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "ddb_model_schemas.json")

with open(output_file, "w") as f:
    json.dump(model_schemas, f, indent=2)

print(f"Model schemas written to: {output_file}")