### Elastic search request api

In [7]:
import json
import logging

import requests


def elasticsearch_curl(uri="http://localhost:9200/", json_body="", verb="get"):
    # pass header option for content type if request has a
    # body to avoid Content-Type error in Elasticsearch v6.0
    headers = {
        "Content-Type": "application/json",
    }
    resp = None
    try:
        # make HTTP verb parameter case-insensitive by converting to lower()
        if verb.lower() == "get":
            resp = requests.get(uri, headers=headers, data=json_body)
        elif verb.lower() == "post":
            resp = requests.post(uri, headers=headers, data=json_body)
        elif verb.lower() == "put":
            resp = requests.put(uri, headers=headers, data=json_body)
        elif verb.lower() == "del":
            resp = requests.delete(uri, headers=headers, data=json_body)
            return None

        # read the text object string
        try:
            resp_text = json.loads(resp.text)
        except:
            resp_text = resp.text

        # catch exceptions and print errors to terminal
    except Exception as error:
        logging.warning("resp:", resp)
        logging.warning("uri:", uri)
        logging.warning("\nelasticsearch_curl() error:", error)
        resp_text = None

    # return the Python dict of the request
    #     print ("resp_text:", resp_text)
    return resp_text


def del_all_scroll():
    response = elasticsearch_curl(
        uri="http://localhost:9200/_search/scroll/_all", verb="del"
    )
    return response


def del_pit(pit):
    json_data = json.dumps({"id": pit})
    response = elasticsearch_curl(
        uri="http://localhost:9200/_pit", json_body=json_data, verb="del"
    )
    return response


response = elasticsearch_curl("http://localhost:9200/wikipedia_sentences/_count")
# print(response)

### Post a sentence to the index

In [8]:
def post_sentence(dict_object):
    # Post one example to news_event_frame index
    if dict_object:
        json_object = json.dumps(dict_object)
    else:
        return None
    response = elasticsearch_curl(
        "http://localhost:9200/wikipedia_sentences/_doc",
        verb="post",
        json_body=json_object,
    )
    return response


def query_sentence(phrase, size=20):
    result = []
    for should_phrase in [
        "is a",
        "is not a",
        "have a",
        "does not have a",
        "is capable of",
        "is not capable of",
    ]:
        dict_object = {
            "query": {
                "bool": {
                    "must": [{"match_phrase": {"doc": phrase}}],
                    "should": [{"match_phrase": {"doc": should_phrase}}],
                },
            },
            "size": size,
        }

        json_object = json.dumps(dict_object)
        response = elasticsearch_curl(
            "http://localhost:9200/wikipedia_sentences/_search",
            verb="post",
            json_body=json_object,
        )
        for item in response["hits"]["hits"]:
            score = item["_score"]
            doc = item["_source"]["doc"]
            result.append((doc, score))
    return result


def query_sentence_match(phrase):
    dict_object = {"query": {"match": {"doc": phrase}}, "size": 1}
    json_object = json.dumps(dict_object)
    response = elasticsearch_curl(
        "http://localhost:9200/wikipedia_sentences/_search",
        verb="post",
        json_body=json_object,
    )
    result = []
    for item in response["hits"]["hits"]:
        score = item["_score"]
        doc = item["_source"]["doc"]
        result.append((doc, score))
    return result


def bulk_post(bulk_dict_data, index_name="wikipedia_sentences"):
    if len(bulk_dict_data) < 1:
        return None
    # Post multiple examples to an index
    # A list of data dict
    meta_json = json.dumps({"index": {"_index": index_name, "_type": "_doc"}})
    data_to_post = (
        "\n".join(meta_json + "\n" + json.dumps(d) for d in bulk_dict_data) + "\n"
    )
    response = elasticsearch_curl(
        f"http://localhost:9200/_bulk", verb="post", json_body=data_to_post,
    )
    return response


# phrase = "entity"
# test_data_dict = [{"query": {"match_phrase": {"doc": phrase}}, "size": 1} for i in range(5)]
# response = post_sentence(test_data_dict)
# # print(response)
# response = bulk_post(test_data_dict)
# print(response)

