In [45]:
from pymilvus import connections, utility, FieldSchema, CollectionSchema, DataType, Collection, MilvusClient
import logging
from typing import Any
import polars as pl
from pathlib import Path
from milvus_model.hybrid import BGEM3EmbeddingFunction
from LegalDefAgent.src.settings import settings
from torch.cuda import is_available as cuda_available

logger = logging.getLogger(__name__)

class VectorDBBuilder:
    def __init__(self):
        self.config = settings
        self.milvus_uri = self.config.MILVUSDB_URI
        self.collection_name = self.config.MILVUSDB_COLLECTION_NAME
        self.batch_size = self.config.DB_CONFIG.BATCH_SIZE
        self.ef = BGEM3EmbeddingFunction(
                model_name='BAAI/bge-m3',
                device='cuda' if cuda_available() else 'cpu',
                use_fp16=True if cuda_available() else False #set to false if device='cpu'
            )
        self.dense_dim = self.ef.dim["dense"]
        
    def setup_collection(self) -> Collection:
        """Setup Milvus collection with proper schema."""
        fields = [
            FieldSchema(
                name="id", 
                dtype=DataType.INT64,
                is_primary=True, 
                #auto_id=True, 
                #max_length=100
            ),
            FieldSchema(
                name="definition_text", 
                dtype=DataType.VARCHAR, 
                max_length=5000
            ),
            FieldSchema(
                name="definendum_label", 
                dtype=DataType.VARCHAR, 
                max_length=256
            ),
            FieldSchema(
                name="dataset", 
                dtype=DataType.VARCHAR, 
                max_length=10
            ),
            FieldSchema(
                name="document_id", 
                dtype=DataType.VARCHAR, 
                max_length=40
            ),
            FieldSchema(
                name="frbr_work", 
                dtype=DataType.VARCHAR, 
                max_length=120
            ),
            FieldSchema(
                name="frbr_expression", 
                dtype=DataType.VARCHAR, 
                max_length=120
            ),
            FieldSchema(
                name="sparse_vector", 
                dtype=DataType.SPARSE_FLOAT_VECTOR,
            ),
            FieldSchema(
                name="dense_vector", 
                dtype=DataType.FLOAT_VECTOR,
                dim=self.dense_dim
            ),
        ]
        
        schema = CollectionSchema(fields, "Definitions embeddings")
        collection_name = "Definitions"
        
        # Drop existing collection if it exists
        if utility.has_collection(collection_name):
            Collection(collection_name).drop()
            
        collection = Collection(
            collection_name, 
            schema, 
            consistency_level="Strong"
        )
        
        # Create and load index
        sparse_index = {
            "index_type": "SPARSE_INVERTED_INDEX", 
            "metric_type": "IP"
        }
        dense_index = {
            "index_type": "FLAT", 
            "metric_type": "COSINE"
        }
        collection.create_index("sparse_vector", sparse_index)
        collection.create_index("dense_vector", dense_index)
        collection.load()
        
        return collection
    
    def build_vector_db(self, df: pl.DataFrame, defs_embeddings=None, definendum_embeddings=None) -> None:
        """Build vector database from processed definitions."""
        logger.info("Building vector database...")
        try:

            Path(self.milvus_uri).parent.mkdir(parents=True, exist_ok=True)
            client = MilvusClient(
                uri=self.milvus_uri
            )
            connections.connect(uri=self.milvus_uri)
            
            # Setup collection
            collection = self.setup_collection()

            definendum_list = (df.select(
                pl.col('label')
                .str.replace('#', '')
                .str.replace(r'([a-zà-ÿ])([A-Z])', r'${1} ${2}', n=-1)  # Add space between lowercase and uppercase letters
                .str.to_lowercase()  # Convert to lowercase after splitting
            )
            )['label'].to_list()

            
            # Generate embeddings and insert in batches
            for i in range(0, len(df), self.batch_size):
                batch_df = df.slice(i, self.batch_size)
                
                # Generate embeddings for the batch
                batch_texts = batch_df['definition_text'].to_list()
                if not defs_embeddings:
                    batch_embeddings = self.ef(batch_texts)
                else:
                    batch_sparse_embeddings = definendum_embeddings['sparse'][i:i+self.batch_size]
                    batch_dense_embeddings = defs_embeddings['dense'][i:i+self.batch_size]
                
                # Prepare batch data
                batch_data = [
                    batch_df['id'].to_list(),
                    batch_df['definition_text'].to_list(),
                    batch_df['label'].to_list(),
                    batch_df['def_n'].to_list(),
                    batch_df['dataset'].to_list(),
                    batch_df['document_id'].to_list(),
                    batch_df['frbr_work'].to_list(),
                    batch_df['frbr_expression'].to_list(),
                    batch_sparse_embeddings,
                    batch_dense_embeddings
                ]
                
                # Insert batch
                collection.insert(batch_data)
                
            logger.info(f"Inserted {collection.num_entities} entities into vector database")
            
        except Exception as e:
            logger.error(f"Error building vector database: {e}")
            raise
        finally:
            connections.disconnect(alias='default')

