# Retrieval Augmented Generation (RAG) using New Hampshire Case Law
*With IBM Granite Models*

The [New Hampshire Case Law Dataset](https://huggingface.co/datasets/free-law/nh) comes from the Caselaw Access Project via Hugging Face.

## In this notebook
This notebook contains instructions for performing Retrieval Augumented Generation (RAG). RAG is an architectural pattern that can be used to augment the performance of language models by recalling factual information from a knowledge base, and adding that information to the model query. The most common approach in RAG is to create dense vector representations of the knowledge base in order to retrieve text chunks that are semantically similar to a given user query.

RAG use cases include:
- Customer service: Answering questions about a product or service using facts from the product documentation.
- Domain knowledge: Exploring a specialized domain (e.g., finance) using facts from papers or articles in the knowledge base.
- News chat: Chatting about current events by calling up relevant recent news articles.

In its simplest form, RAG requires 3 steps:

- Initial setup:
  - Index knowledge-base passages for efficient retrieval. In this recipe, we take embeddings of the passages using WatsonX, and store them in a vector database.
- Upon each user query:
  - Retrieve relevant passages from the database. In this recipe, we using an embedding of the query to retrieve semantically similar passages.
  - Generate a response by feeding retrieved passage into a large language model, along with the user query.

## Prerequisites

To get started, you'll need:
* A [Replicate account](https://replicate.com/) and API token.

## Setting up the environment

### Install dependencies

Granite utils comes with a bundle of dependencies that are required for notebooks.

In [None]:
!pip install git+https://github.com/ibm-granite-community/utils.git \
    langchain_community \
    replicate \
    langchain-huggingface \
    langchain-milvus \
    datasets \
    transformers \
    tiktoken

## Selecting System Components

### Choose your Embeddings Model

Specify the model to use for generating embedding vectors from text.

To use a model from a provider other than Huggingface, replace this code cell with one from [this Embeddings Model recipe](https://github.com/ibm-granite-community/utils/blob/main/recipes/Components/Langchain_Embeddings_Models.ipynb).

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings

embeddings_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-30m-english")

### Choose your Vector Database

Specify the database to use for storing and retrieving embedding vectors.

To connect to a vector database other than Milvus, substitute this code cell with one from [this Vector Store recipe](https://github.com/ibm-granite-community/utils/blob/main/recipes/Components/Langchain_Vector_Stores.ipynb).

In [None]:
from langchain_milvus import Milvus
import tempfile

db_file = tempfile.NamedTemporaryFile(prefix="milvus_", suffix=".db", delete=False).name
print(f"The vector database will be saved to {db_file}")

vector_db = Milvus(
    embedding_function=embeddings_model,
    connection_args={"uri": db_file},
    auto_id=True,
    index_params={"index_type": "AUTOINDEX"},
)

### Choose your LLM
The LLM will be used for answering the question, given the retrieved text.

Follow the instructions in [Getting Started with Replicate](https://github.com/ibm-granite-community/granite-kitchen/blob/cee1513c77429d7ddbf0e5a49b29b7bc9ca0d996/recipes/Getting_Started/Getting_Started_with_Replicate.ipynb), selecting a Granite Code model from the [`ibm-granite`](https://replicate.com/ibm-granite) org.

To connect to a model on a provider other than Replicate, substitute this code cell with one from the [LLM component recipe](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Components/Langchain_LLMs.ipynb).

In [None]:
from langchain_community.llms import Replicate
from ibm_granite_community.notebook_utils import get_env_var

model_path = "ibm-granite/granite-3.3-8b-instruct"

model = Replicate(
    model=model_path,
    replicate_api_token=get_env_var('REPLICATE_API_TOKEN'),
)

Get the tokenizer used by your chosen model.

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path)

## Acquiring the Data

We will use a New Hampshire case law dataset to help the model answer questions about NH laws.

### Download the documents

Download the [New Hampshire CAP Caselaw](https://huggingface.co/datasets/free-law/nh) dataset from HuggingFace using the datasets library.

In [None]:
from langchain.document_loaders import HuggingFaceDatasetLoader

# Load the documents from the dataset
loader = HuggingFaceDatasetLoader("free-law/nh", page_content_column="text")
documents = loader.load()
print("Document Count: " + str(len(documents)))

### Add metadata to the documents

Add the `source` field, which is used below, to the metadata.

In [None]:
for doc in documents:
    doc.metadata['source'] = doc.metadata['id']

### Inspect the documents

In [None]:
for doc in documents[:1]:
    print(doc.metadata, "\n")
    print(doc.page_content, "\n")

## Building the Document Database

We'll use the caselaw document database to retrieve the full text of the cases by case id.

### Create the database file and document table

In [None]:
# put the json objects in a sqlite database, keyed by id
import sqlite3, os, json

# remove database file if exists
if os.path.isfile('data.db'):
    os.remove('data.db')

conn = sqlite3.connect('data.db')
c = conn.cursor()

# create the table if it doesn't exist. include id, text, and size
c.execute('''CREATE TABLE IF NOT EXISTS data
             (id INTEGER PRIMARY KEY UNIQUE,
              metadata TEXT,
              text TEXT,
              char_count INTEGER)''')


### Insert the documents into the table

In [None]:
for doc in documents:
    id = doc.metadata["id"]
    c.execute("INSERT INTO data (id, metadata, text, char_count) VALUES (?,?,?,?)", (id, json.dumps(doc.metadata), doc.page_content, doc.metadata["char_count"]))
    conn.commit()

### Count the documents

In [None]:
c.execute("SELECT count(*) FROM data")
doc_count = c.fetchone()[0]
print(f"Document count: {doc_count}")

## Building the Vector Database

In this example, we take the caselaw text, split it into chunks, derive embedding vectors using the embedding model, and load it into the vector database for querying.

### Split the document into chunks

Split the document into text segments that can fit into the model's context window.

In [None]:
from langchain.text_splitter import TokenTextSplitter

# Split the documents into chunks
text_splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=10)
chunks = text_splitter.split_documents(documents)
print("Chunk Count: " + str(len(chunks)))

### Inspect the chunks

In [None]:
import json
for i in range(1):
    print(chunks[i].page_content)
    print(json.dumps(chunks[i].metadata, indent=4))

### Populate the vector database

NOTE: Population of the vector database may take a few minutes depending on your embedding model and service.

In [None]:
ids = vector_db.add_documents(chunks)
print("Document IDs: " + str(ids[:3]))

## Querying the Databases

### Create query text

Here we use a topic of NH law to query into the vector database for relevant cases. Because we will consider one case at a time (due to context length restrictions), phrase the query to consider a single case.

In [None]:
query = "Summarize this court case about the Suspension and Expulsion of Pupils, using the IRAC framework (Issue, Rule, Application, Conclusion).\n\n"

### Query the vector database

Query the vector database for cases related to the law. Similar documents are found by proximity of the embedded vector in vector space.

In [None]:
k = 10  # the number of docs to retrieve
docs_with_score = vector_db.similarity_search_with_score(query, k=k)

# Get a unique set of docs.
docs = []
doc_ids = {}
for doc, score in docs_with_score:
    # print(doc.metadata["name_abbreviation"])
    # print(score)
    id = doc.metadata["id"]
    if id not in doc_ids:
        docs.append(doc)
        print(id, " - ", doc.metadata["name_abbreviation"])
        doc_ids[id] = 1

### Query the document database

Get the full text of the first case found by the vector search.

Get a list of unique doc ids.

In [None]:
# Get a list of unique doc ids.
docs_ids_seen = set()
uq_docs = [doc for doc in docs if not (doc.metadata["id"] in docs_ids_seen or docs_ids_seen.add(doc.metadata["id"]))]

In [None]:
# Retrieve a number of cases.
cases = []

for doc in docs:
    case_id = doc.metadata["id"]
    case_short_name = doc.metadata["name_abbreviation"]

    c.execute("SELECT text FROM data where id = ?", (case_id,))
    case_text = c.fetchone()[0]
    case_length = len(tokenizer.tokenize(case_text))

    # For this recipe, only consider cases that can fit in the 4k context window (along with the 512 token output).
    if case_length < 3500:
        cases.append((case_id, case_short_name, case_text))
        print(f"Case Id: {case_id}")
        print(f"Case Name: {case_short_name}")
        print(f"Case Length: {case_length} tokens\n")


## Answering Questions

### Assemble the Chat Prompt

Build a chat prompt template with the law and the retrieved case.

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

system_prompt = (
    "You are an assistant with legal expertise. Answer the question based only on the following text from a NH court case. Do not include any other court cases. \n\n{case_text}"
)

rag_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)

rag_chain = (
    rag_prompt
    | model
    | StrOutputParser()
)

### Ask questions of the retrieved case in relation to the law.

Answer the question about each related case.

In [None]:
for case in cases[:2]:
    (case_id, case_short_name, case_text) = case
    response = rag_chain.invoke(input = {"input": query, "case_text": case_text})
    print(f"Case {case_id}: {case_short_name}\n")
    print(response, "\n\n")