In [None]:
from pydantic import BaseModel, Field, ConfigDict, field_validator, computed_field
from aiobotocore.session import get_session
from botocore.exceptions import ClientError
from typing import Any, Self, Type
from decimal import Decimal
import uuid
import enum
import datetime
import builtins
import re

In [None]:
class Index:
    def __init__(
        self,
        partition_key: str,
        sort_key: str | None = None, *,
        index_name: str | None = None,
        local: bool = False
    ) -> None:
        self.partition_key = partition_key
        self.sort_key = sort_key
        self.index_name = index_name
        self.local = local

    def __repr__(self) -> str:
        return f'Index(partition_key={self.partition_key}, sort_key={self.sort_key}, index_name={self.index_name}, local={self.local})'

class Table:
    def __init__(
        self,
        table_name: str,
        primary_index: Index, *,
        type_index: Index | None = None,
        secondary_indexes: list[Index] | None = None
    ) -> None:
        self.table_name = table_name
        self.primary_index = primary_index
        self.type_index = type_index
        self.secondary_indexes = secondary_indexes
        
    async def create(self) -> None:
        session = get_session()
        async with session.create_client(
            'dynamodb',
            region_name='us-east-1',
            endpoint_url='http://172.25.32.1:8000',
            aws_access_key_id='local',
            aws_secret_access_key='local',
            verify=False
        ) as ddb:
            try:
                table_info = await ddb.describe_table(TableName=self.table_name)
                self._verify_table(table_info)
            except ClientError as e:
                if e.response['Error']['Code'] == 'ResourceNotFoundException':
                    self._create_table()

    def _verify_table(self, table_info: dict) -> None:
        print(table_info)

    def _create_table(self) -> None:
        print(f'Table {self.table_name} does not exist')

In [219]:
class AccessPattern:
    VARIABLE_PATTERN = r'\{[a-zA-Z0-9_]+\}'

    def __init__(
        self,
        index: Index,
        partition_key: str,
        sort_key: str | None = None
    ) -> None:
        self.index = index
        self.partition_key = partition_key
        self.sort_key = sort_key

    @property
    def _pkey_variables(self) -> list[str]:
        return re.findall(self.VARIABLE_PATTERN, self.partition_key)
    
    @property
    def _skey_variables(self) -> list[str] | None:
        if self.sort_key:
            return re.findall(self.VARIABLE_PATTERN, self.sort_key)
        
    def __repr__(self) -> str:
        return f'AccessPattern(index={self.index}, partition_key={self.partition_key}, sort_key={self.sort_key})'


In [220]:
STRING = 'S'
NUMBER = 'N'
BINARY = 'B'
BOOLEAN = 'BOOL'
NULL = 'NULL'
MAP = 'M'
LIST = 'L'
STRING_SET = 'SS'
NUMBER_SET = 'NS'
BINARY_SET = 'BS'


class DynamoSerializer:
    def serialize(self, v: Any) -> str:
        py_type = type(v)
        match py_type:
            case builtins.str:
                return {STRING: v}
            case enum.Enum:
                return {STRING: v.value}
            case datetime.date:
                return {STRING: v.isoformat()}
            case datetime.time:
                return {STRING: v.isoformat()}
            case datetime.datetime:
                return {STRING: v.isoformat()}
            case uuid.UUID:
                return {STRING: str(v)}
            case builtins.int:
                return {NUMBER: str(v)}
            case builtins.float:
                return {NUMBER: str(v)}
            case builtins.bytes:
                return {BINARY: v}
            case builtins.bool:
                return {BOOLEAN: str(v)}
            case None:
                return {NULL: True}
            case builtins.list:
                return {LIST: [self.serialize(x) for x in v]}
            case builtins.dict:
                return {MAP: {str(k): self.serialize(v) for k, v in v.items()}}
            
    def deserialize(self, v: dict[str, Any]) -> Any:
        match list(v.keys())[0]:
            case 'S':
                return v[STRING]
            case 'N':
                return v[NUMBER]
            case 'B':
                return v[BINARY]
            case 'BOOL':
                return v[BOOLEAN]
            case 'NULL':
                return None
            case 'L':
                return [self.deserialize(x) for x in v[LIST]]
            case 'M':
                return {self.deserialize(k): self.deserialize(v) for k, v in v[MAP].items()}
            case _:
                raise ValueError(f'Unknown type: {v}')

