In [None]:
!python3 -m pip install --upgrade langchain deeplake sagemaker tiktoken

In [None]:
import os
import getpass

from langchain.embeddings.sagemaker_endpoint import SagemakerEndpointEmbeddings

os.environ["ACTIVELOOP_TOKEN"] = getpass.getpass("Activeloop Token:")
embeddings = SagemakerEndpointEmbeddings()

In [None]:
!git clone https://github.com/apache/spark

In [None]:
import os
from langchain.document_loaders import TextLoader

root_dir = "spark"
docs = []
for dirpath, dirnames, filenames in os.walk(root_dir):
    for file in filenames:
        try:
            loader = TextLoader(os.path.join(dirpath, file), encoding="utf-8")
            docs.extend(loader.load_and_split())
        except Exception as e:
            pass

In [None]:
from langchain.text_splitter import CharacterTextSplitter

text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(docs)

In [None]:
username = "eschizoid"
dataset_path = "hub://eschizoid/spark"

In [None]:
from langchain.vectorstores import DeepLake

db = DeepLake(dataset_path=f"hub://{username}/spark", embedding_function=embeddings, public=True)
db.add_documents(texts)

In [None]:
db = DeepLake(dataset_path=dataset_path, read_only=True, embedding_function=embeddings)

In [None]:
retriever = db.as_retriever()
retriever.search_kwargs["distance_metric"] = "cos"
retriever.search_kwargs["k"] = 20

In [None]:
import json

from langchain import SagemakerEndpoint
from langchain.chains import RetrievalQA
from langchain.llms.sagemaker_endpoint import LLMContentHandler


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs) -> bytes:
        input_str = json.dumps({prompt: prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]


model = SagemakerEndpoint(
    endpoint_name="dolly-v2-12b",
    region_name="us-east-1",
    credentials_profile_name="default",
    content_handler=ContentHandler(),
)

qa = RetrievalQA.from_llm(model, retriever=retriever)