In [44]:
import pickle

df = pl.read_parquet('../data/definitions_corpus/definitions.parquet')

with open('../data/definendums_embeddings.pkl', 'rb') as f:
    definendums_emb = pickle.load(f)


with open('../data/def_embeddings.pkl', 'rb') as f:
    defs_emb = pickle.load(f)


In [46]:
builder = VectorDBBuilder()

builder.build_vector_db(df=df, defs_embeddings=defs_emb, definendum_embeddings=definendums_emb)

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [5]:
from typing import Annotated, Literal, Sequence, Dict, Any
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langgraph.graph.message import add_messages
from langgraph.graph import START, StateGraph
from typing import TypedDict


# Define parent graph
class SupervisorState(TypedDict):
    input: str
    messages: Annotated[Sequence[BaseMessage], add_messages]
    foo: str  # note that this key is shared with the parent graph state


# Define subgraph
class DefinitionAgentState(TypedDict):
    question: str
    query: Dict[str, str]
    retrieved_definitions: list[str]
    relevant_definitions: list[str]
    foo: str  # note that this key is shared with the parent graph state
    bar: str


def subgraph_node_1(state: DefinitionAgentState):
    return {"bar": "bar"}


def subgraph_node_2(state: DefinitionAgentState):
    # note that this node is using a state key ('bar') that is only available in the subgraph
    # and is sending update on the shared state key ('foo')
    return {"foo": state["foo"] + state["bar"]}


subgraph_builder = StateGraph(DefinitionAgentState)
subgraph_builder.add_node(subgraph_node_1)
subgraph_builder.add_node(subgraph_node_2)
subgraph_builder.add_edge(START, "subgraph_node_1")
subgraph_builder.add_edge("subgraph_node_1", "subgraph_node_2")
subgraph = subgraph_builder.compile()


# Define parent graph
class SupervisorState(TypedDict):
    foo: str


def node_1(state: SupervisorState):
    return {"foo": "hi! " + state["foo"]}


builder = StateGraph(SupervisorState)
builder.add_node("node_1", node_1)
# note that we're adding the compiled subgraph as a node to the parent graph
builder.add_node("node_2", subgraph)
builder.add_edge(START, "node_1")
builder.add_edge("node_1", "node_2")
graph = builder.compile()

In [7]:
for chunk in graph.stream({"foo": "foo"}):
    print(chunk)

print()
for chunk in graph.stream({"foo": "foo"}, subgraphs=True):
    print(chunk)

{'node_1': {'foo': 'hi! foo'}}
{'node_2': {'foo': 'hi! foobar'}}

((), {'node_1': {'foo': 'hi! foo'}})
(('node_2:2a0eb3a6-537b-21d4-1e82-09d2163162bc',), {'subgraph_node_1': {'bar': 'bar'}})
(('node_2:2a0eb3a6-537b-21d4-1e82-09d2163162bc',), {'subgraph_node_2': {'foo': 'hi! foobar'}})
((), {'node_2': {'foo': 'hi! foobar'}})


----

In [14]:
from typing import Literal
from typing_extensions import TypedDict

from langchain_groq import ChatGroq
from langgraph.graph import MessagesState, END
from langgraph.types import Command

from LegalDefAgent.src import utils
from LegalDefAgent.src.settings import settings

members = ["researcher", "coder"]
# Our team supervisor is an LLM node. It just picks the next agent to process
# and decides when the work is completed
options = members + ["FINISH"]

system_prompt = (
    "You are a supervisor tasked with managing a conversation between the"
    f" following workers: {members}. Given the following user request,"
    " respond with the worker to act next. Each worker will perform a"
    " task and respond with their results and status. When finished,"
    " respond with FINISH."
)


class Router(TypedDict):
    """Worker to route to next. If no workers needed, route to FINISH."""

    next: Literal[*options]


llm = ChatGroq(model="gemma2-9b-it")


