# Notebook to load a huggingface dataset which is a subset of wikipdeia and index it in to ELasticsearch to power semantic search and RAG architectures

In [None]:
!python3 -m pip install -qU langchain
!pip install jq

In [None]:
from datasets import load_from_disk
from dotenv import load_dotenv
import os
import getpass
from langchain.document_loaders import JSONLoader
import json
from pathlib import Path
from pprint import pprint
from langchain.vectorstores import ElasticsearchStore
from langchain.text_splitter import CharacterTextSplitter
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
import time

In [None]:
 # Variables
DATASET_PATH = ''
INDEX_NAME = ""

In [None]:

page_content_column = "text"
articles = load_from_disk(DATASET_PATH)
articles=articles.to_json('./temp.json')



In [None]:
# Define the metadata extraction function.
def metadata_func(record: dict, metadata: dict) -> dict:

    metadata["url"] = record.get("url")
    metadata["title"] = record.get("title")
    metadata["ID"] = record.get("ID")

    return metadata

In [None]:
%%capture
loader = JSONLoader(
    file_path='./temp.json',
    jq_schema='.',
    content_key='text',
    json_lines=True,
    metadata_func=metadata_func)

text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50);

data = loader.load_and_split(text_splitter)
pprint(data[0].dict())

In [None]:
# Setup your elasticsearch connection. Thy myes
#es=

In [None]:
vector_store =  ElasticsearchStore(es_connection=es,index_name= INDEX_NAME,        )

documents = vector_store.from_documents(data[0:1], es_connection=es, index_name=INDEX_NAME,strategy=ElasticsearchStore.SparseVectorRetrievalStrategy())

In [None]:
# Update index settings to remove the default pieline and up the refresh interval
index_settings = {
    "settings": {
        "index.refresh_interval": "5m",
        "index.default_pipeline": "_none"
    }
}

# Update the settings for the specified index
response = es.indices.put_settings(index=INDEX_NAME, body=index_settings)
pprint(response)


In [None]:
# Bulk indexing function
def index_documents(documents):
    actions = []
    for doc in documents:
        action = {
            "_index": INDEX_NAME,
            "_source": doc
        }
        
        actions.append(action)
    
    success, _ = bulk(es, actions=actions,raise_on_error=False)
    return success

In [None]:
pprint(data[0].dict())

In [None]:
batch_size = 1000   
batch=0
for i in range(0, len(data), batch_size):
    print(f'Batching docs: {i} to {i+batch_size} of {len(data)}')
    batch =[]
    for doc in data[i:i+batch_size]:
        doc=doc.dict()
        doc['text']=doc['page_content']
        batch.append(doc)
    print(index_documents(batch))


In [None]:
# Perform a manual refresh on the index
response = es.indices.refresh(index=INDEX_NAME)

PIPELINE_NAME = ".elser_model_1_sparse_embedding"

# Define the update by query request
update_by_query_body = {
    "query": {
        "bool": {
            "must_not": [
                {"exists": {"field": "vector.model_id"}}
            ]
        }
    }
}
# Execute the update by query operation
response = es.update_by_query(index=INDEX_NAME, body=update_by_query_body,pipeline=PIPELINE_NAME,wait_for_completion=False)

# Print the response
print(response)

In [None]:
TASK_ID =''

# Define the polling interval in seconds
POLL_INTERVAL = 20

# Poll for the task status
while True:
    response = es.tasks.get(task_id=TASK_ID)
    task_status = response.get("completed")
    
    if task_status:
        print("Task completed successfully.")
        break
    
    print("Task still running. Checking again in {} seconds.".format(POLL_INTERVAL))
    time.sleep(POLL_INTERVAL)