In [None]:
class ModelSettings(BaseModel):
    table: Table
    object_type: str | None = None
    access_patterns: list[AccessPattern] = Field(default_factory=list)

    @computed_field
    @property
    def _access_pattern_map(self) -> dict[str, AccessPattern]:
        return {ap.index: ap for ap in self.access_patterns}

    @field_validator('access_patterns')
    @classmethod
    def validate_access_patterns(cls, v: list[AccessPattern]) -> list[AccessPattern]:
        found_primary: bool = False
        for ap in v:
            if not ap.index.index_name:
                found_primary = True

        if not found_primary:
            raise ValueError('No primary index defined')
        
        return v
    
    @field_validator('object_type')
    @classmethod
    def validate_object_type(cls, v: str | None) -> str | None:
        if v:
            return v
        return cls.__name__.lower()

    model_config = ConfigDict(
        arbitrary_types_allowed=True
    )

class DynaModel(BaseModel):
    _serializer = DynamoSerializer()
    _settings: ModelSettings | None = None

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if not self._settings:
            raise ValueError('Model settings not defined')

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        settings = getattr(cls, 'Settings', None)
        if settings:
            cls._settings = settings

    @classmethod
    async def get(
        cls,
        vars: dict[str, Any]
    ) -> Self | None:
        primary_access_pattern: AccessPattern | None = None
        for access_pattern in cls._settings.access_patterns:
            if not access_pattern.index.index_name:
                primary_access_pattern = access_pattern
                break

        if not primary_access_pattern:
            raise ValueError('No access pattern found')
        
        key_values = {}
        pkey_value: str = primary_access_pattern.partition_key
        for v in primary_access_pattern._pkey_variables:
            vname = v[1:-1]
            pkey_value = pkey_value.replace(v, str(vars[vname]))

        key_values[primary_access_pattern.index.partition_key] = pkey_value

        skey_value: str | None = primary_access_pattern.sort_key
        for v in primary_access_pattern._skey_variables or []:
            vname = v[1:-1]
            skey_value = skey_value.replace(v, str(vars[vname]))
            
        if skey_value:
            key_values[primary_access_pattern.index.sort_key] = skey_value

        key = {k: DynamoSerializer().serialize(v) for k, v in key_values.items()}

        session = get_session()
        async with session.create_client(
            'dynamodb',
            region_name='us-east-1',
            endpoint_url='http://172.25.32.1:8000',
            aws_access_key_id='local',
            aws_secret_access_key='local',
            verify=False
        ) as ddb:
            response = await ddb.get_item(
                TableName=cls._settings.table.table_name,
                Key=key
            )
            
        if 'Item' not in response:
            return None
        
        item = response['Item']
        data = {k: DynamoSerializer().deserialize(v) for k, v in item.items()}
        return cls.model_validate(data)

    @classmethod
    async def query(
        cls,
        vars: dict[str, Any] = {},
        skey_condition: str | None = None,
        index: Index | None = None
    ) -> list[Self]:
        active_access_pattern: AccessPattern | None = None
        for access_pattern in cls._settings.access_patterns:
            if index and access_pattern.index == index:
                active_access_pattern = access_pattern
                break
            elif not index and not access_pattern.index.index_name:
                active_access_pattern = access_pattern
                break

        if not active_access_pattern:
            raise ValueError('No access pattern found')
        
        key_expression: str = ''
        expression_attribute_values: dict[str, str] = {}
        expression_attribute_names: dict[str, str] = {}

        pkey_value: str = active_access_pattern.partition_key
        for v in active_access_pattern._pkey_variables:
            vname = v[1:-1]
            pkey_value = pkey_value.replace(v, str(vars[vname]))

        key_expression += f'#{active_access_pattern.index.partition_key} = :{active_access_pattern.index.partition_key}'
        expression_attribute_values[f':{active_access_pattern.index.partition_key}'] = DynamoSerializer().serialize(pkey_value)
        expression_attribute_names[f'#{active_access_pattern.index.partition_key}'] = active_access_pattern.index.partition_key

        if skey_condition:
            skey_value: str | None = active_access_pattern.sort_key
            for v in active_access_pattern._skey_variables or []:
                vname = v[1:-1]
                skey_value = skey_value.replace(v, str(vars[vname]))

            if skey_value:
                key_expression += f' AND #{active_access_pattern.index.sort_key} = :{active_access_pattern.index.sort_key}'
                expression_attribute_values[f':{active_access_pattern.index.sort_key}'] = DynamoSerializer().serialize(skey_value)
                expression_attribute_names[f'#{active_access_pattern.index.sort_key}'] = active_access_pattern.index.sort_key

        session = get_session()
        query_kwargs = {
            'TableName': cls._settings.table.table_name,
            'KeyConditionExpression': key_expression,
            'ExpressionAttributeValues': expression_attribute_values,
            'ExpressionAttributeNames': expression_attribute_names
        }

        if index:
            query_kwargs['IndexName'] = index.index_name
            
        async with session.create_client(
            'dynamodb',
            region_name='us-east-1',
            endpoint_url='http://172.25.32.1:8000',
            aws_access_key_id='local',
            aws_secret_access_key='local',
            verify=False
        ) as ddb:
            response = await ddb.query(**query_kwargs)
            
        if 'Items' not in response:
            return []
        
        output: list[Self] = []
        for item in response['Items']:
            data = {k: DynamoSerializer().deserialize(v) for k, v in item.items()}
            output.append(cls.model_validate(data))
        
        return output
    
    async def create(self) -> None:
        data = self.model_dump()

        key_values = {}
        for access_pattern in self._settings.access_patterns:
            pkey_value: str = access_pattern.partition_key
            for v in access_pattern._pkey_variables:
                vname = v[1:-1]
                pkey_value = pkey_value.replace(v, str(data[vname]))

            key_values[access_pattern.index.partition_key] = pkey_value

            skey_value: str | None = access_pattern.sort_key
            for v in access_pattern._skey_variables or []:
                vname = v[1:-1]
                skey_value = skey_value.replace(v, str(data[vname]))
            
            if skey_value:
                key_values[access_pattern.index.sort_key] = skey_value

        data.update(key_values)

        item = {k: self._serializer.serialize(v) for k, v in data.items()}
        
        session = get_session()
        async with session.create_client(
            'dynamodb',
            region_name='us-east-1',
            endpoint_url='http://172.25.32.1:8000',
            aws_access_key_id='local',
            aws_secret_access_key='local',
            verify=False
        ) as ddb:
            await ddb.put_item(
                TableName=self._settings.table.table_name,
                Item=item
            )