### Insert sentences into elastic index

In [9]:
data = []
jsonfilename = "../data/WikiData/mask-filling/wd_train_50.txt"
with open(jsonfilename) as f:
    input_lines = f.readlines()
    for line in input_lines:
        data.append(line)
    print(len(data))
jsonfilename = "../data/WikiData/mask-filling/wd_test_50.txt"
with open(jsonfilename) as f:
    input_lines = f.readlines()
    for line in input_lines:
        data.append(line)
    print(len(data))
jsonfilename = "../data/WikiData/mask-filling/wd_dev_50.txt"
with open(jsonfilename) as f:
    input_lines = f.readlines()
    for line in input_lines:
        data.append(line)
    print(len(data))

4124
4438
4851


In [None]:
import random
import string
from random import sample

from tqdm.auto import tqdm

count = lambda l1, l2: sum([1 for x in l1 if x in l2])


def count_punc(s):
    return count(s, set(string.punctuation))


def count_keyword(s, key):
    count = 0
    s = s.replace(".", "").replace("\n", "")
    for w in s.split(" "):
        if w.lower() == key:
            count += 1
    return count


source_obj_set = {}
for item in tqdm(data):
    ## Get source and objective
    sens, obj = item.split("\t")
    obj = obj.strip()
    sen_li = sens.split(". ")
    for sen in sen_li[:-1]:
        sen_ws = sen.split(" ")
        for w in sen_ws:
            if sen.replace(w, "<MASK>") + "." == sen_li[-1]:
                source = w
                source_sent = sen
                break
    ## Get top 10 sentences from wikipedia by source
    #         print(f"source:{source}\n target:{obj}\n source_sent:{source_sent}")
    if source in source_obj_set:
        if obj not in source_obj_set[source]:
            source_obj_set[source].append(obj)
    else:
        source_obj_set[source] = []


source_neural_sent_dic = {}
outfilename = "../data/WikiData/mask-filling/wd_neutral_sent.txt"
with open(outfilename, "w") as f:
    out_sentence = (
        "SOURCE" + "\t" + "NEUTRAL_SENTENCE" + "\t" + "MAX_REPLACE_QUERY_SCORE\n"
    )
    f.write(out_sentence)
    for source, obj_li in tqdm(source_obj_set.items()):
        if len(obj_li) < 1:
            continue
        size = 20
        top_query_sent = query_sentence(
            source, size
        )  # a list of sentences that contain the source word
        top_query_sent = list(set(top_query_sent))
        query_replace_score = []
        ### For each of the 10 sentence, reformulate it by the object word, and then get the one with the lowest similary score in its top-1 result.
        ### This means the word in that sentence is unlikely repalced by the object.
        for sent in tqdm(top_query_sent):
            if len(sent[0].replace("  ", " ").split(" ")) < 4:
                query_replace_score.append(200)
                continue

            if count_punc(sent[0]) > 1:  ## Less than one source word
                query_replace_score.append(100)
                continue

            all_obj_top_scores = []
#             if len(obj_li) > 10:
#                 obj_li = sample(obj_li, 10)
            for obj in obj_li:
                query_replace_sen = sent[0].replace(source, obj)
                top_1_sen, top_1_score = query_sentence_match(query_replace_sen)[0]
                all_obj_top_scores.append(top_1_score)
            query_top_1_score = max(all_obj_top_scores)
            query_replace_score.append(query_top_1_score)
        print(source)
        print(query_replace_score)
        for best_sen, sore in zip(top_query_sent, query_replace_score):
            if sore < 50:
                out_sentence = (
                    source + "\t" + best_sen[0].strip() + "\t" + str(sore) + "\n"
                )
                f.write(out_sentence)
                f.flush()

HBox(children=(FloatProgress(value=0.0, max=4851.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=706.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=31.0), HTML(value='')))


uc
[200, 100, 57.358994, 100, 100, 100, 100, 100, 100, 100, 40.899837, 47.899254, 65.614426, 100, 100, 59.821938, 200, 100, 41.77512, 100, 100, 66.98061, 100, 57.358994, 100, 100, 100, 100, 100, 100, 200]


HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))


gymnasium
[100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 38.448093, 100, 57.25715, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 89.11842, 100, 100, 100, 100, 100, 100, 200, 100, 100, 100, 54.5495, 100, 100, 54.5495, 100, 100, 100, 57.25715, 100, 100, 100, 100, 100, 100, 100, 100]


HBox(children=(FloatProgress(value=0.0, max=86.0), HTML(value='')))


mathematical
[100, 100, 108.50187, 100, 113.84152, 52.704918, 100, 107.66448, 100, 100, 48.37364, 66.1816, 30.291853, 100, 62.595108, 100, 55.551216, 68.043915, 100, 100, 79.89993, 100, 82.105965, 100, 54.084682, 61.32752, 100, 44.916874, 100, 46.625267, 100, 100, 100, 100, 89.08565, 62.101326, 100, 100, 100, 51.106262, 100, 100, 100, 62.6825, 100, 48.788563, 100, 100, 100, 100, 68.51367, 100, 100, 100, 100, 71.73747, 100, 100, 100, 100, 100, 21.860834, 64.29395, 100, 100, 70.95816, 100, 100, 60.08189, 100, 100, 49.675346, 100, 49.675346, 100, 73.74263, 25.536543, 100, 100, 36.63915, 53.156094, 100, 104.32438, 100, 100, 100]


HBox(children=(FloatProgress(value=0.0, max=110.0), HTML(value='')))


days
[58.56957, 100, 35.31755, 54.25983, 24.524075, 100, 25.900764, 28.595562, 100, 100, 100, 100, 100, 81.13696, 100, 55.869713, 100, 75.66709, 100, 100, 100, 100, 100, 100, 100, 100, 69.620544, 32.56594, 42.67208, 100, 62.498466, 82.98729, 100, 65.488045, 42.47584, 45.343254, 37.91435, 100, 40.83455, 100, 100, 100, 100, 100, 100, 100, 78.78427, 100, 100, 65.561646, 100, 60.394157, 100, 100, 73.47807, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 32.979664, 100, 100, 31.913574, 134.8746, 100, 100, 100, 63.120903, 100.81957, 100, 100, 48.937088, 83.24594, 32.41822, 55.45412, 50.46963, 100, 100, 61.25139, 100, 100, 100, 44.57302, 63.280766, 66.6198, 100, 100, 100, 100, 100, 78.335266, 100, 100, 100, 100, 100, 100, 100, 43.415913, 82.03022, 31.19347, 100, 100]


HBox(children=(FloatProgress(value=0.0, max=73.0), HTML(value='')))


nest
[100, 100, 100, 41.516632, 100, 26.27335, 100, 100, 100, 83.95829, 100, 87.15492, 84.36905, 100, 100, 63.175404, 46.3727, 28.385782, 100, 21.5762, 100, 100, 100, 68.65329, 100, 81.53687, 100, 100, 100, 56.406254, 38.97204, 26.539463, 100, 45.811123, 73.49348, 28.942444, 200, 100, 100, 109.11208, 100, 100, 60.31894, 100, 59.990257, 66.59789, 54.02215, 100, 101.78922, 100, 45.432133, 59.93465, 38.629753, 100, 100, 55.139145, 122.93384, 100, 83.06426, 100, 43.523823, 100, 100, 50.627716, 100, 100, 100, 99.31623, 100, 37.670166, 100, 100, 100]


HBox(children=(FloatProgress(value=0.0, max=65.0), HTML(value='')))


ali
[100, 56.23985, 56.20364, 100, 100, 100, 60.685204, 100, 100, 100, 50.15868, 50.986416, 56.94111, 56.563522, 100, 37.01421, 100, 100, 100, 100, 100, 66.12709, 100, 100, 100, 200, 77.89262, 100, 100, 100, 100, 120.34441, 85.296074, 100, 100, 100, 100, 78.28835, 100, 100, 53.370678, 100, 200, 100, 100, 100, 100, 100, 100, 48.132366, 100, 96.610435, 100, 100, 100, 42.85795, 100, 100, 100, 60.580723, 100, 100, 25.315002, 100, 100]


