# Reindex State of the union with NER

In [None]:
import os
from dotenv import load_dotenv
load_dotenv(".env", override=True)

from elasticsearch import Elasticsearch

es = None

if 'ELASTIC_CLOUD_ID' in os.environ:
  es = Elasticsearch(
    cloud_id=os.environ['ELASTIC_CLOUD_ID'],
    basic_auth=(os.environ['ELASTIC_USER'], os.environ['ELASTIC_PASSWORD']),
    request_timeout=30
  )
elif 'ELASTIC_URL' in os.environ:
  es = Elasticsearch(
    os.environ['ELASTIC_URL'],
    basic_auth=(os.environ['ELASTIC_USER'], os.environ['ELASTIC_PASSWORD']),
    request_timeout=30
  )
else:
  print("env needs to set either ELASTIC_CLOUD_ID or ELASTIC_URL")

if es:
    print(es.info()['tagline']) # should return cluster info

In [None]:
## Utility functions from week 2
from elasticsearch import Elasticsearch, helpers
from tqdm import tqdm

def delete_index(index_name):
    if es.indices.exists(index=index_name):
        print(f"Index '{index_name}' exists. Deleting...")
        # Delete the index
        es.indices.delete(index=index_name)
        print(f"Index '{index_name}' deleted.")

def create_index_with_mapping(index_name, properties, dynamic_templates=None):
    # Check if the index exists, and if not, create it
    if not es.indices.exists(index=index_name):
        es.indices.create(index=index_name)
    
    if(dynamic_templates):
        response = es.indices.put_mapping(properties=properties, index=index_name, dynamic_templates=dynamic_templates )
    else:
        response = es.indices.put_mapping(properties=properties, index=index_name )

def batchify(docs, batch_size):
    for i in range(0, len(docs), batch_size):
        yield docs[i:i + batch_size]

def bulkLoadIndex(index_name, json_docs ):
    batches = list(batchify(json_docs, BATCH_SIZE))

    for batch in tqdm(batches, desc=f"Batches of size {BATCH_SIZE}"):
        # Convert the JSON documents to the format required for bulk insertion
        bulk_docs = [
            {
                "_op_type": "index",
                "_index": index_name,
                "_source": doc
            }
            for doc in batch
        ]

        # Perform bulk insertion
        success, errors =  helpers.bulk(es, bulk_docs, raise_on_error=False)
        if errors:
            for error in errors:
                print(error)

def changeEsRefreshInterval(es, index_name, refresh_interval):
    body = {
        "index": {
            "refresh_interval": refresh_interval
        }
    }
    response = es.indices.put_settings(index=index_name, body=body)


In [None]:
index_name = "genai_state_of_the_union"
source = es.search(index=index_name, size=1)["hits"]["hits"][0]["_source"]

In [None]:
model_id= "distilbert-base-cased-finetuned-conll03-english"
es_model_id = f"elastic__{model_id}"

inference = {
       "inference": {
         "model_id": es_model_id,
         "field_map": {
           "text": "text_field"
         }
       }
    }

processors = [
    inference

]

es.ingest.put_pipeline(id="sotu_ner", processors=processors)

docs = [
    {"_source": source}
]

value = es.ingest.simulate(id='sotu_ner', docs=docs).body["docs"][0]["doc"]["_source"]

print(value["ml"]["inference"]["predicted_value"][:100]) ### this is where the enrich text shows up

print(value["ml"]["inference"]["entities"][:2]) ### this is where the entities show up
## creates new fields in 

In [None]:
script_processor = {"script": {
    "lang": "painless",
    "source": """

Map convertMap = new HashMap();
convertMap.put("PER", "person");
convertMap.put("ORG", "organization");
convertMap.put("MISC", "misc");
convertMap.put("LOC", "location");

ctx["ner_text"] = ctx["ml"]["inference"]["predicted_value"];
for ( entity in  ctx["ml"]["inference"]["entities"]){
    String class_name = entity["class_name"];
    String key = convertMap.get(class_name);
    String entity_value = entity["entity"];
    if (!ctx.containsKey(key)) {
        ctx[key] = [];
    }
    if (! ctx[key].contains(entity_value)) {
        ctx[key].add(entity_value);
    }
}
"""
}}

remove_processor = { "remove": {"field": "ml"}}

processors = [
    inference,
    script_processor,
    remove_processor
]

es.ingest.put_pipeline(id="sotu_ner", processors=processors)

value = es.ingest.simulate(id='sotu_ner', docs=docs)

value["docs"][0]["doc"]["_source"].keys()


In [None]:
destination_index = "genai_state_of_the_union_ner"
delete_index(index_name=destination_index)

dynamic_templates =  [
      {
        "fields_with_prefix": {
          "match_pattern": "regex",
          "match": "^facet_.*",
          "mapping": {
            "type": "keyword"
          }
        }
      }
    ]

properties = {
            "administration":   {"type": "keyword"},
            "date":             {"type": "keyword"},
            "date_iso":         {"type": "date"},
            "text":             {"type": "text"},
            "person":           {"type": "keyword"},
            "organization":     {"type": "keyword"},
            "misc":             {"type": "keyword"},
            "location":         {"type": "keyword"},
            "ner_text": {
                "type": "text",
                "index": False,
                "store": True
            },
            "url":  {
                        "type": "text",
                        "fields": {
                            "keyword": {
                                "type": "keyword",
                                "ignore_above": 1024
                            }
                        }
                    }              
        }
create_index_with_mapping(index_name=destination_index, 
                          properties=properties, 
                          dynamic_templates=dynamic_templates)

index_name = "genai_state_of_the_union"
source = {
    "index": index_name
}
dest_name = destination_index
dest = {
    "index": dest_name,
    "pipeline": 'sotu_ner'
}

task_id = es.reindex(source=source, dest=dest, wait_for_completion=False)["task"]
print(task_id)

In [None]:
from elasticsearch.client import TasksClient
import time

tasks = TasksClient(client=es)
is_completed = False
while not is_completed:
    tasks_api = tasks.get(task_id=task_id)
    is_completed = tasks_api["completed"]
    done_count = tasks_api["task"]["status"]["created"]
    total_count = tasks_api["task"]["status"]["total"]
    print(f"Processing ... {done_count}/{total_count}")
    time.sleep(5)
print("Done")