# RAG Example Using NVIDIA API Catalog and LangChain

This notebook introduces how to use LangChain to interact with NVIDIA hosted NIM microservices like chat, embedding, and reranking models to build a simple retrieval-augmented generation (RAG) application.

## Terminology

#### RAG

- RAG is a technique for augmenting LLM knowledge with additional data.
- LLMs can reason about wide-ranging topics, but their knowledge is limited to the public data up to a specific point in time that they were trained on.
- If you want to build AI applications that can reason about private data or data introduced after a model's cutoff date, you need to augment the knowledge of the model with the specific information it needs.
- The process of bringing the appropriate information and inserting it into the model prompt is known as retrieval augmented generation (RAG).

The preceding summary of RAG originates in the LangChain v0.2 tutorial [Build a RAG App](https://python.langchain.com/v0.2/docs/tutorials/rag/) tutorial in the LangChain v0.2 documentation.

#### NIM

- [NIM microservices](https://developer.nvidia.com/blog/nvidia-nim-offers-optimized-inference-microservices-for-deploying-ai-models-at-scale/) are containerized microservices that simplify the deployment of generative AI models like LLMs and are optimized to run on NVIDIA GPUs.
- NIM microservices support models across domains like chat, embedding, reranking, and more from both the community and NVIDIA.

#### NVIDIA API Catalog

- [NVIDIA API Catalog](https://build.nvidia.com/explore/discover) is a hosted platform for accessing a wide range of microservices online.
- You can test models on the catalog and then export them with an NVIDIA AI Enterprise license for on-premises or cloud deployment

#### langchain-nvidia-ai-endpoints

- The [`langchain-nvidia-ai-endpoints`](https://pypi.org/project/langchain-nvidia-ai-endpoints/) Python package contains LangChain integrations for building applications that communicate with NVIDIA NIM microservices.

## Installation and Requirements

Create a Python environment (preferably with Conda) using Python version 3.10.14.
To install Jupyter Lab, refer to the [installation](https://jupyter.org/install) page.

In [117]:
# Requirements
!pip install langchain==0.2.5
!pip install langchain_community==0.2.5
!pip install faiss-gpu # replace with faiss-gpu if you are using GPU
!pip install faiss-cpu
!pip install langchain-nvidia-ai-endpoints==0.1.2
!pip install requests faiss-cpu pdfplumber spacy camelot-py pandas==2 numpy==1.26.4 beautifulsoup4 pymupdf
!pip install lxml unstructured html5lib
!pip install weaviate-client
!pip install -Uqq langchain-weaviate

Collecting weaviate-client
  Downloading weaviate_client-4.8.1-py3-none-any.whl (374 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.5/374.5 KB[0m [31m638.6 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting grpcio-health-checking<2.0.0,>=1.57.0
  Downloading grpcio_health_checking-1.66.2-py3-none-any.whl (18 kB)
Collecting httpx<=0.27.0,>=0.25.0
  Using cached httpx-0.27.0-py3-none-any.whl (75 kB)
Collecting grpcio-tools<2.0.0,>=1.57.0
  Downloading grpcio_tools-1.66.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m666.1 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting grpcio<2.0.0,>=1.57.0
  Downloading grpcio-1.66.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m653.4 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCol

## Getting Started!

To get started you need an `NVIDIA_API_KEY` to use the NVIDIA API Catalog:

1) Create a free account with [NVIDIA](https://build.nvidia.com/explore/discover).
2) Click on your model of choice.
3) Under Input select the Python tab, and click **Get API Key** and then click **Generate Key**.
4) Copy and save the generated key as NVIDIA_API_KEY. From there, you should have access to the endpoints.

Available NVIDIA api keys;

1. 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

In [79]:
import getpass
import os

nvidia_api_key = getpass.getpass("Enter your NVIDIA API key: ")
assert nvidia_api_key.startswith("nvapi-"), f"{nvidia_api_key[:5]}... is not a valid key"
os.environ["NVIDIA_API_KEY"] = nvidia_api_key

## RAG Example using LLM & Embedding

### 1) Initialize the LLM

The ChatNVIDIA class is part of LangChain's integration (langchain_nvidia_ai_endpoints) with NVIDIA NIM microservices.
It allows access to NVIDIA NIM for chat applications, connecting to hosted or locally-deployed microservices.

Here we will use **mixtral-8x7b-instruct-v0.1**

In [81]:
from langchain_nvidia_ai_endpoints import ChatNVIDIA

llm = ChatNVIDIA(model="mistralai/mixtral-8x7b-instruct-v0.1", max_tokens=1024)

# Can choose any model hosted at Nvidia API Catalog (Uncomment the below code to list the availabe models)
# ChatNVIDIA.get_available_models()

#### Note:
- In this notebook, we have used NVIDIA NIM microservices from the NVIDIA API Catalog.
- The other APIs, ChatNVIDIA, NVIDIAEmbedding, and NVIDIARerank, also support self-hosted NIM microservices.
- Change the `base_url` to your deployed NIM URL.
- Example: `llm = ChatNVIDIA(base_url="http://localhost:8000/v1", model="meta/llama3-8b-instruct")`
- NIM can be also hosted locally using Docker, following the [NVIDIA NIM for LLMs](https://docs.nvidia.com/nim/large-language-models/latest/getting-started.html) documentation.

In [None]:
# Example Code snippet if you want to use a self-hosted NIM
from langchain_nvidia_ai_endpoints import ChatNVIDIA

# connect to an LLM NIM running at localhost:8000, specifying a specific model
llm = ChatNVIDIA(base_url="http://localhost:8000/v1", model="meta/llama3-8b-instruct")

### 2) Intialize the embedding
NVIDIAEmbeddings is a client to NVIDIA embeddings models that provides access to a NVIDIA NIM for embedding. It can connect to a hosted NIM or a local NIM using a base URL

We selected **NV-Embed-QA** as the embedding

In [82]:
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings

embedder = NVIDIAEmbeddings(model="NV-Embed-QA", truncate="END")

### 3) Obtain dataset
I love taxes and work! Lets steal information about taxes and work!

#### In the **dataloader** class defined below, we can either;

A) Walk through a webpage and find all sub-webpages and scrape the parent and children, or;

B) Extract texts, tables and images from pdfs

In [118]:
import os
import requests
import pandas as pd
import urllib.parse  # To handle URL joining
import fitz

from tqdm import tqdm
from io import StringIO
from bs4 import BeautifulSoup, SoupStrainer
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader, DataFrameLoader, CSVLoader, UnstructuredTSVLoader

class DataHandler:
    """
    Masterfully handles data scraping, preprocessing, and other data-related functionalities in this notebook.
    """

    def __init__(self,
                csv_path="./data/csv"):
        self.visited_urls = set() # classwide tracker to prevent repeated visits
        self.webloaders = [] # tracks all urls that have been converted to langchain WebBaseLoaders.
        self.tabular = [] # tracks all tabular data that has been discovered by scraper. Delivers a list of CSVLoader. Cowardly refusing to save to same list
        self.raws = []
        self.csv_path = csv_path

        self.text_splitter = CharacterTextSplitter(chunk_size=400, separator=" ", chunk_overlap=80)

        self.visited_urls = set()   # Classwide tracker to prevent repeated visits when scraping web
        self.tabular_data = []      # Tracks all tabular_data data that has been discovered by scraper. Delivers a list of CSVLoader. Cowardly refusing to save to same list.
        self.textual_data = []      # Tracks all textual data scraped from websites or pdfs
        self.all_data = []  # Just defining a joint list here instead of later during the functional call; in case we do not care about seperating them anymore (both are lists of Documents, just from different
        # base sources)

        if not os.path.exists(self.csv_path):
            os.mkdir(self.csv_path)
    
    
    @staticmethod
    def from_csv(csv_path):
        """
        Small func to read from csv and produce CSVLoaders.
        """
        df = pd.from_csv( csv_path) # so we can yoink its columns
        loader = CSVLoader(file_path=csv_path,
                                    csv_args={  'delimiter': ',',
                                                'quotechar': '"',
                                                'fieldnames': [str(col) for col in df.columns]}
                                    )
        return loader
    

    def csvs_to_loader(self, directory):
        """
        walks through directory, finds all csvs, and saves it into self.tabular 
        """
        for dir, subdir, files in os.walk(directory):
            for file in files:
                if file.endswith(".csv"):
                    fp = os.path.join(dir, file)
                    loader = self.from_csv(fp)
                    self.tabular.append(loader)


    def extract_table_elements(self, url, table_elements):
        """
        Helper func to extract tabular data

        Args:
            - url: url the table is under
            - table_elements: list of table elements
        """
        for table_idx, table in enumerate(table_elements): #TODO: Find better way to index different tables on the same page? not all have class attributes we can ID them with.
            try:
                tags = table.find_all('sup')
                for tag in tags:
                    tag.extract()

                # Attempt to find the title of this table. tableheader elements only tell us (pandas) how to index it, but what we need the header element for context on what this table is about
                tablename = f"table_{table_idx}" # default name presuming none is found
                for headertype in ['h3', 'h4']: # unlikely to lie in h2 or h1? could result in duplicate data. if we find by those.
                    header = table.find_previous(headertype)
                    if header is not None and header.text is not None: # find the closest header
                        tablename = header.text.replace("\n", "").replace("#", '').replace("\r", '').replace("/", '')
                        break
                
                tablename = os.path.basename(url) + " " + f"{tablename}" # what we will call this table, some has really annoying spacing, so maybe .replace(' ', '')?
                print(f"Grabbing table data under tablename {tablename}")
                df = pd.read_html(StringIO(str(table)), header=0)[0] # some tables do not have tableheader <th> tags for first row 
                # which would result in  generic column indices being created, so forcibly set first row as tableheader. 
                df['context'] = [tablename] * df.shape[0]

                if len(tablename) >= os.pathconf('/', 'PC_NAME_MAX'): # prevent shit from exploding because my tablename is damn scuffed but what todo.
                    tablename = tablename[:os.pathconf('/', 'PC_NAME_MAX') - 10]
                csv_path = os.path.join(self.csv_path, tablename + '.csv')
                # print(df)

                # print("csv_path:", csv_path)
                # print("csv_path:", csv_path)
                df.to_csv( csv_path, index=False ) # if true will fuck up columning in csvloader
                # loader = UnstructuredTSVLoader(csv_path, mode='elements')
                loader = CSVLoader(file_path=csv_path,
                                    csv_args={  'delimiter': ',',
                                                'quotechar': '"',
                                                'fieldnames': [str(col) for col in df.columns]}
                                    )
                for row in loader.load()[1:]: # first row is just column indexes, so void
                    self.tabular.append(
                        self.text_splitter.split_text(
                            self.clean_text(row.page_content)
                            )
                        )

            except BaseException as e:
                print(f"Unable to extract table data from url {url} with error {e}, passing!")


    def create_loaders(self):
        """
        Seperate method to create the loaders. Directly appends to textual_data attribute and calls extract_table_elements to handle tabular data.
        """
        for url, soupy_little_guy in self.raws:
            loader = WebBaseLoader(
                    web_paths=(url,),  # No URL fetching as we already have the HTML content
                    bs_kwargs={"parse_only": SoupStrainer(['main'])},
                )
            html_content = loader.load()
            for i in range(len(html_content)):
                self.textual_data.extend(
                    self.text_splitter.split_text(
                        self.clean_text(
                            html_content[i].page_content
                )))

            table_elements = soupy_little_guy.find_all("table") 
            self.extract_table_elements(url=url, table_elements=table_elements)


    def scrape_website(self, base_url, max_depth, depth=0):
        """
        Wraps around a nested function get_from_website, so when ran will define a new function that knows that the sauce base_url is.
        Scuffed? Yes.
        """

        def get_from_website(url, max_depth, depth):
            """
            Recursively scrape a website by visiting links starting from url. 
            Because this is a mostly I/O bound operation, we make a seperate method that actually creates the Loaders.

            Parameters:
                - url:              URL to start scraping from
                - depth:            Current recursion depth
                - max_depth:        Maximum recursion depth to avoid infinite loops

            Returns:
            - appends to self.raws, [url, BeautifulSoup object] created from response content.
            """
            if url in self.visited_urls or depth > max_depth:
                return

            try:
                response = requests.get(url)
                soupy_little_guy = BeautifulSoup(response.content, 'html.parser')
                if response.status_code != 200:
                    return print(f"Failed to retrieve {url}")
            except Exception as e:
                return print(f"Error accessing {url} with error: {e}")
                

            self.visited_urls.add(url)
            print("Current url:", url)
            self.raws.append([url, soupy_little_guy])

            for link in soupy_little_guy.find_all('a', href=True):  # Find all links on the current page
                relative_url = link['href']
                absolute_url = urllib.parse.urljoin(url, relative_url)
                if base_url in absolute_url:  # Avoids external sites
                    get_from_website(absolute_url, max_depth, depth + 1)

        get_from_website(base_url, max_depth, depth)


    def scrape_pdf(self, pdf_path):
        """
        Extracts information from pdf files.
        Texts stay as texts.
        Tables and images...

        Parameters:
            - pdf_path:     PDF file to extract from

        Output:
            - textual_data: List of cleaned strings that represent data from pdf doc
        """
        try:
            pdf_document = fitz.open(pdf_path)
        except:
            return print(f"Unable to open {pdf_path}")
        print(f"Current pdf: {pdf_path}")
        pdf_text = ""
        for page_num in range(len(pdf_document)):
            page = pdf_document.load_page(page_num)
            pdf_text += page.get_text("text")

        self.textual_data.extend(
            self.text_splitter.split_text(
                self.clean_text(pdf_text)
        ))
    
    def clean_text(self, text):
        """
        Cleans text retrieved from sources to reduce the storage needed

        get_from_website(base_url, max_depth, depth)
        Parameters:
            - text:         Original string

        Output:
            - cleaned_text: Cleaned string 
        """
        return text.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ').replace('  ', ' ')
    
    def embed_text(self, text):
        """
        Embeds a given text if necessary

        Parameters:
            - text:     Original string

        Output:
            - text:     Embedded string 
        """
        return embedder.embed_query(text)


# datahandler.recursively_scrape(
#     "https://www.iras.gov.sg/schemes/disbursement-schemes/progressive-wage-credit-scheme",
#     max_depth=0
# )
# base_domains = ["https://www.iras.gov.sg/schemes/disbursement-schemes/progressive-wage-credit-scheme-(pwcs)"]
# for base_domain in base_domains:
#     datahandler.recursively_scrape(base_domain, max_depth)

We define our websites and pdf links below to call the dataloader on.

In [None]:
# Max depth determines whether you wish to look into children of parents websites, else set to 0
max_depth = 2
websites = [
    "https://www.iras.gov.sg/taxes",
    "https://www.iras.gov.sg/schemes",
    "https://www.mom.gov.sg/passes-and-permits",
    "https://www.mom.gov.sg/employment-practices",
    "https://www.mom.gov.sg/workplace-safety-and-health"
]
pdfs = [
    "./data/pdfs/verification-checklist-inspection-of-machines.pdf"
]

datahandler = DataHandler()


In [121]:
for website in websites:
    datahandler.scrape_website(website, max_depth) # Scrape here once and for all

for pdf in pdfs:
    datahandler.scrape_pdf(pdf)

Current url: https://www.iras.gov.sg/schemes
Current url: https://www.iras.gov.sg/schemes/disbursement-schemes
Current url: https://www.iras.gov.sg/schemes/disbursement-schemes/enterprise-innovation-scheme-(eis)
Current url: https://www.iras.gov.sg/schemes/disbursement-schemes/refundable-investment-credit-(ric)
Current url: https://www.iras.gov.sg/schemes/disbursement-schemes/jobs-growth-incentive
Current url: https://www.iras.gov.sg/schemes/disbursement-schemes/progressive-wage-credit-scheme
Current url: https://www.iras.gov.sg/schemes/disbursement-schemes/progressive-wage-credit-scheme/pwcs-glossary
Current url: https://www.iras.gov.sg/schemes/disbursement-schemes/uplifting-employment-credit
Current url: https://www.iras.gov.sg/schemes/disbursement-schemes/senior-employment-credit-(sec)-cpf-transition-offset-(cto)-and-enabling-employment-credit-(eec)
Current url: https://www.iras.gov.sg/schemes/disbursement-schemes/self-review-for-eligibility-of-government-schemes
Current url: https:

In [123]:
datahandler.create_loaders() # only when we actually want to create loaders to debug stuff, so we don't need to scrape within the same function that we're (likely) debugging, saving time.

Grabbing table data under tablename voluntary-disclosure-of-errors-for-reduced-penalties Disclose errors within grace period to avoid penalties
Grabbing table data under tablename voluntary-disclosure-of-errors-for-reduced-penalties             Example 1: Computation of reduced penalty on errors voluntarily disclosed by a company in its Corporate Income Tax Return within and after the grace period        
Grabbing table data under tablename voluntary-disclosure-of-errors-for-reduced-penalties             Example 2: Computation of reduced penalty on errors voluntarily disclosed by an individual in his Individual Income Tax Return within and after the grace period        
Grabbing table data under tablename voluntary-disclosure-of-errors-for-reduced-penalties Voluntary Compliance Initiatives
Grabbing table data under tablename voluntary-disclosure-of-errors-for-reduced-penalties Voluntary Compliance Initiatives
Grabbing table data under tablename basic-guide-for-new-individual-taxpayers 

### And then, we embed the data for use in the vector db.

In [64]:
embedded_text = [datahandler.embed_text(text) for text in datahandler.textual_data]
embedded_tabular = [datahandler.embed_text(text) for text in datahandler.tabular_data]
all_data = embedded_tabular.extend(embedded_text) # maybe neatly tuck this away in some other functionality in datahandler? Idk buddy

### 4) Storing the documents

To build our foundational knowledge base from our collected data and allow for faster retrieval of vector queries, we need to have some form of search system.

#### a) Process the documents into vectorstore and save it to disk

Real world documents can be very long, this makes it hard to fit in the context window of many models. Even for those models that could fit the full post in their context window, models can struggle to find information in very long inputs.

To handle this we’ll split the Document into chunks for embedding and vector storage. More on text splitting [here](https://python.langchain.com/v0.2/docs/concepts/#text-splitters)

In [112]:
# Here we create a faiss vector store from the documents and save it to disk.
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import CharacterTextSplitter
from tqdm import tqdm

text_splitter = CharacterTextSplitter(chunk_size=400, separator=" ", chunk_overlap=80)

documents = []

def clean_text(text):
    return text.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ').replace('  ', ' ')

for docs in tqdm(datahandler.webloaders, desc='Embedding text from websites...'):  # Had to cut down data to a quarter's worth cuz no more tokens in model to run
    for i in range(len(docs)):
        documents.extend(text_splitter.split_text(clean_text(docs[i].page_content)))


    # except Exception as e:
    #     print(f"Error {e} encountered whilst loading data!")

# for pdf in tqdm(pdf_data, desc='Embedding text from pdfs...'):  # Had to cut down data to a quarter's worth cuz no more tokens in model to run
#     documents.extend(text_splitter.split_text(clean_text(pdf)))

# metadatas = []
# for i, d in enumerate(documents):
#     splits = text_splitter.split_text(d)
#     docs.extend(splits)
#     # metadatas.extend([{"source": sources[i]}] * len(splits))

# You will only need to do this once, later on we will restore the already saved vectorstore
store = FAISS.from_texts(documents, embedder)
VECTOR_STORE = './data/nv_embedding'
store.save_local(VECTOR_STORE)

Embedding text from websites...: 100%|██████████| 622/622 [00:00<00:00, 791.30it/s]
Embedding text from tables...: 100%|██████████| 949/949 [00:00<00:00, 2274.43it/s]


Qualifying Year: 2022.0
First Tier Gross Monthly Wage Ceiling ≤ $2,500: 75%
Second Tier Gross Monthly Wage Ceiling > $2,500 and ≤ $3,000: 45%
context: progressive-wage-credit-scheme-(pwcs)#title12 Is my firm eligible?
['Qualifying Year: 2022.0 First Tier Gross Monthly Wage Ceiling ≤ $2,500: 75% Second Tier Gross Monthly Wage Ceiling > $2,500 and\xa0≤ $3,000: 45% context: progressive-wage-credit-scheme-(pwcs)#title12 Is my firm eligible?']
Qualifying Year: 2023.0
First Tier Gross Monthly Wage Ceiling ≤ $2,500: 75%
Second Tier Gross Monthly Wage Ceiling > $2,500 and ≤ $3,000: 45%
context: progressive-wage-credit-scheme-(pwcs)#title12 Is my firm eligible?
['Qualifying Year: 2023.0 First Tier Gross Monthly Wage Ceiling ≤ $2,500: 75% Second Tier Gross Monthly Wage Ceiling > $2,500 and\xa0≤ $3,000: 45% context: progressive-wage-credit-scheme-(pwcs)#title12 Is my firm eligible?']
Qualifying Year: 2024.0
First Tier Gross Monthly Wage Ceiling ≤ $2,500: 50%
Second Tier Gross Monthly Wage Ceiling

To enable runtime search, we index text chunks by embedding each document split and storing these embeddings in a vector database. Later to search, we embed the query and perform a similarity search to find the stored splits with embeddings most similar to the query.

Then, we can read the previously processed and saved vector store back for use;

In [113]:
from langchain_community.vectorstores import FAISS

# Load the FAISS vectorestore back.
VECTOR_STORE = './data/nv_embedding'
store = FAISS.load_local(VECTOR_STORE, embedder, allow_dangerous_deserialization=True)

#### b) Store all our data into a vectordb (connects vector stores to structured db systems)

We will be using **Weaviate** because it is a free, open-source, scalable and reliable vector database service with decent amounts of documentation online.

The Weaviate DB does not need to be structured like an SQL DB since it already has integrated models/pre-computed embeddings to handle vector embeddings.

Furthermore, it works more with Object-Oriented design rather than rows and columns.

Sources used:
- https://python.langchain.com/docs/integrations/vectorstores/weaviate/
- https://weaviate.io/developers/weaviate/client-libraries/python

Run the following only if you already have the instance of weaviate running on your local http://localhost:8080 and port 50051 open for gRPC traffic

In [83]:
import weaviate
from weaviate.classes.init import AdditionalConfig, Timeout
from weaviate.classes.config import Property, DataType

class WeaviateDB:
    """
    Not-so-masterfully handles the vector database using Weaviate and all DB related functionality.
    """

    def __init__(self):
        self.client = None

    def set_up_db(self):
        """
        Starts up connection to Weaviate DB and creates schema if not already created

        Parameters: None

        Output:     None
        """
        try:
            self.client = weaviate.connect_to_local(
                port=8080,
                grpc_port=50051,
                additional_config=AdditionalConfig(
                    timeout=Timeout(init=30, query=60, insert=120)  # Values in seconds
            ))  
        except Exception as e:
            return print(f"Unable to start up Weaviate DB due to {e}")

    def shut_down_db(self):
        """
        Shuts down connection to Weaviate DB 

        Parameters: None

        Output:     None
        """     
        self.client.close()

    def create_collection(self):
        """
        Creates a new Document collection in the DB to store original documents and embeddings

        Parameters: None

        Output:     None
        """     
        # self.schema = {
        #     "classes": [    
        #         {
        #             "class": "Document",
        #             "description": "A class for storing split documents.",
        #             "properties": [
        #                 {
        #                     "name": "content",
        #                     "dataType": ["text"]
        #                 },
        #                 {
        #                     "name": "embedding",
        #                     "dataType": ["number[]"]
        #                 }
        #             ]
        #         }
        #     ]
        # }
        # Add the schema to Weaviate if not already added
        # if not any(cls['class'] == 'Document' for cls in self.client.schema.get()['classes']):
        #     self.client.schema.create(self.schema)

        self.client.collections.create(
            "Document",
            properties=[
                Property(name="content", data_type=DataType.TEXT),
                Property(name="embedding", data_type=DataType.NUMBER),
            ],
            reranker_config='', # TODO: Try adding nvidia reranker here
            generative_config='' # TODO: Try adding nvidia llm here
        )

    def add_objects(self, original, embedded):
        """
        Adds document and embedding object pairs to the db schema 

        Parameters:
            - original:     Original documents
            - embedded:     Embedded documents

        Output: None
        """
        self.set_up_db()

        # for idx, (doc_chunk, embedding) in enumerate(zip(original, embedded)):
        #     properties = {
        #         "content": doc_chunk,
        #         "embedding": embedding
        #     }
        #     self.client.data_object.create(properties, "Documents")

        documents = self.client.collections.get("Documents")
        documents.config.add_property(
            Property(
                name="onHomepage",
                data_type=DataType.BOOL
            )
        )

        self.shut_down_db()

    def query_data(self):
        self.set_up_db()

        # TODO: Use some NIM API maybe to query data from vectordb 
        # -> Also i didnt make it into a wrapper yet cuz it always gets messy when i try method wrappers that use self too
        documents = self.client.collections.get("Documents")
    
        self.shut_down_db()

In [84]:
weaviatedb = WeaviateDB()
weaviatedb.add_objects(datahandler.textual_data, datahandler.embedded_data)

AttributeError: 'WeaviateClient' object has no attribute 'schema'

### 5) Wrap the restored vectorsore into a retriever and ask our question

In [114]:
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate

retriever = store.as_retriever()

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Answer solely based on the following:\n<Documents>\n{context}\n</Documents>",
        ),
        ("user", "{question}"),
    ]
)

# Langchain's LCEL(LangChain Expression Language) Runnable protocol is used to define the chain
# LCEL allows pipe together components and functions
chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

Case 1:Irrelevant question

In [11]:
print(chain.invoke("What is component two hundred and twenty five million?"))

There is no information provided in the documents about a component or entity called "two hundred and twenty five million." The documents contain information about various financial transactions, mortgage and loan balances, estate duty calculations, and rules related to M&A allowance and motor vehicle expenses, but there is no mention of a "component" with a value of 225 million.


Case 2: Simple questions

In [115]:
print(chain.invoke("How do i file taxes for my company?"))

To file taxes for your company, you need to follow these steps:

1. Sign up for GIRO or PayNow Corporate with your business/organisation's bank account to receive CIT and GST refunds.
2. If your company qualifies, file Form C-S via mytax.iras.gov.sg. You can do so by logging in to mytax.iras.gov.sg.
3. Before filing, you have to be authorized by your company to act for its Corporate Income Tax matters via Corppass. You can refer to the step-by-step guides for assistance on Corppass setup.
4. The type of Corporate Income Tax Return you should file (Form C-S, Form C-S (Lite) or Form C) depends on the qualifying conditions.

If your company is being liquidated, the appointed liquidator will need to request access to myTax Portal on your behalf via Corppass.


Case 3: Complex questions

In [88]:
print(chain.invoke("In the event my foriegn employee is injured at work, how do i report the incident and claim reparations?"))

Based on the documents provided, if your foreign employee is injured at work, you can report the incident to the Ministry of Manpower (MOM) in Singapore at +65 6438 5122. If the work injury compensation insurance is obligatory under the Work Injury Compensation Act (WICA), you may be able to claim input tax for the insurance premiums. However, input tax claims for medical and accident insurance premiums are generally disallowed under the GST (General) Regulations. It is recommended to visit the MOM webpage on WICA or contact them for more detailed information on the claims process and conditions.


Case 4: Specific data retrieval questions

In [89]:
chain.invoke("I have accidentally made an error in my taxes related to GST and would like to voluntarily report it to IRAS. How should I go about this?")
# after embedding tabular data, this question was able to be answered correctly.

'To report and correct the error in your GST (Goods and Services Tax) tax return, you should follow these steps as outlined in the documents provided:\n\n1. Send an electronic request to IRAS for GST F7 (Disclosure of Errors on GST Return). The system to send this request is not specified in the documents, so you may need to check the IRAS website or contact them directly to find out how to submit this request.\n2. After sending the request, you will need to e-File the GST F7 form. This can be done online at the myTax Portal.\n3. You have up to 14 days from the date of your request for the GST F7 to submit the completed form.\n\nIf you are a GST late registrant, you will also need to register for GST online at the myTax Portal.\n\nFor unauthorized GST collections or disclosure of input tax claimed on any supply that was part of a Missing Trader Fraud, you may need to provide additional information or take additional steps. It would be best to consult the IRAS website or contact them di

In [116]:
chain.invoke("I am an employer whose firm qualifies for PWCs. From 2024 onwards, how much Co-Funding can I recieve from the government?")
# Model still struggles with this question, time to explore more complex workflows.

"Based on the documents provided, the gross monthly wage ceiling for co-funding under the Progressive Wage Credit Scheme (PWCS) will be increased from $2,500 to an unspecified amount, starting from 2024. However, the exact amount of co-funding that you can receive as an employer is not mentioned in the documents. It would be best to refer to the official Budget 2024 announcement or check with the relevant authorities for the specific details on co-funding amounts.\n\nAdditionally, the documents mention that employees' average wage must be $4,000 or lower to be eligible for PWCS and that a wage cut-off for PWCS eligibility will apply from 2024 onwards. Employees whose average monthly wage exceeds $4,000 post-wage increase will not be eligible for PWCS. Ensure that your employees' wages comply with these eligibility criteria to qualify for the PWCS co-funding."

### 6) Enhancing accuracy for single data sources

This example demonstrates how a re-ranking model can be used to combine retrieval results and improve accuracy during retrieval of documents.

Typically, reranking is a critical piece of high-accuracy, efficient retrieval pipelines. Generally, there are two important use cases:

- Combining results from multiple data sources
- Enhancing accuracy for single data sources

Here, we focus on demonstrating only the second use case. If you want to know more, check [here](https://github.com/langchain-ai/langchain-nvidia/blob/main/libs/ai-endpoints/docs/retrievers/nvidia_rerank.ipynb)

In [None]:
from langchain_nvidia_ai_endpoints import NVIDIARerank
from langchain_core.runnables import RunnableParallel

# We will narrow the collection to 100 results and further narrow it to 10 with the reranker.
retriever = store.as_retriever(search_kwargs={'k':100}) # typically k will be 1000 for real world use-cases
ranker = NVIDIARerank(model='nv-rerank-qa-mistral-4b:1', top_n=10)

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Answer solely based on the following context:\n<Documents>\n{context}\n</Documents>",
        ),
        ("user", "{question}"),
    ]
)

