In [1]:
import os
import torch
import nest_asyncio
from dotenv import load_dotenv

from transformers import pipeline, TextStreamer

from llama_index import download_loader, SummaryPrompt, LLMPredictor, GithubRepositoryReader, GPTVectorStoreIndex, GPTTreeIndex, GPTListIndex, PromptHelper, SimpleDirectoryReader, load_index_from_storage, StorageContext, ServiceContext, LangchainEmbedding
from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
from llama_index.node_parser import SimpleNodeParser, NodeParser

from langchain.llms import HuggingFacePipeline
from langchain.embeddings import HuggingFaceEmbeddings

load_dotenv()
nest_asyncio.apply()

In [2]:
# define prompt helper
# set maximum input size
max_input_size = 512
# set number of output tokens
num_output = 128
# set maximum chunk overlap
max_chunk_overlap = 20
prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap)

## Define LLM

In [3]:
generate_text = pipeline(model="databricks/dolly-v2-12b", torch_dtype=torch.bfloat16,
                         trust_remote_code=True, device_map="auto", return_full_text=True)

hf_pipeline = HuggingFacePipeline(pipeline=generate_text)
llm_predictor = LLMPredictor(llm=hf_pipeline)

## Define Embedding model

In [6]:
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
INFO:sentence_transformers.SentenceTransformer:Use pytorch device: cuda


In [7]:
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, embed_model=embed_model)

In [8]:
text_splitter = TokenTextSplitter(separator=" ", chunk_size=256, chunk_overlap=20)
parser = SimpleNodeParser(text_splitter=text_splitter)

# RSS feed data collector

In [4]:
# RSS Feed
RssReader = download_loader("RssReader")

reader = RssReader()
rss_feed_documents = reader.load_data([
    "https://pytorch.org/feed.xml",
])

rss_feed_nodes = parser.get_nodes_from_documents(rss_feed_documents)
rss_feed_index = GPTVectorStoreIndex(rss_feed_nodes, service_context=service_context)

## GitHub data collector

In [5]:
# GitHub Repo
download_loader("GithubRepositoryReader")

from llama_index.readers.llamahub_modules.github_repo import GithubRepositoryReader, GithubClient

github_client = GithubClient(os.getenv("GITHUB_TOKEN"))
loader = GithubRepositoryReader(
    github_client,
    owner =                  "pytorch",
    repo =                   "pytorch.github.io",
    filter_directories =     (["docs", "_posts", "_getting_started", "_news", "_mobile"], GithubRepositoryReader.FilterType.INCLUDE),
    filter_file_extensions = ([".md"], GithubRepositoryReader.FilterType.INCLUDE),
    verbose =                False,
    concurrent_requests =    10,
)

loader = GithubRepositoryReader(
    github_token=os.getenv("GITHUB_TOKEN"),
    owner="jagadeeshi2i",
    repo="pytorch.github.io-1",
    use_parser=False,
    verbose=False,
    concurrent_requests = 10
)
github_docs = loader.load_data(branch="master")
github_site_nodes = parser.get_nodes_from_documents(github_docs)
github_site_index = GPTVectorStoreIndex(github_site_nodes, service_context=service_context)

## Vector data collector

In [6]:
# PyTorch Docs
pytorch_docs = SimpleDirectoryReader(input_files = ['./faiss/index.pkl']).load_data()
pytorch_docs_nodes = parser.get_nodes_from_documents(pytorch_docs)
pytorch_docs_index = GPTVectorStoreIndex(pytorch_docs_nodes, service_context=service_context)

## ChatGPT Plugin data collector

In [None]:
ChatGPTRetrievalPluginReader = download_loader("ChatGPTRetrievalPluginReader")

bearer_token = os.getenv("BEARER_TOKEN")
reader = ChatGPTRetrievalPluginReader(
    endpoint_url="http://localhost:8000",
    bearer_token=bearer_token
)

plugin_documents = reader.load_data("text query")
plugin_nodes = parser.get_nodes_from_documents(plugin_documents)
plugin_index = GPTVectorStoreIndex(plugin_nodes, service_context=service_context)
plugin_index.storage_context.persist(persist_dir="plugin_data")""

## Save index

In [None]:
rss_index_id = rss_feed_index.index_id
rss_feed_index.storage_context.persist(persist_dir="rss_feed")
pytorch_docs_index_id = pytorch_docs_index.index_id
pytorch_docs_index.storage_context.persist(persist_dir="pytorch_data")
github_site_index_id = github_site_index.index_id
github_site_index.storage_context.persist(persist_dir="github_data")