def supervisor_node(state: MessagesState) -> Command[Literal[*members, "__end__"]]:
    messages = [
        {"role": "system", "content": system_prompt},
    ] + state["messages"]
    response = llm.with_structured_output(Router).invoke(messages)
    goto = response["next"]
    if goto == "FINISH":
        goto = END

    return Command(goto=goto)

In [15]:
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import create_react_agent
from langchain_core.tools import tool



@tool
def tavily_tool(urls) -> str:
    """Use requests and bs4 to scrape the provided web pages for detailed information."""
    return "The meaning of life is 33."


research_agent = create_react_agent(
    llm, tools=[tavily_tool], state_modifier="You are a researcher. DO NOT do any math."
)


def research_node(state: MessagesState) -> Command[Literal["supervisor"]]:
    result = research_agent.invoke(state)
    return Command(
        update={
            "messages": [
                HumanMessage(content=result["messages"][-1].content, name="researcher")
            ]
        },
        goto="supervisor",
    )


# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION, WHICH CAN BE UNSAFE WHEN NOT SANDBOXED
code_agent = create_react_agent(llm, tools=[tavily_tool])


def code_node(state: MessagesState) -> Command[Literal["supervisor"]]:
    result = code_agent.invoke(state)
    return Command(
        update={
            "messages": [
                HumanMessage(content=result["messages"][-1].content, name="coder")
            ]
        },
        goto="supervisor",
    )


builder = StateGraph(MessagesState)
builder.add_edge(START, "supervisor")
builder.add_node("supervisor", supervisor_node)
builder.add_node("researcher", research_node)
builder.add_node("coder", code_node)
graph = builder.compile()

In [17]:
for s in graph.stream(
    {
        "messages": [
            (
                "user",
                "What's the meaning of life?",
            )
        ]
    },
    subgraphs=True,
):
    print(s)
    print("----")

