# Part 1 pre-acceptance workflow

This notebook is meant to be a prototype for the NVIDIA NIM hackathon project.

## Installation and Requirements

Shits fucked cunt

In [None]:
# Requirements
!pip install langchain==0.2.5
!pip install langchain_community==0.2.5
!pip install faiss-gpu # replace with faiss-gpu if you are using GPU
!pip install faiss-cpu
!pip install langchain-nvidia-ai-endpoints==0.1.2
!pip install requests pdfplumber spacy camelot-py 
!pip install pandas==2 numpy==1.26.4 
!pip install beautifulsoup4 
!pip install pymupdf
!pip install lxml
!pip install unstructured

## Getting Started!

You need an `NVIDIA_API_KEY` to use the NVIDIA API Catalog:

1) Create a free account with [NVIDIA](https://build.nvidia.com/explore/discover).
2) Click on your model of choice.
3) Under Input select the Python tab, and click **Get API Key** and then click **Generate Key**.
4) Copy and save the generated key as NVIDIA_API_KEY. From there, you should have access to the endpoints.
5) If at any point downstream NVIDIA complains about insufficient credits, tell them to shut the fuck up and make a new gmail. If their stock price is so high, they should share some of those profits with me, an NVIDIA investor (i own 3 shares rn)

In [1]:
import getpass
import os

nvidia_api_key = getpass.getpass("Enter your NVIDIA API key: ")
assert nvidia_api_key.startswith("nvapi-"), f"{nvidia_api_key[:5]}... is not a valid key"
os.environ["NVIDIA_API_KEY"] = nvidia_api_key

## LLM & Embedding

### 1) Initialize the LLM

The ChatNVIDIA class is part of LangChain's integration (langchain_nvidia_ai_endpoints) with NVIDIA NIM microservices.
It allows access to NVIDIA NIM for chat applications, connecting to hosted or locally-deployed microservices.

Here we will use **mixtral-8x7b-instruct-v0.1**

Note: You can use any model hosted at the NVIDIA API catalog using 'ChatNVIDIA.get_available_models()'

In [2]:
from langchain_nvidia_ai_endpoints import ChatNVIDIA

llm = ChatNVIDIA(model="mistralai/mixtral-8x7b-instruct-v0.1", max_tokens=1024)

#### Note:
- In this notebook, we have used NVIDIA NIM microservices from the NVIDIA API Catalog.
- The other APIs, ChatNVIDIA, NVIDIAEmbedding, and NVIDIARerank, also support self-hosted NIM microservices.
- Change the `base_url` to your deployed NIM URL.
- Example: `llm = ChatNVIDIA(base_url="http://localhost:8000/v1", model="meta/llama3-8b-instruct")`
- NIM can be also hosted locally using Docker, following the [NVIDIA NIM for LLMs](https://docs.nvidia.com/nim/large-language-models/latest/getting-started.html) documentation. This is only true if you are the son of a rich oil tycoon, and have a few H100s sitting around in your basement.

In [3]:
from langchain_nvidia_ai_endpoints import ChatNVIDIA

# connect to an LLM NIM running at localhost:8000, specifying a specific model
llm = ChatNVIDIA(base_url="http://localhost:8000/v1", model="meta/llama3-8b-instruct")

### 2) Intialize the embedding
NVIDIAEmbeddings is a client to NVIDIA embeddings models that provides access to a NVIDIA NIM for embedding. It can connect to a hosted NIM or a local NIM using a base URL

We selected **NV-Embed-QA** as the embedding

In [3]:
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings

embedder = NVIDIAEmbeddings(model="NV-Embed-QA", truncate="END")

### 3) Obtain dataset
I love taxes and work! Lets steal information about taxes and work!

#### In the **DataHandler** class defined below, we can;

A) Walk through a webpage and find all sub-webpages and scrape the parent and children,

B) Extract texts, tables and images from pdfs

C) Split documents 

- Real world documents can be very long, this makes it hard to fit in the context window of many models. Even for those models that could fit the full post in their context window, models can struggle to find information in very long inputs.

