In [43]:
%%capture
%pip install sqlalchemy aiosqlite aiocache

732.26s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


In [44]:
import hashlib
import json
from typing import Any, Dict, Optional, Tuple, Union

from aiocache import Cache, cached
from aiocache.serializers import JsonSerializer, PickleSerializer
from sqlalchemy import text
from sqlalchemy.engine import Result
from sqlalchemy.exc import MultipleResultsFound, NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import Select
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.expression import FunctionElement


In [45]:
# --- Configuration ---
class CacheConfig:
    def __init__(
        self,
        cache_type: str = "redis",
        endpoint: str = "127.0.0.1",
        port: int = 6379,
        db: int = 0,
        serializer: str = "json",
        ttl: int = 3600,  # Default TTL: 1 hour
        **kwargs: Any
    ) -> None:
        self.cache_type = cache_type
        self.endpoint = endpoint
        self.port = port
        self.db = db
        self.ttl = ttl
        self.serializer = JsonSerializer() if serializer == "json" else PickleSerializer()
        self.aiocache_kwargs = kwargs

        if cache_type == "redis":
            self.cache_class = Cache.REDIS
        elif cache_type == "memcached":
            self.cache_class = Cache.MEMCACHED
        elif cache_type == "memory":
            self.cache_class = Cache.MEMORY
        else:
            raise ValueError(
                f"Invalid cache_type: {cache_type}. Choose 'redis', 'memcached', or 'memory'."
            )

In [46]:
# --- Result Mock ---
class ResultMock:
    # Same as before (no changes needed)
    def __init__(self, data: list) -> None:
        self._data = data

    def all(self) -> list:
        return self._data

    def first(self) -> Optional[Any]:
        return self._data[0] if self._data else None

    def scalar(self) -> Optional[Any]:
        return self._data[0][0] if self._data else None

    def scalar_one(self) -> Any:
        if len(self._data) != 1:
            raise (
                MultipleResultsFound() if len(self._data) > 1 else NoResultFound()
            )
        return self._data[0][0]

    def scalar_one_or_none(self) -> Optional[Any]:
        if len(self._data) > 1:
            raise MultipleResultsFound()
        return self._data[0][0] if self._data else None

    def one(self) -> Any:
        if len(self._data) != 1:
            raise (
                MultipleResultsFound() if len(self._data) > 1 else NoResultFound()
            )
        return self._data[0]

    def one_or_none(self) -> Optional[Any]:
        if len(self._data) > 1:
            raise MultipleResultsFound()
        return self._data[0] if self._data else None

    def __iter__(self) -> Any:
        return iter(self._data)

    def partitions(self, size: Optional[int] = None) -> list[Any]:
        """
        Mock implementation of partitions method
        """
        if size is None:
            yield self._data
        else:
            for i in range(0, len(self._data), size):
                yield self._data[i:i + size]


In [47]:
# --- Cached Async Session ---
class CachedAsyncSession(AsyncSession):
    def __init__(
        self,
        *args: Any,
        cache_config: CacheConfig,
        **kwargs: Any,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.cache_config = cache_config
        self.cache = Cache(
            cache_class=cache_config.cache_class,
            endpoint=cache_config.endpoint,
            port=cache_config.port,
            db=cache_config.db,
            serializer=cache_config.serializer,
            ttl=cache_config.ttl,
            **cache_config.aiocache_kwargs,
        )

    async def execute(  # type: ignore[override]
        self,
        statement: Union[Select, TextClause, FunctionElement],
        *args: Any,
        **kwargs: Any,
    ) -> Union[Result, ResultMock]:
        if isinstance(statement, Select) or isinstance(
            statement, FunctionElement
        ):
            cache_key = self._generate_cache_key(statement, **kwargs)

            async def _execute_and_cache(
                statement: Union[Select, TextClause, FunctionElement],
                *args: Any,
                **kwargs: Any
            ) -> Any:
                result = await super(CachedAsyncSession, self).execute(statement, *args, **kwargs)
                return [list(row) for row in result.all()]

            result_list = await cached(
                ttl=self.cache_config.ttl,
                cache=self.cache_config.cache_class,
                key=cache_key,
                serializer=self.cache_config.serializer,
            )(_execute_and_cache)(statement, *args, **kwargs)
            return ResultMock(result_list)
        else:
            return await super().execute(statement, *args, **kwargs)
        
    def _generate_cache_key(
        self, statement: Union[Select, TextClause, FunctionElement], **kwargs: Any
    ) -> str:
        if isinstance(statement, Select):
            compiled_statement = str(
                statement.compile(compile_kwargs={"literal_binds": True})
            )
        else:
            compiled_statement = str(statement)

        params_str = json.dumps(kwargs, sort_keys=True)
        combined_str = f"{compiled_statement}:{params_str}"
        return "db_cache:" + hashlib.sha256(combined_str.encode()).hexdigest()

    async def invalidate_cache(self, key_pattern: str) -> None:
        """
        Invalidates cache keys matching a pattern.

        Args:
            key_pattern: The pattern to match (e.g., "db_cache:*", "user:*:profile").
        """
        if self.cache_config.cache_type == "redis":
            # Get the underlying Redis client
            redis_client = self.cache.client  # No need to await

            # Iterate over keys matching the pattern and delete them
            async for key in redis_client.scan_iter(key_pattern):
                await redis_client.delete(key)
        else:
            # For memcached or in-memory, you might need a different invalidation
            # strategy or to clear the entire cache if key patterns aren't supported.
            await self.cache.clear()


In [48]:
# --- Example Usage ---
import hashlib
import json
from typing import Any, Dict, Optional, Tuple, Union

from aiocache import Cache, cached
from aiocache.serializers import JsonSerializer, PickleSerializer
from sqlalchemy import text
from sqlalchemy.engine import Result
from sqlalchemy.exc import MultipleResultsFound, NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.sql import Select
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.expression import FunctionElement
from sqlalchemy import Column, Integer, String, select


Base = declarative_base()

class User(Base):
    __tablename__ = "users"
    id = Column(Integer, primary_key=True)
    name = Column(String)
    email = Column(String)

async def setup_db_and_cache():
    # Database setup (using SQLite)
    engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=True)

    # Cache configuration (using Redis)
    cache_config = CacheConfig(
        cache_type="redis",
        endpoint="127.0.0.1",
        port=6379,
        serializer="pickle",
        ttl=60,
    )

    async_session = sessionmaker(
        engine, expire_on_commit=False, class_=CachedAsyncSession, cache_config=cache_config
    )

    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)

    return async_session, cache_config