((), {'supervisor': None})
----
(('researcher:0ca49cf6-2cd5-2dee-734a-586fed6e5337',), {'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_6v95', 'function': {'arguments': '{"urls":"https://www.example.com"}', 'name': 'tavily_tool'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 88, 'prompt_tokens': 959, 'total_tokens': 1047, 'completion_time': 0.16, 'prompt_time': 0.030585387, 'queue_time': 0.022670333, 'total_time': 0.190585387}, 'model_name': 'gemma2-9b-it', 'system_fingerprint': 'fp_10c08bf97d', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-a89004bf-434e-4a53-b649-64f88a2d7ca5-0', tool_calls=[{'name': 'tavily_tool', 'args': {'urls': 'https://www.example.com'}, 'id': 'call_6v95', 'type': 'tool_call'}], usage_metadata={'input_tokens': 959, 'output_tokens': 88, 'total_tokens': 1047})]}})
----
(('researcher:0ca49cf6-2cd5-2dee-734a-586fed6e5337',), {'tools': {'messages': [ToolMessage(content='The meani

KeyboardInterrupt: 

In [None]:
legal_agent = create_react_agent(
    llm, tools=[tavily_tool], state_modifier="You are a researcher. DO NOT do any math."
)

In [22]:
from datetime import datetime
from typing import Literal

from langchain_community.tools import DuckDuckGoSearchResults, OpenWeatherMapQueryRun
from langchain_community.utilities import OpenWeatherMapAPIWrapper
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.managed import RemainingSteps
from langgraph.prebuilt import ToolNode

from LegalDefAgent.src.llm import get_model
from LegalDefAgent.src.settings import settings
from LegalDefAgent.src.retriever import vector_store, exist_db


# ---- tools

from langchain_core.tools import tool

vectorstore = vector_store.setup_vectorstore()
retriever = vectorstore.as_retriever(search_kwargs={"k": 7})

async def query_vectorstore(query: dict):
    #implement hybrid search
    retrieved_definitions = retriever.invoke(query)

    return [{'definition_text': d.page_content, 'metadata': {'dataset': d.metadata['dataset'], 'document': d.metadata['document_id']}} for d in retrieved_definitions[:2]]


async def definition_search(query: str, legislation: str = None, date_filters: tuple = None) -> str:
    """
    Searches and retrieves the most similar definitions to the given query in a vector DB.

    Args:
        definendum: The term to be defined, as extracted from the user's query.
        legislation: The legislation to search in.
            possible values: "EU", "IT", None
        date_filters: Date filters in the form of a tuple (from_date, to_date) to apply to the search.
            e.g. ("2021-01-01", "2021-12-31")
    """

    retrieved_definitions = await query_vectorstore(query)

    return retrieved_definitions

In [2]:
from LegalDefAgent.src.llm import _MODEL_TABLE

_MODEL_TABLE_INV = {v: k for k, v in _MODEL_TABLE.items()}
_MODEL_TABLE_INV

{'gpt-4o-mini': <OpenAIModelName.GPT_4O_MINI: 'gpt-4o-mini'>,
 'gpt-4o': <OpenAIModelName.GPT_4O: 'gpt-4o'>,
 'llama3-8b-8192': <GroqModelName.LLAMA_3_8B: 'groq-llama3-8b-8192'>,
 'llama3-70b-8192': <GroqModelName.LLAMA_3_70B: 'groq-llama3-70b-8192'>,
 'llama-3.3-70b-versatile': <GroqModelName.LLAMA_33_70B: 'groq-llama-3.3-70b-versatile'>,
 'gemma2-9b-it': <GroqModelName.GEMMA2_9B_IT: 'groq-gemma2-9b-it'>,
 'open-mistral-nemo': <MistralModelName.NEMO_12B: 'open-mistral-nemo'>,
 'gemma2:2b': <OllamaModelName.GEMMA2_2B: 'ollama-gemma2:2b'>,
 'llama3.2': <OllamaModelName.LLAMA_32_3B: 'ollama-llama3.2'>,
 'phi3': <OllamaModelName.PHI3_4B: 'ollama-phi3'>,
 'fake': <FakeModelName.FAKE: 'fake'>}

In [23]:
await definition_search("definendum")

[{'definition_text': "examination: means a formalised test evaluating the person's knowledge and understanding;",
  'metadata': {'dataset': 'EurLex', 'document': '32015R0340.xml'}},
 {'definition_text': 'data item: means a single attribute of a complete data set, which is allocated a value that defines its current status;',
  'metadata': {'dataset': 'EurLex', 'document': '32010R0073.xml'}}]

In [23]:
from langchain_core.tools import tool
from langgraph.graph import START, StateGraph
from typing import TypedDict
from typing import Annotated, Literal, Sequence, Dict, Any
from typing_extensions import TypedDict
import uuid

from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langgraph.graph.message import add_messages
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.callbacks import dispatch_custom_event
from langchain_core.runnables import RunnableConfig


from LegalDefAgent.src.retriever import vector_store
from LegalDefAgent.src.settings import settings
from LegalDefAgent.src.schema.task_data import Task
from LegalDefAgent.src.schema.definition import DefinitionsList, Definition, RelevantDefinitionsIDList
from LegalDefAgent.src.schema.grader import DefinitionRelevance

from LegalDefAgent.src.schema.task_data import Task
import LegalDefAgent.src.utils as utils
from LegalDefAgent.src.llm import get_model


vectorstore = vector_store.setup_vectorstore()
retriever = vectorstore.as_retriever(search_kwargs={"k": 7})


async def query_vectorstore(query: str):
    #implement hybrid search
    retrieved_definitions = await retriever.ainvoke(query)

    return utils.docs_list_to_json_list(retrieved_definitions)


async def semantic_filter(model, question, retrieved_definitions):
    parser = JsonOutputParser(pydantic_object=RelevantDefinitionsIDList)
    prompt = PromptTemplate(
        template="""
        You are a legal expert assessing the relevance of legal definitions to a user question.
        Below you are provided with a dcitionary containing definitions that were automatically retrieved.
        Your task is to filter the list of definitions provided to you keeping only the relevant ones.
        For each  the text of the definition contains keyword(s) or semantic meaning related to the user's question, keep it. Otherwise discard it.
        Output only the relevant definitions using the formatting instructions provided.
        VERY IMPORTANT NOTES:
        * You can only answer with valid, directly parsable json.
        * If you can't find any relevant definitions, you should output this: "relevant_definitions": []\n\n
        Here are the formatting instructions: {format_instructions}
        Here are the retrieved definitions: {context}
        Here is the question asked by the user: {question}
        """,
        input_variables=["context", "question"],
        partial_variables={"format_instructions": parser.get_format_instructions()}
    )

    chain = prompt | model | parser
    response = chain.invoke({"context": retrieved_definitions, "question": question}) # definendum or question??

    return response


async def definition_search(question: str, definendum: str, legislation: str | None = None, date_filters: tuple | None = None) -> str:
    """
    Searches and retrieves the most similar definitions to the given query in a vector DB.

    Args:
        question: The entire user's question.
        definendum: The term to be defined, as extracted from the user's query.
        legislation: The legislation to search in. Possible values: "EU", "IT", None
        date_filters: Date filters in the form of a tuple (from_date, to_date) to apply to the search. e.g. ("2021-01-01", "2021-12-31")
    """

    model = get_model(settings.DEFAULT_MODEL)

    retrieved_definitions = await query_vectorstore(definendum)
    print(retrieved_definitions)

    relevant_definitions_ids = await semantic_filter(model, question, retrieved_definitions)
    print(relevant_definitions_ids)

    relevant_definitions = [d for d in retrieved_definitions if d['metadata']['id'] in relevant_definitions_ids['relevant_definitions']]


    return relevant_definitions

In [24]:
question = "What's the definition of dog in the EurLex dataset?"
definendum = "dog"

await definition_search(question=question, definendum=definendum)

[{'metadata': {'dataset': 'EurLex', 'def_n': '#def_29', 'definendum_label': '#dog', 'document_id': '32020R0688.xml', 'frbr_expression': '/akn/eu/act/regulation/2019-12-17/688/eng@/!main', 'frbr_work': '/akn/eu/act/regulation/2019-12-17/688/!main', 'id': 8263}, 'definition_text': 'dog: means a kept animal of the Canis lupus species;'}, {'metadata': {'dataset': 'EurLex', 'def_n': '#def_10', 'definendum_label': '#dog', 'document_id': '32020R0689.xml', 'frbr_expression': '/akn/eu/act/regulation/2019-12-17/689/eng@/!main', 'frbr_work': '/akn/eu/act/regulation/2019-12-17/689/!main', 'id': 5008}, 'definition_text': 'dog: means a kept animal of the Canis lupus species;'}, {'metadata': {'dataset': 'EurLex', 'def_n': '#def_1', 'definendum_label': '#dog', 'document_id': '32019R2035.xml', 'frbr_expression': '/akn/eu/act/regulation/2019-06-28/2035/eng@/!main', 'frbr_work': '/akn/eu/act/regulation/2019-06-28/2035/!main', 'id': 2772}, 'definition_text': 'dog: means a kept animal of the Canis lupus sp

[{'metadata': {'dataset': 'EurLex',
   'def_n': '#def_29',
   'definendum_label': '#dog',
   'document_id': '32020R0688.xml',
   'frbr_expression': '/akn/eu/act/regulation/2019-12-17/688/eng@/!main',
   'frbr_work': '/akn/eu/act/regulation/2019-12-17/688/!main',
   'id': 8263},
  'definition_text': 'dog: means a kept animal of the Canis lupus species;'},
 {'metadata': {'dataset': 'EurLex',
   'def_n': '#def_10',
   'definendum_label': '#dog',
   'document_id': '32020R0689.xml',
   'frbr_expression': '/akn/eu/act/regulation/2019-12-17/689/eng@/!main',
   'frbr_work': '/akn/eu/act/regulation/2019-12-17/689/!main',
   'id': 5008},
  'definition_text': 'dog: means a kept animal of the Canis lupus species;'},
 {'metadata': {'dataset': 'EurLex',
   'def_n': '#def_1',
   'definendum_label': '#dog',
   'document_id': '32019R2035.xml',
   'frbr_expression': '/akn/eu/act/regulation/2019-06-28/2035/eng@/!main',
   'frbr_work': '/akn/eu/act/regulation/2019-06-28/2035/!main',
   'id': 2772},
  'def

In [13]:
from langchain_core.messages import AIMessage, SystemMessage, trim_messages

state ={   "messages": [
    {
      "content": "hi",
      "additional_kwargs": {},
      "response_metadata": {},
      "type": "human",
      "id": "fe160588-b249-40f3-92de-0e3887eb6025",
      "example": False
    },
    {
        "content": "What's the definition of dog in the EurLex dataset?",
        "additional_kwargs": {},
        "response_metadata": {},
        "type": "human",
        "id": "b6f1d8f4-4e8b-4b7c-9b2f-5a1c4b1f4d6b",
        "example": False
    },
    {'foo': 'bar'}
  ]
}

print(state['messages'][-1])
state['messages'] = trim_messages(state['messages'], token_counter=len, strategy="last", max_tokens=1)
print(state['messages'][-1])

{'foo': 'bar'}


ValueError: Message dict must contain 'role' and 'content' keys, got {'foo': 'bar'}
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/MESSAGE_COERCION_FAILURE 