- To handle this we’ll split the Document into chunks for embedding and vector storage. More on text splitting [here](https://python.langchain.com/v0.2/docs/concepts/#text-splitters)

In [12]:
import os
import requests
import pandas as pd
import urllib.parse  # To handle URL joining
import fitz

from tqdm import tqdm
from io import StringIO
from bs4 import BeautifulSoup, SoupStrainer
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader, DataFrameLoader, CSVLoader, UnstructuredTSVLoader

class DataHandler:
    """
    Masterfully handles data scraping, preprocessing, and other data-related functionalities in this notebook.
    """

    def __init__(self, csv_path="./data/csv"):
        self.webloaders = [] # tracks all urls that have been converted to langchain WebBaseLoaders.
        self.tabular = [] # tracks all tabular data that has been discovered by scraper. Delivers a list of CSVLoader. Cowardly refusing to save to same list
        self.raws = []
        self.csv_path = csv_path

        self.text_splitter = CharacterTextSplitter(chunk_size=1024, separator=" ", chunk_overlap=64)

        self.visited_urls = set()   # Classwide tracker to prevent repeated visits when scraping web
        self.tabular_data = []      # Tracks all tabular_data data that has been discovered by scraper. Delivers a list of CSVLoader. Cowardly refusing to save to same list.
        self.textual_data = []      # Tracks all textual data scraped from websites or pdfs
        self.all_data = []  # Just defining a joint list here instead of later during the functional call; in case we do not care about seperating them anymore (both are lists of Documents, just from different
        # base sources)

        if not os.path.exists(self.csv_path):
            os.mkdir(self.csv_path)
    
    @staticmethod
    def from_csv(csv_path):
        """
        Small func to read from csv and produce CSVLoaders.
        """
        df = pd.from_csv( csv_path) # so we can yoink its columns
        loader = CSVLoader(file_path=csv_path,
                                    csv_args={  'delimiter': ',',
                                                'quotechar': '"',
                                                'fieldnames': [str(col) for col in df.columns]}
                                    )
        return loader

    def csvs_to_loader(self, directory):
        """
        walks through directory, finds all csvs, and saves it into self.tabular 
        """
        for dir, subdir, files in os.walk(directory):
            for file in files:
                if file.endswith(".csv"):
                    fp = os.path.join(dir, file)
                    loader = self.from_csv(fp)
                    self.tabular.append(loader)


    def extract_table_elements(self, url, table_elements):
        """
        Helper func to extract tabular data

        Args:
            - url: url the table is under
            - table_elements: list of table elements
        """
        for table_idx, table in enumerate(table_elements): #TODO: Find better way to index different tables on the same page? not all have class attributes we can ID them with.
            try:
                tags = table.find_all('sup')
                for tag in tags:
                    tag.extract()

                # Attempt to find the title of this table. tableheader elements only tell us (pandas) how to index it, but what we need the header element for context on what this table is about
                tablename = f"table_{table_idx}" # default name presuming none is found
                for headertype in ['h3', 'h4']: # unlikely to lie in h2 or h1? could result in duplicate data. if we find by those.
                    header = table.find_previous(headertype)
                    if header is not None and header.text is not None: # find the closest header
                        tablename = self.clean_text(header.text).replace("#", '').replace("/", '')
                        break
                
                tablename = os.path.basename(url) + f" {tablename}" # what we will call this table, some has really annoying spacing, so maybe .replace(' ', '')?
                print(f"Grabbing table data under tablename {tablename}")
                df = pd.read_html(StringIO(str(table)), header=0)[0] # some tables do not have tableheader <th> tags for first row 
                # which would result in  generic column indices being created, so forcibly set first row as tableheader. 
                df['context'] = [tablename] * df.shape[0]

                if len(tablename) >= os.pathconf('/', 'PC_NAME_MAX'): # prevent shit from exploding because my tablename is damn scuffed but what todo.
                    tablename = tablename[:os.pathconf('/', 'PC_NAME_MAX') - 10]
                csv_path = os.path.join(self.csv_path, tablename + '.csv')

                # print("csv_path:", csv_path)
                df.to_csv( csv_path, index=False ) # if true will fuck up columning in csvloader
                # loader = UnstructuredTSVLoader(csv_path, mode='elements')
                loader = CSVLoader(file_path=csv_path,
                                    csv_args={  'delimiter': ',',
                                                'quotechar': '"',
                                                'fieldnames': [str(col) for col in df.columns]}
                                    )
                for row in loader.load()[1:]: # first row is just column indexes, so void
                    self.tabular.append(
                        self.text_splitter.split_text(
                            self.clean_text(row.page_content)
                            )
                        )

            except BaseException as e:
                print(f"Unable to extract table data from url {url} with error {e}, passing!")

    def create_loaders(self):
        """
        Seperate method to create the loaders. Directly appends to textual_data attribute and calls extract_table_elements to handle tabular data.
        """
        for url, soupy_little_guy in self.raws:
            loader = WebBaseLoader(
                    web_paths=(url,),  # No URL fetching as we already have the HTML content
                    bs_kwargs={"parse_only": SoupStrainer(['main'])},
                )
            html_content = loader.load()
            for i in range(len(html_content)):
                self.textual_data.extend(
                    self.text_splitter.split_text(
                        self.clean_text(
                            html_content[i].page_content
                )))

            table_elements = soupy_little_guy.find_all("table") 
            self.extract_table_elements(url=url, table_elements=table_elements)


    def scrape_website(self, base_url, max_depth, depth=0):
        """
        Wraps around a nested function get_from_website, so when ran will define a new function that knows that the sauce base_url is.
        Scuffed? Yes.
        """

        def get_from_website(url, max_depth, depth):
            """
            Recursively scrape a website by visiting links starting from url. 
            Because this is a mostly I/O bound operation, we make a seperate method that actually creates the Loaders.

            Parameters:
                - url:              URL to start scraping from
                - depth:            Current recursion depth
                - max_depth:        Maximum recursion depth to avoid infinite loops

            Returns:
                - appends to self.raws, [url, BeautifulSoup object] created from response content.
            """
            if url in self.visited_urls or depth > max_depth:
                return

            try:
                response = requests.get(url)
                soupy_little_guy = BeautifulSoup(response.content, 'html.parser')
                if response.status_code != 200:
                    return print(f"Failed to retrieve {url}")
            except Exception as e:
                return print(f"Error accessing {url} with error: {e}")
                

            self.visited_urls.add(url)
            print("Current url:", url)
            self.raws.append([url, soupy_little_guy])

            for link in soupy_little_guy.find_all('a', href=True):  # Find all links on the current page
                relative_url = link['href']
                absolute_url = urllib.parse.urljoin(url, relative_url)
                if base_url in absolute_url:  # Avoids external sites
                    get_from_website(absolute_url, max_depth, depth + 1)

        get_from_website(base_url, max_depth, depth)


    def scrape_pdf(self, pdf_path):
        """
        Extracts information from pdf files.
        Texts stay as texts.
        Tables and images...

        Parameters:
            - pdf_path:     PDF file to extract from

        Output:
            - textual_data: List of cleaned strings that represent data from pdf doc
        """
        try:
            pdf_document = fitz.open(pdf_path)
        except:
            return print(f"Unable to open {pdf_path}")
        print(f"Current pdf: {pdf_path}")
        pdf_text = ""
        for page_num in range(len(pdf_document)):
            page = pdf_document.load_page(page_num)
            pdf_text += page.get_text("text")

        self.textual_data.extend(
            self.text_splitter.split_text(
                self.clean_text(pdf_text)
        ))
    
    def clean_text(self, text):
        """
        Cleans text retrieved from sources to reduce the storage needed

        get_from_website(base_url, max_depth, depth)
        Parameters:
            - text:         Original string

        Output:
            - cleaned_text: Cleaned string 
        """
        return text.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ').replace('  ', ' ')
    
    def embed_text(self, text):
        """
        Embeds a given text if necessary

        Parameters:
            - text:     Original string

        Output:
            - text:     Embedded string 
        """
        return embedder.embed_query(text)

We define our websites and pdf links below to call the dataloader on.

In [13]:
# Max depth determines whether you wish to look into children of parents websites, else set to 0
max_depth = 0
websites = [
    "https://www.iras.gov.sg/taxes",
    "https://www.iras.gov.sg/schemes",
    "https://www.mom.gov.sg/passes-and-permits",
    "https://www.mom.gov.sg/employment-practices",
    "https://www.mom.gov.sg/workplace-safety-and-health"
]

datahandler = DataHandler()
for website in websites:
    datahandler.scrape_website(website, max_depth)

if os.path.isdir('./data/pdfs'):
    for pdf in os.listdir('./data/pdfs'):
        datahandler.scrape_pdf(os.path.join('./data/pdfs', pdf))

Current url: https://www.iras.gov.sg/taxes
Current url: https://www.iras.gov.sg/schemes
Current url: https://www.mom.gov.sg/passes-and-permits
Current url: https://www.mom.gov.sg/employment-practices
Current url: https://www.mom.gov.sg/workplace-safety-and-health
Current pdf: ./data/pdfs\verification-checklist-inspection-of-machines.pdf


In [14]:
datahandler.create_loaders() # only when we actually want to create loaders to debug stuff, so we don't need to scrape within the same function that we're (likely) debugging, saving time.

Grabbing table data under tablename passes-and-permits table_0
Unable to extract table data from url https://www.mom.gov.sg/passes-and-permits with error module 'os' has no attribute 'pathconf', passing!
Grabbing table data under tablename passes-and-permits Overseas Networks & Expertise Pass
Unable to extract table data from url https://www.mom.gov.sg/passes-and-permits with error module 'os' has no attribute 'pathconf', passing!
Grabbing table data under tablename passes-and-permits Work Permit for performing artiste
Unable to extract table data from url https://www.mom.gov.sg/passes-and-permits with error module 'os' has no attribute 'pathconf', passing!
Grabbing table data under tablename passes-and-permits Training Work Permit
Unable to extract table data from url https://www.mom.gov.sg/passes-and-permits with error module 'os' has no attribute 'pathconf', passing!
Grabbing table data under tablename passes-and-permits Letter of Consent for Dependant’s Pass holders who are busines

And then, we embed the textual and tabular data gathered for use in the vector db.

In [17]:
datahandler.embedded_data = [datahandler.embed_text(text) for text in tqdm(datahandler.textual_data, desc='Embedding textual data...')] + [datahandler.embed_text(text) for text in tqdm(datahandler.tabular_data, desc='Embedding tabular data...')]
datahandler.textual_data.extend(datahandler.tabular_data)

Embedding textual data...: 100%|██████████| 13/13 [00:07<00:00,  1.66it/s]
Embedding tabular data...: 0it [00:00, ?it/s]


### 4) Storing the documents

To build our foundational knowledge base from our collected data and allow for faster retrieval of vector queries, we need to have some form of search system.

#### a) Process the documents into vectorstore and save it to disk

Vectorstores are good when we wish to store small datasets mainly in memory, thus using less storage.

In this case, we use FAISS, which is a high-performace library that is efficient for similarity search and clustering of dense vectors.

In [11]:
# Here we create a faiss vector store from the documents and save it to disk.
from langchain_community.vectorstores import FAISS

# You will only need to do this once, later on we will restore the already saved vectorstore
store = FAISS.from_texts(datahandler.textual_data, embedder)
VECTOR_STORE = './data/nv_embedding'
store.save_local(VECTOR_STORE)

To enable runtime search, we index text chunks by embedding each document split and storing these embeddings in a vector database. Later to search, we embed the query and perform a similarity search to find the stored splits with embeddings most similar to the query.

Then, we can read the previously processed and saved vector store back for use;

In [12]:
from langchain_community.vectorstores import FAISS

# Load the FAISS vectorestore back.
VECTOR_STORE = './data/nv_embedding'
store = FAISS.load_local(VECTOR_STORE, embedder, allow_dangerous_deserialization=True)

#### b) Store all our data into a vectordb (connects vector stores to structured db systems)

We will be using **Milvus DB** since it is an open-source, distributed vector DB designed for high-performance vector similarity search across massive datasets (which is what we kind of will have), and makes use of popular libraries like FAISS or Annoy for its vector searching.

We have the option to use;
1) MilvusVectorStore (Milvus Lite) which is easier to implement and can be more easily tied with ML frameworks --> Milvus-lite does not have a windows version (only supports linux and mac)
2) MilvusClient (Python SDK) WITHOUT Docker which was used in a demo to build a RAG system --> Milvus-lite does not have a windows version (only supports linux and mac)
3) Direct DB WITH Docker (Milvus Standalone) to directly manage and control the connections as well as increase flexibility --> What im using first since it is the most straighforward (im trying to find if milvus standalone can be used locally instead of through docker)