In [222]:
primary_index = Index('PK', 'SK')
type_index = Index('_TYPE', 'created_at', index_name='_TYPE-index')
secondary_index = Index('GS1PK', 'GS1SK', index_name='GS1')

table = Table(
    'Table',
    primary_index,
    type_index=type_index,
    secondary_indexes=[
        secondary_index
    ]
)

await table.create()

{'Table': {'AttributeDefinitions': [{'AttributeName': 'GS1SK', 'AttributeType': 'S'}, {'AttributeName': 'SK', 'AttributeType': 'S'}, {'AttributeName': 'created_at', 'AttributeType': 'S'}, {'AttributeName': 'PK', 'AttributeType': 'S'}, {'AttributeName': '_TYPE', 'AttributeType': 'S'}, {'AttributeName': 'GS1PK', 'AttributeType': 'S'}], 'TableName': 'Table', 'KeySchema': [{'AttributeName': 'PK', 'KeyType': 'HASH'}, {'AttributeName': 'SK', 'KeyType': 'RANGE'}], 'TableStatus': 'ACTIVE', 'CreationDateTime': datetime.datetime(2025, 6, 3, 20, 40, 33, 890000, tzinfo=tzlocal()), 'ProvisionedThroughput': {'LastIncreaseDateTime': datetime.datetime(1969, 12, 31, 16, 0, tzinfo=tzlocal()), 'LastDecreaseDateTime': datetime.datetime(1969, 12, 31, 16, 0, tzinfo=tzlocal()), 'NumberOfDecreasesToday': 0, 'ReadCapacityUnits': 0, 'WriteCapacityUnits': 0}, 'TableSizeBytes': 0, 'ItemCount': 0, 'TableArn': 'arn:aws:dynamodb:ddblocal:000000000000:table/Table', 'BillingModeSummary': {'BillingMode': 'PAY_PER_REQUE

In [223]:
class CollectionQueryResponse:
    def __init__(
        self,
        items: list[dict[str, Any]],
        models: list[DynaModel]
    ) -> None:
        self.items = items
        self.models = models

        self._model_map: dict[Type[DynaModel], list[DynaModel]] = {}
        self._table_type_index: Index | None = None
        for model in models:
            self._table_type_index = model._settings.table.type_index
            type_access_pattern: AccessPattern | None = None
            for access_pattern in model._settings.access_patterns:
                if access_pattern.index == self._table_type_index:
                    type_access_pattern = access_pattern
                    self._model_map[type_access_pattern.partition_key] = []
                    break
                
            if not type_access_pattern:
                raise ValueError('No type access pattern found')
            
        for item in items:
            if self._table_type_index.partition_key not in item:
                continue

            item_type = DynamoSerializer().deserialize(item[self._table_type_index.partition_key])
            if item_type not in self._model_map:
                continue

            self._model_map[item_type].append(item)

    def __getitem__(self, key: Type[DynaModel]) -> list[DynaModel]:
        model_type_name = None
        for access_pattern in key._settings.access_patterns:
            if access_pattern.index == self._table_type_index:
                model_type_name = access_pattern.partition_key
                break
        
        if not model_type_name:
            raise ValueError('No type access pattern found')
        
        return self._model_map[model_type_name]

class Collection:
    def __init__(
        self,
    ) -> None:
        ...

    async def query(
        self
    ) -> CollectionQueryResponse:
        ...

In [224]:
class MetadataMixin(BaseModel):
    created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
    updated_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
    created_by: str = 'SYSTEM'
    updated_by: str = 'SYSTEM'

In [225]:
class User(DynaModel, MetadataMixin):
    id: uuid.UUID = Field(default_factory=uuid.uuid4)
    name: str
    email: str

    class Settings:
        table = table
        access_patterns = [
            AccessPattern(primary_index, 'user:{id}', 'user:{id}'),
            AccessPattern(type_index, 'user', '{created_at}'),
            AccessPattern(secondary_index, 'user-email:{email}')
        ]

In [226]:
class Post(DynaModel, MetadataMixin):
    id: uuid.UUID = Field(default_factory=uuid.uuid4)
    user_id: uuid.UUID
    title: str
    content: str

    class Settings:
        table = table
        access_patterns = [
            AccessPattern(primary_index, 'user:{user_id}', 'post:{id}'),
            AccessPattern(type_index, 'post', '{created_at}'),
            AccessPattern(secondary_index, 'post:{id}', 'post:{id}')
        ]

In [227]:
model = User(
    name='David',
    email='david@gmail.com'
)

In [228]:
await model.create()

In [229]:
await User.get({'id': model.id})

User(created_at=datetime.datetime(2025, 6, 4, 21, 10, 47, 976791), updated_at=datetime.datetime(2025, 6, 4, 21, 10, 47, 976802), created_by='SYSTEM', updated_by='SYSTEM', id=UUID('1bb81f8b-31e2-423c-b477-a96598b105c5'), name='David', email='david@gmail.com')

In [230]:
await User.query(index=type_index)

[User(created_at=datetime.datetime(2025, 6, 4, 21, 10, 47, 976791), updated_at=datetime.datetime(2025, 6, 4, 21, 10, 47, 976802), created_by='SYSTEM', updated_by='SYSTEM', id=UUID('1bb81f8b-31e2-423c-b477-a96598b105c5'), name='David', email='david@gmail.com')]

In [232]:
post1 = Post(
    user_id=model.id,
    title="Hello World",
    content="This is a post"
)

post2 = Post(
    user_id=model.id,
    title="Foo Bar",
    content="This is another post"
)

post3 = Post(
    user_id=model.id,
    title="Baz Qux",
    content="This is yet another post"
)

await post1.create()
await post2.create()
await post3.create()