In [24]:
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.embeddings.google_palm import GooglePalmEmbeddings
from langchain.llms import GooglePalm
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma

### Embedding HIPAA

In [25]:
with open(".streamlit/secrets.toml", "r") as file:
    GOOGLE_PALM_API_KEY = file.read().split('=')[1]

In [26]:
loader = TextLoader("data/hipaa.txt")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)

In [27]:
len(texts)

766

In [28]:
max([len(doc.page_content) for doc in texts])

1000

In [35]:
embeddings = GooglePalmEmbeddings(google_api_key = GOOGLE_PALM_API_KEY)
chunk = 100
for i in range(len(texts)//chunk + 1):
    Chroma.from_documents(texts[chunk*i:min(chunk*(i+1), len(texts))], embeddings, persist_directory="./chroma_db")
    print(f"Finished {chunk*i} to {chunk*(i+1)-1}...")

Finished 0 to 99...
Finished 100 to 199...
Finished 200 to 299...
Finished 300 to 399...
Finished 400 to 499...
Finished 500 to 599...
Finished 600 to 699...
Finished 700 to 799...


### Running RAG Queries

In [45]:
docsearch = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)

In [46]:
qa = RetrievalQA.from_chain_type(llm=GooglePalm(google_api_key=GOOGLE_PALM_API_KEY), chain_type="stuff", retriever=docsearch.as_retriever())

In [48]:
qa.run("What does HIPAA stand for?")

'Health Insurance Portability and Accountability Act of 1996'