Sources used:
- https://milvus.io/docs/quickstart.md
- https://milvus.io/docs/integrate_with_langchain.md
- https://milvus.io/docs/multimodal_rag_with_milvus.md
- https://github.com/milvus-io/bootcamp/tree/master/bootcamp/tutorials/quickstart/apps/multimodal_rag_with_milvus (Possible demo to use)

##### Note:
Previously, I thought of using **Weaviate** because it is a free, open-source, scalable and reliable vector database service with decent amounts of documentation online.

The Weaviate DB works more with Object-Oriented design rather than rows/columns and does not need to be structured like an SQL DB since it already has integrated models/pre-computed embeddings to handle vector embeddings.

However, I realised it **could not** be GPU accelerated during its vector retrieval/indexing.

Sources used:
- https://python.langchain.com/docs/integrations/vectorstores/weaviate/
- https://weaviate.io/developers/weaviate/client-libraries/python

In [None]:
# !pip install langchain-milvus
# !pip install -U pymilvus

In [90]:
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema
import hashlib

class MilvusDB:
    """
    Not-so-masterfully handles the vector database using Milvus and all DB related functionality.
    """

    def __init__(self, uri):
        self.uri = uri
        self.client = None
        self.collection_name = None

        self.similarity_threshold = 0.99  # Similarity between new input data and existing data
        self.relevancy_threshold = 0.99   # Similarity between queried data and query
        self.set_up_db()

    def set_up_db(self):
        """
        Starts up connection to Weaviate DB and creates schema if not already created

        Parameters: None

        Output:     None
        """
        self.client = MilvusClient(uri=self.uri)

    def load_collection(self):
        """
        Loads collection into RAM for faster retrieval

        Parameters: None

        Output:     None
        """
        self.client.load_collection(
            collection_name=self.collection_name,
            replica_number=1 # Number of replicas to create on query nodes. Max value is 1 for Milvus Standalone, and no greater than `queryNode.replicas` for Milvus Cluster.
        )

    def release_collection(self):
        """
        Releases the collection from memory to save memory usage

        Parameters: None

        Output:     None
        """
        self.client.release_collection(
            collection_name=self.collection_name
        )

    def create_collection(self, collection_name, dimensions):
        """
        Creates a new collection in the DB

        Parameters: 
            - collection_name:  Name of collection to make
            - dimensions:       Number of dimensions for vector data

        Output:     None
        """     
        # Checks if the client already has the collection in their Milvus instance
        # if self.client.has_collection(collection_name):  # TODO: Change this later
        #     self.client.drop_collection(collection_name)

        # Defines a schema to follow when creating the DB
        id_field = FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=512, is_primary=True)
        embedding_field = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimensions)
        text_field = FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=1152)  # Use VARCHAR for string types
        source_field = FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1152)  # Optional metadata field
        schema = CollectionSchema(fields=[id_field, embedding_field, text_field, source_field])

        # Creates the collection
        self.client.create_collection(
            collection_name=collection_name,
            dimension=dimensions,
            schema=schema,
            metric_type="COSINE",               # Inner product distance; or COSINE or L2 (Euclidean)
            consistency_level="Strong",     # Strong consistency level
        )
        self.collection_name = collection_name

        # Creates an index for more efficient similarity search later on based on the metric_type and index_type
        self.index_params = MilvusClient.prepare_index_params()
        self.index_params.add_index(
            field_name="embedding",
            metric_type="COSINE",
            index_type="IVF_FLAT",
            index_name="embedding_index",
            params={ "nlist": 128 }
        )
        self.client.create_index(
            collection_name=self.collection_name,
            index_params=self.index_params,
            sync=False
        )

    def insert_data(self, original, embedded):
        """
        Adds document and embedding object pairs to the DB collection if not alreadt inside 

        Parameters:
            - original:     Original documents
            - embedded:     Embedded documents

        Output: None
        """
        self.load_collection()
        data = []

        for i, embedded_line in enumerate(tqdm(embedded, desc="Inserting data...")):
            unique_id = self.generate_unique_id(embedded_line)
            if not self.check_for_similar_vectors(embedded_line) and not self.check_for_similar_ids(unique_id):
                data.append({"id": unique_id, "embedding": embedded_line, "text": original[i], "source": "None"})

        self.client.insert(collection_name=self.collection_name, data=data)
        self.release_collection()

    def generate_unique_id(self, data):
        """
        Generates a unique hash ID based on the vector or text data

        Parameters:
            - data: The vector or text data used to generate the hash

        Returns:
            - id:   A unique hash ID as a string
        """
        data_str = str(data)
        unique_id = hashlib.sha256(data_str.encode()).hexdigest()
        return unique_id

    def check_for_similar_vectors(self, embedding, top_k=5):
        """
        Checks the DB for vectors that are similar to the input embedding (based on distance metric like cosine similarity or Euclidean distance)

        Parameters:
            - embedding:    Embedded documents
            - top_k:        Top K number of documents that are similar to the input embedding

        Output: 
            - check:        True or False value of whether a similar vector has been found
        """
        try:
            search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
            results = self.client.search(
                collection_name=self.collection_name,
                data=[embedding],
                anns_field="embedding",  # Adjust field name based on your Milvus schema
                search_params=search_params,
                limit=top_k
            )
            for result in results:
                for vector in result:
                    if vector['distance'] >= self.similarity_threshold:
                        return True  
            return False
        except Exception as e:
            print(f"Error checking if vector exists: {e}")
            return False

    def check_for_similar_ids(self, id):
        """
        Checks the DB for ids that are similar to the new one created

        Parameters:
            - id:       Unique id generated for new row

        Output: 
            - check:    True or False value of whether a similar vector has been found
        """
        try:
            results = self.client.query(
                collection_name=self.collection_name,
                filter=f"id == '{id}'",  
                output_fields=["id"], limit=1000 
            )
            return True if results else False
        except Exception as e:
            print(f"Error checking if ID exists: {e}")
            return False

    def retrieve_data(self, question):
        """
        Retrieves vector data from DB based on embedded question

        Parameters:
            - question:     Embedded question as a vector

        Output:
            - search_res:   Results of vector retrieval
        """
        self.load_collection()
        search_res = self.client.search(
            collection_name=self.collection_name,
            data=[question],  
            limit=3,  # Return top 3 results
            search_params={"metric_type": "COSINE", "params": {}},  # Inner product distance
            output_fields=["text"],  # Return the text field
        )
        filtered_res = [res for res in search_res[0] if res['distance'] > self.relevancy_threshold]
        self.release_collection()
        return filtered_res

