## Dynamic Metadata Filtering for Knowledge Bases

> Source: aws-samples complete [notebook here](https://github.com/aws-samples/rag-workshop-amazon-bedrock-knowledge-bases/blob/main/03-advanced-concepts/dynamic-metadata-filtering/dynamic-metadata-filtering-KB.ipynb)

In [None]:
import json
import boto3
from typing import List, Optional
from pydantic import BaseModel, model_validator

In [None]:
# Session init
session = boto3.session.Session()
region = session.region_name
bedrock = boto3.client("bedrock-runtime", region_name=region)
bedrock_agent_runtime = boto3.client("bedrock-agent-runtime")

MODEL_ID = "anthropic.claude-3-haiku-20240307-v1:0" 

### Define Pydantic Models
We'll use Pydantic models to validate and structure our extracted entities:

In [None]:
class Entity(BaseModel):
    Publisher: Optional[str]
    Year: Optional[int]

class ExtractedEntities(BaseModel):
    entities: List[Entity]

    @model_validator(mode='before')
    def remove_duplicates(cls, values):
        # Ensure `entities` is in the dictionary
        if 'entities' in values:
            unique_entities = []
            seen = set()
            for entity in values['entities']:
                entity_tuple = tuple(sorted(entity.items()))
                if entity_tuple not in seen:
                    seen.add(entity_tuple)
                    unique_entities.append(dict(entity_tuple))
            values['entities'] = unique_entities
        return values

## Implement Entity Extraction using Tool Use
We'll define a tool for entity extraction with very basic instructions and use it with Amazon Bedrock:


In [None]:
tool_name = "extract_entities"
tool_description = "Extract named entities from the text. If you are not 100% sure of the entity value, use 'unknown'."

tool_extract_entities = ["Publisher", "Year"]
tool_extract_property = ["entities"]

tool_entity_description = {
    "Publisher": {"type": "string", "description": "The publisher of the game. First alphabet is upper case."},
    "Year": {"type": "integer", "description": "The year when the game was released."}
}

tool_properties = {
    'tool_name':tool_name,
    'tool_description':tool_description,
    'tool_extract_entities':tool_extract_entities,
    'tool_extract_property':tool_extract_property,
    'tool_entity_description': tool_entity_description
}

def extract_entities(text, tool_properties):   
    tools = [{
            "toolSpec": {
                "name": tool_properties['tool_name'],
                "description": tool_properties['tool_description'],
                "inputSchema": {
                    "json": {
                        "type": "object",
                        "properties": {
                            "entities": {
                                "type": "array",
                                "items": {
                                    "type": "object",
                                    "properties": tool_properties['tool_entity_description'],
                                    "required": tool_properties['tool_extract_entities']
                                }
                            }
                        },
                        "required": tool_properties['tool_extract_property']
                    }
                }
            }
        }]
    
    response = bedrock.converse(
        modelId=MODEL_ID,
        inferenceConfig={
            "temperature": 0,
            "maxTokens": 4000
        },
        toolConfig={"tools": tools},
        messages=[{"role": "user", "content": [{"text": text}]}]
    )

    json_entities = None
    for content in response['output']['message']['content']:
        if "toolUse" in content and content['toolUse']['name'] == "extract_entities":
            json_entities = content['toolUse']['input']
            break

    if json_entities:
        return ExtractedEntities.model_validate(json_entities)
    else:
        print("No entities found in the response.")
        return None

## Construct Metadata Filter
Now, let's create a function to construct the metadata filter based on the extracted entities:

In [None]:
def construct_metadata_filter(extracted_entities):
    if not extracted_entities or not extracted_entities.entities:
        return None

    entity = extracted_entities.entities[0]
    metadata_filter = {"andAll": []}

    if entity.Publisher and entity.Publisher != 'unknown':
        metadata_filter["andAll"].append({
            "equals": {
                "key": "Publisher",
                "value": entity.Publisher
            }
        })

    if entity.Year and entity.Year != 'unknown':
        metadata_filter["andAll"].append({
            "greaterThanOrEquals": {
                "key": "Year",
                "value": int(entity.Year)
            }
        })

    return metadata_filter if metadata_filter["andAll"] else None

## Example

In [None]:
text="Provide a list of all video games published by Rockstar Games and released after 2010"
extracted_entities = extract_entities(text, tool_properties)
metadata_filter = construct_metadata_filter(extracted_entities)
print('Here is the prepared metadata filters:')
print(metadata_filter)

---
### Finalmente, podemos llamar la API de Retrieve con los nuevos filtros:
> e.g.
```python
def process_query(text, tool_properties):
    extracted_entities = extract_entities(text, tool_properties)
    metadata_filter = construct_metadata_filter(extracted_entities)
    print('Here is the prepared metadata filters:')
    print(metadata_filter)

    response = bedrock_agent_runtime.retrieve(
        knowledgeBaseId=kb_id,
        retrievalConfiguration={
            "vectorSearchConfiguration": {
                "filter": metadata_filter
            }
        },
        retrievalQuery={
            'text': text
        }
    )
    return response
```