HBox(children=(FloatProgress(value=0.0, max=68.0), HTML(value='')))


google
[100, 100, 76.243416, 100, 100, 100, 100, 100, 100, 100, 100, 116.471886, 100, 100, 100, 100, 100, 100, 64.111694, 100, 100, 83.60434, 122.77535, 100, 100, 73.88432, 100, 100, 100, 100, 136.3017, 100, 100, 79.20551, 74.73652, 50.168476, 100, 100, 52.837276, 74.246796, 100, 100, 100, 66.07609, 100, 100, 92.585396, 100, 100, 66.73795, 100, 100, 47.051365, 100, 85.23278, 73.112526, 100, 100, 100, 100, 100, 100, 100, 125.09341, 48.92394, 100, 68.55482, 100]


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


palaces
[100, 100, 200, 100, 200, 37.351433, 100, 46.831036, 100, 100, 100, 100, 38.3732, 200, 78.31159, 200, 100, 23.628101, 52.211, 100, 100, 58.314514, 100, 100, 200, 57.01868, 103.19612, 100, 46.5107, 100, 100, 38.09427, 70.89281, 51.02264, 200, 200, 100, 100, 100, 46.968754]


HBox(children=(FloatProgress(value=0.0, max=62.0), HTML(value='')))


museums
[46.8237, 100, 31.444002, 56.878536, 81.191795, 100, 100, 100, 30.003878, 33.67836, 100, 100, 100, 30.064648, 100, 30.5154, 53.139694, 44.119602, 31.46045, 100, 100, 30.249163, 43.321976, 24.138826, 100, 100, 100, 48.50317, 100, 100, 60.286896, 38.048656, 100, 100, 46.62393, 49.863533, 34.463905, 100, 100, 98.38253, 100, 94.243935, 100, 85.36799, 100, 46.022076, 100, 100, 100, 100, 100, 100, 43.160885, 44.083656, 100, 100, 59.650425, 100, 100, 39.2573, 100, 25.519093]


HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))


mosques
[100, 100, 100, 100, 100, 76.20507, 100, 26.188261, 100, 29.723818, 100, 100, 100, 100, 57.858887, 58.356697, 100, 100, 60.04093, 53.37375, 200, 27.024565, 100, 27.314497, 28.161234, 100, 37.83715, 100, 200, 100, 28.897276, 29.619898, 38.62193, 32.995712, 100, 100, 100, 27.225958, 32.535187, 26.951815, 200, 88.313774, 67.58986, 26.940926, 31.854511, 100, 27.782486, 100, 100, 29.520285, 30.963575, 100, 55.99937]


HBox(children=(FloatProgress(value=0.0, max=34.0), HTML(value='')))


ek
[100, 100, 100, 88.468, 200, 100, 100, 67.27725, 53.222065, 100, 100, 100, 73.17489, 72.86018, 69.443756, 93.992455, 100, 100, 100, 100, 100, 100, 100, 69.01781, 61.995087, 100, 100, 100, 76.056076, 100, 100, 100, 100, 128.14949]


HBox(children=(FloatProgress(value=0.0, max=107.0), HTML(value='')))


street
[100, 100, 100, 100, 100, 60.583797, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 55.791965, 41.942345, 100, 64.16228, 100, 99.33436, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 76.69603, 100, 100, 100, 100, 54.922688, 100, 49.640545, 100, 100, 56.3699, 124.07309, 100, 100, 100, 67.172356, 100, 46.763206, 41.44028, 100, 41.20164, 100, 50.464027, 34.257473, 112.827576, 100, 100, 100, 100, 100, 88.81118, 42.33989, 100, 100, 41.721313, 58.24052, 100, 100, 100, 100, 57.63301, 100, 100, 100, 100, 100, 100, 100, 100, 100, 50.903294, 100, 100, 100, 100, 36.452003, 100, 100, 71.061966, 100, 100, 100, 100, 83.93806, 100, 56.3699, 100, 100, 100, 100, 100, 66.5737, 100, 100]


HBox(children=(FloatProgress(value=0.0, max=59.0), HTML(value='')))