async def run_queries(async_session):
    async with async_session() as session:
        async with session.begin():
            session.add_all([
                User(name='user1', email='user1@example.com'),
                User(name='user2', email='user2@example.com'),
                User(name='user3', email='user3@example.com'),
            ])
        # Example with parameters
        stmt = select(User).where(User.name == "user1")
        result = await session.execute(stmt)
        user = result.first()
        print("First Query Result (might be from DB):", user)

        # Second query - should be from cache
        stmt = select(User).where(User.name == "user1")
        result = await session.execute(stmt)
        user = result.first()
        print("Second Query Result (should be from cache):", user)

        # Example to invalidation cache.
        await session.invalidate_cache("db_cache:*")

        # This query is after invalidation and will hit the database
        stmt = select(User).where(User.name == "user1")
        result = await session.execute(stmt)
        user = result.first()
        print("Third Query Result (after invalidation, from DB):", user)

In [49]:
async_session, cache_config = await setup_db_and_cache()



2024-12-10 02:04:13,266 INFO sqlalchemy.engine.Engine BEGIN (implicit)


2024-12-10 02:04:13,269 INFO sqlalchemy.engine.Engine PRAGMA main.table_info("users")
2024-12-10 02:04:13,271 INFO sqlalchemy.engine.Engine [raw sql] ()
2024-12-10 02:04:13,282 INFO sqlalchemy.engine.Engine PRAGMA temp.table_info("users")
2024-12-10 02:04:13,283 INFO sqlalchemy.engine.Engine [raw sql] ()
2024-12-10 02:04:13,291 INFO sqlalchemy.engine.Engine 
CREATE TABLE users (
	id INTEGER NOT NULL, 
	name VARCHAR, 
	email VARCHAR, 
	PRIMARY KEY (id)
)


2024-12-10 02:04:13,292 INFO sqlalchemy.engine.Engine [no key 0.00126s] ()
2024-12-10 02:04:13,295 INFO sqlalchemy.engine.Engine COMMIT


In [50]:
await run_queries(async_session)

2024-12-10 02:04:13,317 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-12-10 02:04:13,322 INFO sqlalchemy.engine.Engine INSERT INTO users (name, email) VALUES (?, ?) RETURNING id
2024-12-10 02:04:13,325 INFO sqlalchemy.engine.Engine [generated in 0.00036s (insertmanyvalues) 1/3 (ordered; batch not supported)] ('user1', 'user1@example.com')
2024-12-10 02:04:13,332 INFO sqlalchemy.engine.Engine INSERT INTO users (name, email) VALUES (?, ?) RETURNING id
2024-12-10 02:04:13,333 INFO sqlalchemy.engine.Engine [insertmanyvalues 2/3 (ordered; batch not supported)] ('user2', 'user2@example.com')
2024-12-10 02:04:13,337 INFO sqlalchemy.engine.Engine INSERT INTO users (name, email) VALUES (?, ?) RETURNING id
2024-12-10 02:04:13,338 INFO sqlalchemy.engine.Engine [insertmanyvalues 3/3 (ordered; batch not supported)] ('user3', 'user3@example.com')
2024-12-10 02:04:13,345 INFO sqlalchemy.engine.Engine COMMIT
2024-12-10 02:04:13,356 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-12-10 02:04: