In [None]:
!python3 -m pip install --upgrade langchain deeplake sagemaker tiktoken
!aws configure set aws_access_key_id[REDACTED]
!aws configure set aws_secret_access_key[REDACTED]
!aws configure set default.region us-east-1
!export "ACTIVELOOP_TOKEN"=[REDACTED]

In [None]:
import sagemaker

sess = sagemaker.Session()
sagemaker_session_bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
model_name = "all-MiniLM-L6-v2"  # "mpt-7b-instruct" # "dolly-v2-12b" "flan-t5-xxl"

In [None]:
from sagemaker.huggingface.model import HuggingFaceModel

huggingface_model = HuggingFaceModel(
    model_data=f"s3://{sess.default_bucket()}/{model_name}/model.tar.gz",
    role=role,
    transformers_version="4.26",
    pytorch_version="1.13",
    py_version="py39",
    model_server_workers=1
)

In [None]:
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.8xlarge",  # "ml.g5.4xlarge"
    endpoint_name=model_name,
    model_data_download_timeout=3600,
    container_startup_health_check_timeout=3600,
    update_endpoint=True
)

In [None]:
import json
from typing import Dict, List

from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase


class ContentHandler(ContentHandlerBase):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": inputs, **model_kwargs})
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["vectors"]


embeddings = SagemakerEndpointEmbeddings(
    endpoint_name="all-MiniLM-L6-v2",
    region_name="us-east-1",
    content_handler=ContentHandler()
)

In [None]:
!git clone https://github.com/apache/spark

In [None]:
import os
from langchain.document_loaders import TextLoader

root_dir = "spark"
docs = []
for dirpath, dirnames, filenames in os.walk(root_dir):
    for file in filenames:
        try:
            loader = TextLoader(os.path.join(dirpath, file), encoding="utf-8")
            docs.extend(loader.load_and_split())
        except Exception as e:
            pass

In [None]:
from langchain.text_splitter import CharacterTextSplitter

text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(docs)

In [None]:
username = "eschizoid"
dataset_path = "hub://eschizoid/spark"

In [None]:
from langchain.vectorstores import DeepLake

db = DeepLake(dataset_path=f"hub://{username}/spark", embedding_function=embeddings)
db.add_documents(texts)

In [None]:
db = DeepLake(dataset_path=dataset_path, read_only=True, embedding_function=embeddings)

In [None]:
retriever = db.as_retriever()
retriever.search_kwargs["distance_metric"] = "cos"
retriever.search_kwargs["k"] = 20

In [None]:
import json

from langchain import SagemakerEndpoint
from langchain.chains import RetrievalQA
from langchain.llms.sagemaker_endpoint import LLMContentHandler


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs) -> bytes:
        input_str = json.dumps({prompt: prompt, **model_kwargs})
        return input_str

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]


model = SagemakerEndpoint(
    endpoint_name="all-MiniLM-L6-v2",
    region_name="us-east-1",
    credentials_profile_name="default",
    content_handler=ContentHandler(),
)

qa = RetrievalQA.from_llm(model, retriever=retriever)