freeway
[100, 100, 100, 59.853825, 100, 100, 100, 100, 100, 100, 100, 100, 40.247192, 100, 100, 100, 100, 100, 26.9222, 100, 100, 100, 100, 67.28872, 66.888275, 100, 100, 100, 100, 100, 83.261955, 100, 100, 100, 100, 100, 78.659485, 100, 96.622215, 100, 100, 71.99592, 66.799484, 100, 63.635204, 100, 24.497517, 100, 100, 100, 70.07817, 100, 100, 100, 49.165165, 100, 100, 46.674828, 100]


HBox(children=(FloatProgress(value=0.0, max=62.0), HTML(value='')))


orbital
[200, 100, 60.55385, 100, 100, 200, 100, 100, 100, 100, 100, 43.928093, 34.73517, 100, 100, 100, 45.962074, 74.27097, 101.65354, 96.92519, 100, 100, 100, 38.6983, 44.74498, 100, 55.18387, 100, 100, 100, 100, 100, 39.42178, 100, 49.816505, 51.048805, 100, 100, 127.852646, 100, 100, 100, 100, 100, 64.34703, 200, 61.516907, 100, 100, 100, 100, 42.040367, 200, 100, 100, 36.85577, 78.66459, 100, 62.258316, 100, 41.96491, 100]


HBox(children=(FloatProgress(value=0.0, max=62.0), HTML(value='')))


highways
[100, 100, 100, 63.29042, 100, 31.655645, 100, 100, 100, 100, 100, 100, 100, 55.726845, 100, 100, 100, 100, 100, 26.43754, 45.92136, 100, 46.070915, 47.02238, 42.507767, 100, 26.315243, 100, 100, 100, 40.00044, 44.360897, 48.888725, 100, 100, 36.035927, 100, 88.56392, 55.601192, 51.423946, 28.984623, 100, 41.201138, 100, 83.89229, 100, 100, 100, 41.990192, 100, 42.87456, 100, 100, 44.641804, 100, 89.33334, 100, 100, 100, 44.090485, 41.113655, 63.33937]


HBox(children=(FloatProgress(value=0.0, max=77.0), HTML(value='')))


theorem
[100, 88.29884, 100, 100, 100, 100, 100, 100, 100, 48.92577, 100, 69.10511, 51.819237, 100, 100, 100, 100, 100, 100, 200, 100, 100, 83.31275, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 52.14514, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 89.499214, 100, 100, 36.72422, 100, 74.18657, 100, 100, 100, 100, 100, 100, 100, 100, 40.441895, 100, 100, 100, 100, 47.20585, 62.77384, 100, 100]


HBox(children=(FloatProgress(value=0.0, max=51.0), HTML(value='')))

In [18]:
source

'gentian'

In [19]:
for source, obj_li in tqdm(source_obj_set.items()):
    print(source)

HBox(children=(FloatProgress(value=0.0, max=326.0), HTML(value='')))

energy
animal
fish
person
element
covering
drug
illumination
appearance
leader
individual
chemical
mammal
man
wood
insect
bark
quantity
food
weather
fabric
nourishment
instrument
platform
vertebrate
traveler
game
sport
football
location
region
trait
material
emotion
building
merchant
vehicle
quality
science
day
delicacy
surface
servant
disorder
barrier
disease
room
attendant
period
activity
road
clothing
decoration
paper
meal
reptile
machine
craft
feeling
ballroom
house
officer
earth
construction
parent
compound
herb
pain
storm
light
church
salamander
commodity
motor
pigeon
education
lamp
mineral
vegetable
cat
joy
art
pup
boat
sound
cloth
bag
dam
book
hair
assistant
garment
formation
gear
precipitation
shape
cutlery
radiation
water
metal
tissue
bread
worker
emperor
job
bottle
corn
coat
biome
bush
medicine
noise
flask
painting
spacecraft
restaurant
table
criminal
official
bone
chicken
shelter
passenger
wave
blubber
solid
tube
server
ground
dish
season
cloud
belief
engine
artist
interval

In [21]:
len(source_obj_set)

326