In [5]:
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA
import os
from jilm.model import JILMLangModel


In [6]:
from jilm.document_loader import DocumentLoader
from jilm.settings import CHROMA_SETTINGS, PERSIST_DIRECTORY

In [7]:
embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME')
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)


[2023-05-19 23:48:41,074] {SentenceTransformer.py:66} INFO - Load pretrained SentenceTransformer: all-mpnet-base-v2
[2023-05-19 23:48:45,602] {SentenceTransformer.py:105} INFO - Use pytorch device: cuda


In [8]:
doc = DocumentLoader.load_single_document("../README.md")


[nltk_data] Downloading package punkt to /home/jbp/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/jbp/nltk_data...


[2023-05-19 23:48:59,222] {xml.py:96} INFO - Reading document from string ...
[2023-05-19 23:48:59,226] {html.py:99} INFO - Reading document ...


[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


In [9]:
# Create and store locally vectorstore
db = Chroma.from_documents([doc], embeddings, persist_directory=PERSIST_DIRECTORY, client_settings=CHROMA_SETTINGS)
db.persist()
db = None

[2023-05-19 23:49:06,055] {__init__.py:91} INFO - Running Chroma using direct local API.
[2023-05-19 23:49:06,173] {ctypes.py:22} INFO - Successfully imported ClickHouse Connect C data optimizations
[2023-05-19 23:49:06,179] {ctypes.py:31} INFO - Successfully import ClickHouse Connect C/Numpy optimizations
[2023-05-19 23:49:06,232] {json_impl.py:45} INFO - Using python library for writing JSON byte strings
[2023-05-19 23:49:06,633] {duckdb.py:461} INFO - No existing DB found in tmp/vector-db, skipping load
[2023-05-19 23:49:06,635] {duckdb.py:473} INFO - No existing DB found in tmp/vector-db, skipping load


Batches: 100%|██████████| 1/1 [00:02<00:00,  2.98s/it]


ValueError: Expected metadata value to be a str, int, or float, got ../README.md

In [None]:
retriever = db.as_retriever()

In [None]:
#llm = JILMLangModel(retriever=retriever, embeddings=embeddings, max_tokens=1000, chunk_size=64, chunk_overlap=0)
llm = JILMLangModel(callbacks=[StreamingStdOutCallbackHandler()], retriever=retriever, embeddings=embeddings, max_tokens=1000, chunk_size=64, chunk_overlap=0)

In [None]:
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)

In [None]:
query = "What is the name of the project?"
res = qa(query)    
answer, docs = res['result'], res['source_documents']

In [None]:
for document in docs:
    print("\n> " + document.metadata["source"] + ":")
    print(document.page_content)