# Entity Extraction from New Hampshire Case Law
*With IBM Granite Models*

The [New Hampshire Case Law Dataset](https://huggingface.co/datasets/free-law/nh) comes from the Caselaw Access Project via Hugging Face.

## In this notebook

In this notebook, we'll explore the process of extracting meaningful information from text using entity extraction techniques, and then leverage that information to build and query a simple knowledge graph. Specifically, we'll guide you through the following steps:

- **Entity Extraction**: We'll start by processing a body of text to identify and extract key entities, such as people, organizations, dates, and locations. Entity extraction is a crucial part of natural language processing (NLP) that helps in transforming raw text into structured data.
- **Knowledge Graph Construction**: Once we've extracted entities, we'll build a basic knowledge graph. A knowledge graph represents entities as nodes and relationships as edges, providing a structured way to understand the interconnections between different entities within the text. This helps in visualizing and storing the extracted data meaningfully.
- **Querying the Knowledge Graph**: With the knowledge graph in place, we can retrieve specific information by posing questions. We'll implement methods to query the graph, including resolving entities from the question to entities in the graph. This will allow us to identify relevant graph structures that correspond to the user's query.
- **Question Answering**: Finally, we'll use the results retrieved from the knowledge graph to answer the user's question. By using the structured information from the graph, we can provide detailed answers and offer insights into the relationships and context within the body of text.

This process of transforming unstructured text into a knowledge graph and then querying it is useful for applications such as legal research, medical case studies, or business intelligence. By the end of this notebook, you'll have hands-on experience building a simple pipeline that takes raw text, extracts valuable entities, and then allows users to query the data to obtain meaningful answers. Equipped with these techniques, we will move on to the more sophisticated techniques of Graph RAG.

## Prerequisites

To get started, you'll need:
* A [Replicate account](https://replicate.com/) and API token.

## Setting up the environment

### Install dependencies

Granite Kitchen comes with a bundle of dependencies that are required for notebooks. See the list of packages in its [`setup.py`](https://github.com/ibm-granite-community/granite-kitchen/blob/main/setup.py). 

In [None]:
!pip install git+https://github.com/ibm-granite-community/utils \
    langchain_community \
    replicate \
    datasets \
    transformers \
    tiktoken \
    neo4j \
    stringcase \
    langchain_huggingface \
    sentence-transformers \
    langchain_chroma

## Selecting System Components

### Choose your LLM
The LLM will be used for answering the question, given the retrieved text.

Follow the instructions in [Getting Started with Replicate](https://github.com/ibm-granite-community/granite-kitchen/blob/cee1513c77429d7ddbf0e5a49b29b7bc9ca0d996/recipes/Getting_Started/Getting_Started_with_Replicate.ipynb), selecting a Granite Code model from the [`ibm-granite`](https://replicate.com/ibm-granite) org.

To connect to a model on a provider other than Replicate, substitute this code cell with one from the [LLM component recipe](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Components/Langchain_LLMs.ipynb).

In [None]:
from langchain_community.llms import Replicate
from ibm_granite_community.notebook_utils import set_env_var, get_env_var

model = Replicate(
    model="ibm-granite/granite-3.0-8b-instruct",
    replicate_api_token=get_env_var("REPLICATE_API_TOKEN"),
)

### Get the tokenizer

Retrieve the tokenizer used by your chosen LLM.

In [None]:
from transformers import AutoTokenizer

model_path = "ibm-granite/granite-3.0-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)

## Acquiring the Data

We will use a New Hampshire case law dataset to help the model answer questions about NH laws.

### Download the documents

Download the [New Hampshire CAP Caselaw](https://huggingface.co/datasets/free-law/nh) dataset from HuggingFace using the datasets library. Each document is a case; there should be about 21,540 of them in total. We will use a small subset in this recipe.

In [None]:
from langchain.document_loaders import HuggingFaceDatasetLoader

# Load the documents from the dataset
loader = HuggingFaceDatasetLoader("free-law/nh", page_content_column="text")
all_documents = loader.load()
print("Document Count: " + str(len(all_documents)))

### Inspect the documents

The documents contain case law text in their `page_content`, and metadata fields for the court name, decision date, and other information. 

In [None]:
for doc in all_documents[:1]:
    print(json.dumps(doc.metadata, indent=4), "\n")
    print(doc.page_content, "\n")

## Extracting the entities

In this example, we take the caselaw text, split it into chunks, and extract entities from each chunk. 

### Split the document into chunks

Split the document into text chunks that can fit into the model's context window.

In [None]:
from langchain.text_splitter import TokenTextSplitter

doc_chunks = {}
documents = [doc for doc in all_documents[:30] if doc.metadata["id"] in ['4439812', '4439539', '4440694']]
# documents = all_documents[:30]

# Split the documents into chunks
text_splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=50)
for doc in documents:
    id = doc.metadata["id"]
    chunks = text_splitter.split_documents([doc])
    doc_chunks[id] = chunks
    print(f"Case {id}: " + str(len(chunks)))

In [None]:
for doc in documents:
    id = doc.metadata["id"]
    chunks = doc_chunks[id]
    if len(chunks) > 0:
        print(f"Case {id}: " + str(len(chunks)))
        print(json.dumps(doc.metadata, indent=4))

### Inspect the chunks

Each text chunk inherits the metadata from the document. For the purpose of this recipe, note that the `judge` is not well captured in many of these documents; we will be extracting it from the case law text.

In [None]:
import json
for doc in documents[:10]:
    id = doc.metadata["id"]
    print(json.dumps(doc.metadata, indent=4))
    for chunk in doc_chunks[id]:
        print(chunk.page_content)

We can see from this output that the "judge" in the metadata is not reliable, so we will pick that entity out of the text.

## Extracting Entities

### Identify entity categories of interest

For our example query, we will be interested in the judge and any precedent cited.

In [None]:
categories = {
    "Counsel for Plaintiff/Petitioner": "The attorney or law firm representing the plaintiff/petitioner.",
    "Counsel for Defendant/Respondent": "The attorney or law firm representing the defendant/respondent.",
    "Judge/Justice": "The name of the judge or justice involved in the case, including their role (e.g., trial judge, appellate judge, presiding justice).",
    "Statute/Act": "The statute or act referenced or applied in the case (e.g., 'Civil Rights Act of 1964').",
    "Precedent Cited": "Previous case law referred to in the case.",
    "Constitutional Provision": "The constitutional article or amendment referenced in the case (e.g., 'First Amendment,' 'Article III').",
    "Decision/Holding": "The final judgment of the court (e.g., 'Affirmed,' 'Reversed').",
    "Disposition": "The outcome of the case (e.g., 'dismissed with prejudice,' 'remanded').",
    "Remedy": "Type of compensation or relief provided (e.g., 'compensatory damages,' 'injunctive relief').",
    "Sentence": "In a criminal case, the sentence handed down (e.g., '5 years imprisonment')."
}

categories_str = "\n".join(f"{k}: {v}" for k, v in categories.items())

### Construct a prompt template

The template instructs the model to extract entities from a text. The list of entity categories is provided to guide the model's output, and ensure we are obtaining a pre-determined set of entities.

In [None]:
query = f"""\
<|start_of_role|>system<|end_of_role|>
Below is a list of entity categories:

{categories_str}

Given this list of entity categories, you will be asked to extract entities belonging to these categories from a text passage.
Consider only the list of entity categories above; do not extract any additional entities. For each entity found, list the category and the entity, separated by a semicolon. Do not use the words "Entity" or "Category".

Here are some examples:
1. Remedy: Compensatory Damages
2. Counsel for Defendant/Respondent: Jane C.
3. Precedent Cited: State vs. Tiger
<|end_of_text|>
<|start_of_role|>user<|end_of_role|>
Find the entities in the following text, and list them in the format specified above:

{{}}
<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>"""


### Extract entities from each chunk of text

The response will be a list of entities mentioned in each text chunk. Extraction takes about a minute for 10 documents.

In [None]:
doc_extracts = {}
for doc in documents:
    id = doc.metadata['id']
    extracts = []
    for i, chunk in enumerate(doc_chunks[id]):
        print(f"\nChunk {i} of {id}")
        full_query = query.format(chunk.page_content)
        print(str(len(tokenizer.tokenize(full_query))) + " tokens")
        response = model.invoke(full_query, max_tokens=1000)
        print(response)
        extracts.append(response)

    doc_extracts[id] = extracts

## Building the Graph Database

### Construct Graph Triples

Using the extracted entities along with the text chunk, construct graph triples. Graph triples are 3-tuples of subject, predicate, and object. In our recipe, they are 3-tuples of entity, role, and case.

For ease of querying, in this recipe we use the entity name as the identifier for each entity. In a production system you would want to assign a unique ID to each entity.

In [None]:
def get_triples_from_extract(extract, case_name):
    triples = []
    lines = extract.splitlines()
    for line in lines:
        try:
            # Take the number off of the front.
            line = line.split(". ", 1)[1]
            role, entity = line.split(": ", 2)
            if role in categories:
                triple = (entity, role, case_name)
                triples.append(triple)
        except (ValueError, IndexError):
            print(f"Error parsing case {id} line: {line}")
    return triples

doc_triples = {}
for doc in documents:
    id = doc.metadata['id']
    name = doc.metadata['name_abbreviation']
    triples = []
    for i, extract in enumerate(doc_extracts[id]):
        # Break response up into entity triples.
        new_triples = get_triples_from_extract(extract, name);
        triples.extend(new_triples)
    # Add triples from metadata.
    triples.append((doc.metadata["court"], 'Court', name))

    # Add to triples for the document.
    if id in doc_triples:
        doc_triples[id].append(triples)
    else:
        doc_triples[id] = triples

# Get all of the triples, filtering those that have no entity.
all_triples = []
for id, triples in doc_triples.items():
    print(f"Case {id}")
    for triple in triples:
        e = triple[0].lower()
        if "not explicitly mentioned" not in e \
            and "not mentioned" not in e \
            and "not applicable" not in e \
            and "not specified" not in e:
            all_triples.append(triple)
            print(triple)


### Populate the graph

In [None]:
from neo4j import GraphDatabase
from stringcase import snakecase, lowercase

# Define the list of (entity, relationship, entity) triples
triples = all_triples

# Connect to the Neo4j database
uri = get_env_var("NEO4J_URI")
username = get_env_var("NEO4J_USERNAME")
password = get_env_var("NEO4J_PASSWORD")
driver = GraphDatabase.driver(uri, auth=(username, password))

def add_triple(tx, entity, role, case):
    query = (
        "MERGE (e:Entity {name: $entity}) "
        "MERGE (c:Case {name: $case}) "
        "MERGE (e)-[r:%s]->(c)"
    ) % snakecase(lowercase(role.replace('/', '_')))
    tx.run(query, entity=entity, case=case)

def build_graph(triples):
    with driver.session() as session:
        # Empty the graph first
        session.run("MATCH (n) DETACH DELETE n")
        # Fill the graph
        for entity, role, case in triples:
            session.write_transaction(add_triple, entity, role, case)

# Build the graph from the triples list
build_graph(triples)

# Close the connection to the database
driver.close()

### Inspect the contents of the graph

In [None]:
with driver.session() as session:
    # Query to find all nodes
    result = session.run("MATCH (n) RETURN n.name AS name")
    print("Nodes in the graph:")
    for record in result:
        print(record["name"])

    # Query to find all relationships
    result = session.run("MATCH ()-[r]->() RETURN type(r) AS rel")
    print("\nRelationship types in the graph:")
    rels = [record["rel"] for record in result]
    # unique rels
    rels = list(set(rels))
    for rel in rels:
        print(rel)

### Find all precedents cited in the graph

As an example of how we can now use the graph, we will find all precedents cited in the corpus.

In [None]:
driver = GraphDatabase.driver(uri, auth=(username, password))
with driver.session() as session:
    # Query to find all nodes
    result = session.run("MATCH (a)-[:precedent_cited]->() RETURN a.name AS name")
    print("Precedents in the graph:")
    for record in result:
        print(record["name"])

## Populate a vector database with entities

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings

embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

In [None]:
from langchain_chroma import Chroma

vector_db = Chroma(embedding_function=embeddings_model)

### Add graph nodes to the vector database.

For the purpose of illustration, we will add embeddings of the entity names, which are also used as entity keys in the vector database. A more sophisticated aproach would be to generate descriptions of the entities and embed those instead, capturing more context and nuance about the entity indexed.

In [None]:
from langchain.docstore.document import Document
names = []
with driver.session() as session:
    # Query to find all nodes
    result = session.run("MATCH (e)-[r]->() RETURN e.name AS entity, type(r) AS role")
    for record in result:
        doc = Document(page_content=f"{record['entity']} is a {record['role']}", metadata={"entity": record["entity"]})
        names.append(doc)

ids = vector_db.add_documents(names)
print("Documents added: ", len(ids))
doc_count = len(vector_db._collection.get(include=["documents"])["documents"])
print("Documents total: ", doc_count)
# vector_db.delete_collection()


## Answer questions

### Extract entities from question

This is one type of question that can be asked. It mentions two entities of interest, `Judge Bois` (a judge), and `Durkin v. Snow` (a case used as precedent). To answer the question, we will find cases that have these two entities in common.

In [None]:
question = "How has Judge Bois used Durkin v. Snow to rule on cases?"

response = model.invoke(query.format(question))
print(response)
question_entity_triples = get_triples_from_extract(response, "")
print(question_entity_triples)


### Match entities to the graph

We match the question entities to entities in the graph, in order to create a graph query for related cases. We match on a short description of the entity using its role, as extracted from both the question and the case text. This is done by comparing the embeddings of each description for proximity in the semantic embedding space created by the embeddings model. To improve the performance of this match, we could additional context from the question and the case, and even the knowledge graph.

In [None]:

def match_entity(entity, threshold=1.0):
    """Match entities by embedding vector distance given a similarity threshold. With Chroma, l2 (Euclidean) distance is used."""
    docs_with_score = vector_db.similarity_search_with_score(entity, k=5)
    for doc, score in docs_with_score:
        print(f"{doc.metadata['entity']} -- similarity score {score}")
        next
    if len(docs_with_score):
        doc, score = docs_with_score[0]
        if score <= threshold:
            # Return first close match.
            return doc.metadata["entity"]
    else:
        # No match.
        return None

for triple in question_entity_triples:
    entity = triple[0]
    entity_desc = f"{triple[0]} is a {triple[1]}."
    print(f"\nMatching {entity}")
    match = match_entity(entity_desc)
    if match is not None:
        print(f"Match: {match}")

        

### Query the graph for cases

Query for cases given a single entity and its relationship to the case.

In [None]:
def query_for_cases(entity_name, role):
    with driver.session() as session:
        relationship = snakecase(lowercase(role.replace('/', '_')))
        query = f"MATCH (e:Entity {{name: '{entity_name}'}})-[:{relationship}]->(c:Case) RETURN c.name AS name"
        print(query)
        result = session.run(query)
        print("Cases:")
        for record in result:
            print(record["name"])

for triple in question_entity_triples:
    entity, role, case = triple
    entity_match = match_entity(entity)
    query_for_cases(entity_match, role)

Query for cases given multiple entities and their relationships to the case.

In [None]:
def query_for_cases(entity_role_pairs):
    with driver.session() as session:
        query = ""
        for i, (entity, role) in enumerate(entity_role_pairs):
            relationship = snakecase(lowercase(role.replace('/', '_')))
            query += f"MATCH (e{str(i)}:Entity {{name: '{entity}'}})-[:{relationship}]->(c)\n"
        query += "RETURN c.name AS name"
        print(query)
        result = session.run(query)
        cases = []
        print("Cases:")
        for record in result:
            cases.append(record["name"])
            print(record["name"])
        return cases

entity_role_pairs = []
for triple in question_entity_triples:
    entity, role, case = triple
    entity_match = match_entity(entity)
    entity_role_pairs.append((entity_match, role))
    
cases = query_for_cases(entity_role_pairs)

### Retrieve the case text

We have found a case related to both entities in the question. Let's retrieve the case text.

In [None]:
case_text = [doc.page_content for doc in documents if doc.metadata["name_abbreviation"] == cases[0]][0]
print(case_text)

### Answer the question

Having retrieved the case text, now let's answer the question given the case text.

In [None]:
q = f"""
Answer the question using the following text from one case: \n\n{case_text}

Question: {question}
"""

print(question)
response = model.invoke(q)
print(response)