### IRIS GraphRAG Demo

This notebook is a demo of using IRIS Vector Search capabilities for a graphrag application

The following cell is used to get all the requirements. The jupyter image should already have these downloaded, but running this cell just to be safe is advised

In [None]:
! pip install -U langchain_community arxiv tiktoken langchainhub pymilvus langchain langgraph tavily-python sentence-transformers langchain-milvus langchain-ollama langchain-huggingface beautifulsoup4 langchain-experimental neo4j json-repair langchain-openai langchain-ollama

This is just some basic setup for the langchain application

In [1]:
from dotenv import load_dotenv
from langchain.globals import set_verbose, set_debug
import os

max_papers=20
data_path="/home/jevyan/workspace/data/"

load_dotenv()

# Set langchain variables
set_debug(False)
set_verbose(False)

Here you should set your OPENAI KEY to be used for the llm model

In [6]:
### LLM
import os
os.environ["OPENAI_API_KEY"] = "<insert_token_here>

gpt4omini = "gpt-4o-mini"

model = gpt4omini

For our project we are using immunology and clinical trials papers from arxiv x. Since the data is loaded already we commented these cells out, but you can add your own data extraction here.

In [7]:
import arxiv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_milvus import Milvus
from langchain_community.embeddings import HuggingFaceEmbeddings

''' Uncomment and replace with your own data if desired
search_query = "immunology OR 'clinical trials' OR 'neuroscience'"
max_results = max_papers

# Fetch papers from arXiv
client = arxiv.Client()
search = arxiv.Search(
    query=search_query, max_results=max_results, sort_by=arxiv.SortCriterion.Relevance
)

docs = []
for result in client.results(search):
    docs.append(
        {"title": result.title, "summary": result.summary, "url": result.entry_id, "authors": result.authors}
    )

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=2000, chunk_overlap=50
)
doc_splits = text_splitter.create_documents(
    [doc["summary"]+" "+doc["title"]+""+str(doc["authors"]) for doc in docs], metadatas=docs
)

print(f"Number of papers: {len(docs)}")
print(f"Number of chunks: {len(doc_splits)}")
'''

In [None]:
filename=data_path+"docs"+str(max_papers)+".csv"
with open(filename,"w") as file:
    print("docid|title|abstract|url|authors",file=file)
    s="|,"
    for i,doc in enumerate(docs):
        abstract=doc['summary'].replace("\n",' ')
        title=doc['title']
        try:
            print(f"{i}|{title}|{abstract}|{doc['url']}",end="",file=file)
        except UnicodeEncodeError:
            err=1
        a=0
        for author in doc["authors"]:
            auth=str(author).replace('\u0107','').replace('\u0131','').replace('\u0142','').replace('\u016b','').replace('\u010d','')
            auth=auth.replace('\u0111','').replace('\u015f','')
            try:
                print(f"{s[a]}{auth}",end="",file=file)
                a=1
            except UnicodeEncodeError:
                err=2
        print(file=file)

In [None]:
# GraphRAG Setup
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_core.documents import Document
from langchain_experimental.llms.ollama_functions import OllamaFunctions
from langchain_experimental.graph_transformers.diffbot import DiffbotGraphTransformer
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama

graph_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini")

graph_transformer = LLMGraphTransformer(
    llm=graph_llm,
    allowed_nodes=["Paper", "Author", "Topic"],
    node_properties=["title", "summary", "url", "author"],
    allowed_relationships=["AUTHORED", "DISCUSSES", "RELATED_TO"],
)

graph_documents = graph_transformer.convert_to_graph_documents(doc_splits)

print(f"Graph documents: {len(graph_documents)}")
print(f"Nodes from 1st graph doc:{graph_documents[0].nodes}")
print(f"Relationships from 1st graph doc:{graph_documents[0].relationships}")

In [None]:
filename=data_path+"entities"+str(max_papers)+".csv"
with open(filename,"w") as file:
    print("docid|entityid|type",file=file)
    for i, doc in enumerate(graph_documents):
        for node in doc.nodes:
            try:
                print(f"{i}|{node.id}|{node.type}",file=file)
            except UnicodeEncodeError:
                err=3

In [None]:
filename=data_path+"relations"+str(max_papers)+".csv"
with open(filename,"w") as file:
    print("docid|source|sourcetype|target|targettype|type",file=file)
    for i, doc in enumerate(graph_documents):
        for rel in doc.relationships:
            try:
                print(f"{i}|{rel.source.id}|{rel.source.type}|{rel.target.id}|{rel.target.type}|{rel.type}",file=file)
            except UnicodeEncodeError:
                err=4

In [None]:
# load iris module
import iris
import pandas as pd
import warnings

warnings.simplefilter("ignore")

In [None]:
# change these variables to reflect your connection
hostname = "iris"
port = 1972
namespace = "IRISAPP"
username = "SuperUser"
password = "SYS"

# connect
connection = iris.connect("{:}:{:}/{:}".format(hostname, port, namespace), username, password)

In [None]:
irispy = iris.createIRIS(connection)
irispy.classMethodValue("classname","methodname",args)

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
import ast

def extract_query_entities(query):

  prompt_text = '''Based on the following example, extract entities from the user provided queries.
                Below are a number of example queries and their extracted entities. Provide only the entities.
                'How many wars was George Washington involved in' -> ['War', 'George Washington'].\n
                'What are the relationships between the employees' -> ['relationships','employees].\n

                For the following query, extract entities as in the above example.\n query: {content}'''

  llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
  prompt = ChatPromptTemplate.from_template(prompt_text)
  chain = prompt | llm | StrOutputParser()
  response = chain.invoke({"content": query})
  return ast.literal_eval(response)

entities = extract_query_entities("What are the most common Phase I trials?")
print(entities)