In [1]:
import os
import logging
import sys
import requests
import time
from contextlib import contextmanager
import textwrap

import faiss
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, ServiceContext, LLMPredictor, StorageContext #, LangchainEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
#from llama_index.indices.vector_store import ChatGPTRetrievalPluginIndex

from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.huggingface import HuggingFaceEmbeddings


#logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
logging.basicConfig(stream=sys.stdout, level=logging.INFO, force=True)

In [2]:
@contextmanager
def timer(title):
    t0 = time.time()
    print(f"[{title}] - start.")
    yield
    print(f"[{title}] - done in {time.time() - t0 :.1f}s.")

In [3]:
# from https://github.com/jerryjliu/llama_index/blob/v0.5.27/gpt_index/indices/vector_store/vector_indices.py#L686

from typing import Any, Callable, Dict, Optional, Sequence, Type
from requests.adapters import Retry

from llama_index.data_structs.node import Node
from llama_index.indices.service_context import ServiceContext
from llama_index.data_structs.data_structs import IndexDict
from llama_index.data_structs.struct_type import IndexStructType
from llama_index.vector_stores import ChatGPTRetrievalPluginClient


class ChatGPTRetrievalPluginIndexDict(IndexDict):
    """Index dict for ChatGPT Retrieval Plugin."""

    @classmethod
    def get_type(cls) -> IndexStructType:
        """Get type."""
        return IndexStructType.CHATGPT_RETRIEVAL_PLUGIN


class ChatGPTRetrievalPluginIndex(GPTVectorStoreIndex):
    """ChatGPTRetrievalPlugin index.

    This index directly interfaces with any server that hosts
    the ChatGPT Retrieval Plugin interface:
    https://github.com/openai/chatgpt-retrieval-plugin.

    Args:
        client (Optional[OpensearchVectorClient]): The client which encapsulates
            logic for using Opensearch as a vector store (that is, it holds stuff
            like endpoint, index_name and performs operations like initializing the
            index and adding new doc/embeddings to said index).
        service_context (ServiceContext): Service context container (contains
            components like LLMPredictor, PromptHelper, etc.).
    """

    index_struct_cls: Type[IndexDict] = ChatGPTRetrievalPluginIndexDict

    def __init__(
        self,
        nodes: Optional[Sequence[Node]] = None,
        index_struct: Optional[ChatGPTRetrievalPluginIndexDict] = None,
        service_context: Optional[ServiceContext] = None,
        endpoint_url: Optional[str] = None,
        bearer_token: Optional[str] = None,
        retries: Optional[Retry] = None,
        batch_size: int = 100,
        vector_store: Optional[ChatGPTRetrievalPluginClient] = None,
        **kwargs: Any,
    ) -> None:
        """Init params."""

        if vector_store is None:
            if endpoint_url is None:
                raise ValueError("endpoint_url is required.")
            if bearer_token is None:
                raise ValueError("bearer_token is required.")
            vector_store = ChatGPTRetrievalPluginClient(
                endpoint_url,
                bearer_token,
                retries=retries,
                batch_size=batch_size,
            )
        assert vector_store is not None

        super().__init__(
            nodes=nodes,
            index_struct=index_struct,
            service_context=service_context,
            vector_store=vector_store,
            **kwargs,
        )

# indexの作成

In [4]:
llm_name = "chatgpt" # "chatgpt", "gpt3"
vector_search_method = "qdrant"  # "simple", "faiss", "qdrant"
embed_model_name = "default"  # "default", "oshizo/sbert-jsnli-luke-japanese-base-lite"
"""
MEMO: cpu環境なのでembedはやめておく
"""

if llm_name == "gpt3":
    llm = OpenAI(
        model_name="text-davinci-003",
        max_tokens=1024,
        temperature=0,
        frequency_penalty=0.02,
    )
elif llm_name == "chatgpt":
    llm = ChatOpenAI(
        model_name="gpt-3.5-turbo",
        max_tokens=1024,
        temperature=0,
    )

if embed_model_name == "default":
    embed_model = None
else:
    embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name=embed_model_name))

documents = SimpleDirectoryReader("data").load_data()
llm_predictor = LLMPredictor(llm=llm)
service_context = ServiceContext.from_defaults(
    llm_predictor=llm_predictor,
    embed_model=embed_model,
)


with timer("make index"):
    if vector_search_method == "simple":
        index = GPTVectorStoreIndex.from_documents(
            documents,
            service_context=service_context,
        )

    elif vector_search_method == "faiss":
        d = 1536  # dimensions of text-ada-embedding-002
        faiss_index = faiss.IndexFlatIP(d)  # cosine similarity
        vector_store = FaissVectorStore(faiss_index=faiss_index)
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        index = GPTVectorStoreIndex.from_documents(
            documents,
            storage_context=storage_context,
            service_context=service_context,
        )

    elif vector_search_method == "qdrant":
        index = ChatGPTRetrievalPluginIndex.from_documents(
            documents, 
            endpoint_url="http://chatgpt-retrieval-plugin:7000",
            bearer_token=os.getenv("BEARER_TOKEN"),
            service_context=service_context,
        )


# クエリエンジン作成
with timer("make query engine"):
    query_engine = index.as_query_engine(verbose=True)

[make index] - start.
INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens
INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 15402 tokens
[make index] - done in 1.5s.
[make query engine] - start.
[make query engine] - done in 0.0s.


# 質問する

In [5]:
prompt = textwrap.dedent(
    """
    ぼっちちゃんの苦手なことを教えてください
    """
)

with timer("ask question"):
    res = query_engine.query(prompt)

[ask question] - start.
INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens
INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 31 tokens
INFO:llama_index.token_counter.token_counter:> [get_response] Total LLM token usage: 1125 tokens
INFO:llama_index.token_counter.token_counter:> [get_response] Total embedding token usage: 0 tokens
[ask question] - done in 2.4s.


In [6]:
print(res)

ぼっちちゃんは人と接することが極度に苦手である。


# TODO
- [ ] llama hubでPDF, Twitter、Slack、Wiki を引っ張ってこれるか試す
- [ ] notionはできる？
- [x] 画像の扱いを整理する
- [x] chatgpt plugin試す

# indexを保存

In [None]:
index.storage_context.persist()

# indexをロード

In [None]:
from llama_index import StorageContext, load_index_from_storage

storage_context = StorageContext.from_defaults(persist_dir="./storage")
index = load_index_from_storage(storage_context)