# Step 16: In-Memory Vector Store Cache

This notebook demonstrates caching LLM responses using in-memory vector store with semantic search.

In [None]:
import os
from dotenv import load_dotenv

load_dotenv(override=True)

In [None]:
import asyncio
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Annotated
from uuid import uuid4

from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.connectors.ai.open_ai import (
    AzureChatCompletion,
    AzureTextEmbedding,
)
from semantic_kernel.connectors.in_memory import InMemoryStore
from semantic_kernel.data.vector import (
    VectorStoreField,
    vectorstoremodel,
    FieldTypes,
    VectorSearchOptions,
    VectorStore,
    VectorStoreCollection,
)
from semantic_kernel.filters import (
    FilterTypes,
    FunctionInvocationContext,
    PromptRenderContext,
)
from semantic_kernel.functions import FunctionResult

## Define Cache Data Model

In [None]:
@vectorstoremodel
@dataclass
class CacheRecord:
    prompt: Annotated[str, VectorStoreField(is_indexed=True)]
    result: Annotated[str, VectorStoreField(is_full_text_indexed=True)]
    prompt_embedding: Annotated[
        list[float], VectorStoreField(field_type=FieldTypes.VECTOR, dimensions=1536)
    ] = field(default_factory=list)
    id: Annotated[str, VectorStoreField(field_type=FieldTypes.KEY)] = field(
        default_factory=lambda: str(uuid4())
    )

## Define Prompt Cache Filter

This filter intercepts prompt rendering and function invocation to implement semantic caching.

In [None]:
COLLECTION_NAME = "llm_responses"
RECORD_ID_KEY = "cache_record_id"


class PromptCacheFilter:
    """A filter to cache the results of the prompt rendering and function invocation."""

    def __init__(
        self,
        embedding_service: EmbeddingGeneratorBase,
        vector_store: VectorStore,
        collection_name: str = COLLECTION_NAME,
        score_threshold: float = 0.2,
    ):
        self.embedding_service = embedding_service
        self.vector_store = vector_store
        self.collection: VectorStoreCollection[str, CacheRecord] = (
            vector_store.get_collection(CacheRecord, collection_name=collection_name)
        )
        self.score_threshold = score_threshold

    async def on_prompt_render(
        self,
        context: PromptRenderContext,
        next: Callable[[PromptRenderContext], Awaitable[None]],
    ):
        """Filter to cache the rendered prompt and the result of the function.

        It uses the score threshold to determine if the result should be cached.
        The direction of the comparison is based on the default distance metric for
        the in memory vector store, which is cosine distance, so the closer to 0 the
        closer the match.
        """
        await next(context)
        assert context.rendered_prompt  # nosec
        prompt_embedding = await self.embedding_service.generate_raw_embeddings(
            [context.rendered_prompt]
        )
        # Collection is created automatically on first use
        results = await self.collection.search(
            vector=prompt_embedding[0],
            vector_property_name="prompt_embedding",
            top=1,
        )
        async for result in results.results:
            if result.score < self.score_threshold:
                context.function_result = FunctionResult(
                    function=context.function.metadata,
                    value=result.record.result,
                    rendered_prompt=context.rendered_prompt,
                    metadata={RECORD_ID_KEY: result.record.id},
                )

    async def on_function_invocation(
        self,
        context: FunctionInvocationContext,
        next: Callable[[FunctionInvocationContext], Awaitable[None]],
    ):
        """Filter to store the result in the cache if it is new."""
        await next(context)
        result = context.result
        if result and result.rendered_prompt and RECORD_ID_KEY not in result.metadata:
            prompt_embedding = await self.embedding_service.generate_embeddings(
                [result.rendered_prompt]
            )
            cache_record = CacheRecord(
                prompt=result.rendered_prompt,
                result=str(result),
                prompt_embedding=prompt_embedding[0],
            )
            # Collection is created automatically on first use
            await self.collection.upsert(cache_record)

## Setup Kernel and Services

In [None]:
# Initialize kernel and services
kernel = Kernel()
chat = AzureChatCompletion(service_id="default")
embedding = AzureTextEmbedding(service_id="embedder")
kernel.add_service(chat)
kernel.add_service(embedding)

In [None]:
# Create in-memory vector store
vector_store = InMemoryStore()
print("Vector store initialized")

## Register Cache Filter

Create and register the cache filter with the kernel.

In [None]:
# Create the cache filter and add it to the kernel
cache = PromptCacheFilter(embedding_service=embedding, vector_store=vector_store)
kernel.add_filter(FilterTypes.PROMPT_RENDERING, cache.on_prompt_render)
kernel.add_filter(FilterTypes.FUNCTION_INVOCATION, cache.on_function_invocation)
print("Cache filter registered")

## Test Semantic Caching

Run queries to demonstrate semantic caching behavior. Similar queries should be retrieved from cache.

In [None]:
# Example query
async def execute_async(title: str, prompt: str):
    print(f"{title}: {prompt}")
    start = time.time()
    result = await kernel.invoke_prompt(prompt)
    elapsed = time.time() - start
    print(f"\tElapsed Time: {elapsed:.3f}")
    return result

In [None]:
# First query - will make an actual LLM call
result1 = await execute_async("First run", "What's the tallest building in New York?")
print(f"Result 1: {result1}")

In [None]:
# Second query - different topic, will make another LLM call
result2 = await execute_async("Second run", "How are you today?")
print(f"Result 2: {result2}")

In [None]:
# Third query - semantically similar to first, should retrieve from cache (faster!)
result3 = await execute_async("Third run", "What is the highest building in New York City?")
print(f"Result 3: {result3}")
print("\nâš¡ Notice the significantly reduced time for the third query - it was retrieved from cache!")

## Interactive Testing

Use the loop below to test with your own queries and observe caching behavior.

In [None]:
# Interactive loop for testing queries
# Type 'exit' to quit
while True:
    user_input = input("User > ")
    
    if user_input == "exit":
        break
    
    result = await execute_async("Test", user_input)
    print(f"Result: {result}")