reranker = lambda input: ranker.compress_documents(query=input['question'], documents=input['context'])

chain_with_ranker = (
    RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
    | {"context": reranker, "question": lambda input: input['question']}
    | prompt
    | llm
    | StrOutputParser()
)


In [None]:
print(chain_with_ranker.invoke("How do i file taxes for my company?"))

Based on the provided documents, to file taxes for your company, you need to follow these steps:

1. Ensure that you are duly authorized by your company as an 'Approver' for Corporate Tax (Filing and Applications) in Corppass. You can refer to the step-by-step guides for assistance on Corppass setup.
2. Have your Singpass and your company’s Unique Entity Number (UEN)/ Entity ID ready.
3. Visit the mytax.iras.gov.sg website to file the Corporate Income Tax Return for your company.
4. If your company is filing Form C, you need to submit the financial statements/certified accounts and tax computation(s) for the relevant Year of Assessment (YA).
5. If your company meets the qualifying conditions to file Form C-S or Form C-S (Lite), you can choose to file the simplified version, Form C-S (Lite), if your company has an annual revenue of $200,000 or below.

You can also visit the Basic Guide to Corporate Income Tax for Companies page to get help with filing your company’s tax returns for the 

In [None]:
print(chain_with_ranker.invoke("In the event my foriegn employee is injured at work, how do i report the incident and claim reparations?"))

Based on the provided documents, if your foreign employee is injured at work, you can report the incident and claim reparations under the Work Injury Compensation Act (WICA). Specifically, the documents mention that input tax can be claimed for work injury compensation insurance that is obligatory under WICA for both local and foreign employees performing manual work or non-manual work earning $2,600 or less a month.

To report the incident and make a claim, you can visit the Ministry of Manpower (MOM) webpage on WICA or contact MOM at +65 6438 5122. However, it is important to note that medical and accident insurance premiums for your staff are generally not allowable for input tax claims under the GST (General) Regulations, unless the insurance or payment of compensation is obligatory under WICA or under any collective agreement within the meaning of the Industrial Relations Act.

Therefore, it seems that you can report the incident and claim reparations under WICA, but you should ve