In [91]:
database = MilvusDB('http://localhost:19530')
database.create_collection("Documents", 1024)
database.insert_data(datahandler.textual_data + datahandler.tabular_data, datahandler.embedded_data)

Inserting data...: 100%|██████████| 13/13 [00:05<00:00,  2.25it/s]


### 5) Using data to answer questions

With our stored embedded data, we can retrieve relevant vectors stored in our vectorstore/db to answer embedded questions.

#### A) Using vectorstore to answer

In [None]:
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate

context = store.as_retriever()

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Answer solely based on the following:\n<Documents>\n{context}\n</Documents>",
        ),
        ("user", "{question}"),
    ]
)

# Langchain's LCEL(LangChain Expression Language) Runnable protocol is used to define the chain
# LCEL allows pipe together components and functions
chain = (
    {"context": context, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

#### B) Using VectorDB to answer

What happens if our vectordb does not have the relevant information needed though?

#### C) Using web sources to answer

Using Tavily, we can create a search engine that calls abstract searching, scraping, filtering and extracting from online sources.

Only issue is, it takes too long to get the answer out, so it should only be used as a final resort.

In [98]:
import os
os.environ["TAVILY_API_KEY"] = "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"

In [99]:
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, MessagesState
from typing import Literal

class SearchEngine:
    def __init__(self):
        self.search_tool = TavilySearchResults(max_results=3)
        self.llm = ChatNVIDIA(model="meta/llama-3.1-8b-instruct")

        self.tools = [self.search_tool]  # Add other tools here if needed
        self.llm_with_tools = self.llm.bind_tools(self.tools)
        self.workflow = self.build_workflow()

    def build_workflow(self):
        """Builds the state graph workflow."""
        workflow = StateGraph(MessagesState)
        
        # Add nodes for agent (LLM) and tools (search)
        workflow.add_node("agent", self.call_model)
        workflow.add_node("tools", ToolNode(self.tools))
        
        # Add conditional logic to determine if tools should be used
        workflow.add_edge("__start__", "agent")
        workflow.add_conditional_edges("agent", self.should_continue)
        workflow.add_edge("tools", "agent")
        
        return workflow.compile()

    def call_model(self, state):
        """Invoke the LLM with the current state."""
        messages = state["messages"]
        response = self.llm_with_tools.invoke(messages)
        return {"messages": [response]}

    def should_continue(self, state):
        """Determine whether to continue invoking tools or end the workflow."""
        messages = state["messages"]
        last_message = messages[-1]
        if last_message.tool_calls:
            return "tools"
        return "__end__"

    def query(self, user_query, location_context=' in singapore'):
        """Main method to process the user's query."""
        initial_state = MessagesState({"messages": [self.llm_with_tools.invoke(user_query + location_context)]})
        return self.workflow.invoke(initial_state)

In [103]:
def ask(question):
    retrieved_data = database.retrieve_data(embedder.embed_query(question))
    if len(retrieved_data) > 0:
        retrieved_lines_with_distances = [
            (res["entity"]["text"], res["distance"]) for res in retrieved_data[0]
        ]
        context = '\n'.join(
            [line_with_distance[0] for line_with_distance in retrieved_lines_with_distances]
        )
        prompt = f" \
            Context from database:\n{context}\n\n\
            User's question: {question}\n\n\
            Please generate an informative response.\
        "
        return llm.invoke(prompt).content
    print("No data found in database. Searching online...")
    search_engine = SearchEngine()
    response = search_engine.query(question)

    for message in response['messages']:
        print(message.content)
    

#### Now, we can try asking different types of questions:

Case 1: Irrelevant question

In [11]:
print(chain.invoke("What is component two hundred and twenty five million?"))

There is no information provided in the documents about a component or entity called "two hundred and twenty five million." The documents contain information about various financial transactions, mortgage and loan balances, estate duty calculations, and rules related to M&A allowance and motor vehicle expenses, but there is no mention of a "component" with a value of 225 million.


In [101]:
print(ask("What is component two hundred and twenty five million?"))

I can't find that information.  Is there anything else that I can help you with?

[{"url": "https://www.nbcnews.com/", "content": "judge asks in granting bail\nAsian America\nBrooklyn woman arrested after allegedly throwing hot coffee at a man in a Palestinian scarf\nU.S. news\nMichigan police make arrest in the disappearance of a woman missing since 2021\nAsian America\nBrooklyn woman arrested after allegedly throwing hot coffee at a man in a Palestinian scarf\nU.S. news\nMichigan police make arrest in the disappearance of a woman missing since 2021\nHealth\nHealth news\nMysterious dog respiratory illness may be caused by a new type of bacterial infection, researchers say\nCoronavirus\nFour more free Covid tests will be available to U.S. households\nHealth news\nThis type of belly fat is linked to increased risk of Alzheimer's, research finds\nHealth news\nNew weight loss drugs change how people think of Thanksgiving and other holiday meals\nHealth news\nDeadly listeria outbreak linke

Case 2: Simple questions

In [None]:
print(chain.invoke("How do i file taxes for my company?"))

In [104]:
print(ask("How do i file taxes for my company?"))

No data found in database. Searching online...
I don't have specific information about filing taxes for your company. However, I can guide you to a suitable resource in Singapore. You may wish to refer to the Inland Revenue Authority of Singapore (IRAS) for a step-by-step guide on how to file taxes in Singapore. 

Meantime, I will not make a function call in this instance as you can easily access the IRAS website yourself to obtain the necessary information.

[{"url": "https://www.pilotoasia.com/guide/personal-income-tax-singapore", "content": "Filing of tax returns is required if your annual income is or more. Starting from YA 2024, the top marginal Personal Income Tax rate will be increased from 22% to 24%. Starting from YA 2024, the personal income tax rate in Singapore for non-tax residents will be set at 24%. Apart from a few exceptions, overseas income is exempted from taxation."}, {"url": "https://mytax.iras.gov.sg/", "content": "At AXS Stations if you are a DBS/POSB customer (f

Case 3: Complex questions

In [None]:
print(chain.invoke("In the event my foriegn employee is injured at work, how do i report the incident and claim reparations?"))

Case 4: Realistic questions

In [13]:
chain.invoke("I am an employer whose firm qualifies for PWCs. From 2025 onwards, how much Co-Funding can I recieve from the government?")
# In its current state, the model does an ethan for this quite hard question, although it at least does retrieve some (few?) relevant PWCS info.

'From 2025 onwards, as an employer whose firm qualifies for PWCS (Productivity Works Credits Scheme), you can receive co-funding from the government of up to 50% for the first tier of wage increases and 15% to 30% for the second tier of wage increases. This co-funding support applies to wage increases given in qualifying year 2025 and onwards. The gross monthly wage ceiling for PWCS co-funding will be increased to $3,000 in qualifying years 2025 and 2026.\n\nPlease note that the specific rates and details of the co-funding support may be subject to changes or updates in the PWCS guidelines. It is recommended to consult the official guidelines or contact the relevant authorities for the most accurate and up-to-date information.'

### 6) Enhancing accuracy for single data sources

This example demonstrates how a re-ranking model can be used to combine retrieval results and improve accuracy during retrieval of documents.

Typically, reranking is a critical piece of high-accuracy, efficient retrieval pipelines. Generally, there are two important use cases:

- Combining results from multiple data sources
- Enhancing accuracy for single data sources

Here, we focus on demonstrating only the second use case. If you want to know more, check [here](https://github.com/langchain-ai/langchain-nvidia/blob/main/libs/ai-endpoints/docs/retrievers/nvidia_rerank.ipynb)

In [None]:
from langchain_nvidia_ai_endpoints import NVIDIARerank
from langchain_core.runnables import RunnableParallel

# We will narrow the collection to 100 results and further narrow it to 10 with the reranker.
retriever = store.as_retriever(search_kwargs={'k':100}) # typically k will be 1000 for real world use-cases
ranker = NVIDIARerank(model='nv-rerank-qa-mistral-4b:1', top_n=10)

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Answer solely based on the following context:\n<Documents>\n{context}\n</Documents>",
        ),
        ("user", "{question}"),
    ]
)

reranker = lambda input: ranker.compress_documents(query=input['question'], documents=input['context'])

chain_with_ranker = (
    RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
    | {"context": reranker, "question": lambda input: input['question']}
    | prompt
    | llm
    | StrOutputParser()
)


In [None]:
print(chain_with_ranker.invoke("How do i file taxes for my company?"))

Based on the provided documents, to file taxes for your company, you need to follow these steps:

1. Ensure that you are duly authorized by your company as an 'Approver' for Corporate Tax (Filing and Applications) in Corppass. You can refer to the step-by-step guides for assistance on Corppass setup.
2. Have your Singpass and your company’s Unique Entity Number (UEN)/ Entity ID ready.
3. Visit the mytax.iras.gov.sg website to file the Corporate Income Tax Return for your company.
4. If your company is filing Form C, you need to submit the financial statements/certified accounts and tax computation(s) for the relevant Year of Assessment (YA).
5. If your company meets the qualifying conditions to file Form C-S or Form C-S (Lite), you can choose to file the simplified version, Form C-S (Lite), if your company has an annual revenue of $200,000 or below.

You can also visit the Basic Guide to Corporate Income Tax for Companies page to get help with filing your company’s tax returns for the 

In [None]:
print(chain_with_ranker.invoke("In the event my foriegn employee is injured at work, how do i report the incident and claim reparations?"))

Based on the provided documents, if your foreign employee is injured at work, you can report the incident and claim reparations under the Work Injury Compensation Act (WICA). Specifically, the documents mention that input tax can be claimed for work injury compensation insurance that is obligatory under WICA for both local and foreign employees performing manual work or non-manual work earning $2,600 or less a month.

To report the incident and make a claim, you can visit the Ministry of Manpower (MOM) webpage on WICA or contact MOM at +65 6438 5122. However, it is important to note that medical and accident insurance premiums for your staff are generally not allowable for input tax claims under the GST (General) Regulations, unless the insurance or payment of compensation is obligatory under WICA or under any collective agreement within the meaning of the Industrial Relations Act.

Therefore, it seems that you can report the incident and claim reparations under WICA, but you should ve