In [1]:
import torch
import faiss
import nltk
import pickle
import numpy as np

from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm

In [2]:
with open("../data/news_narrative_splits/split1/news_narratives_train.pkl", "rb") as outfile:
    train_dict = pickle.load(outfile)
with open("../data/news_narrative_splits/split1/news_narratives_test.pkl", "rb") as outfile:
    test_dict = pickle.load(outfile)

In [3]:
news_filename = '../data/event_knowledge_2.0/news_narratives.txt'
doc_to_text = {}
last_doc_id = ''
with open(news_filename) as news_narratives:
    for line in tqdm(news_narratives):
        if line == '\n':
            continue
        if line[:8] == '<doc_id>':
            doc_to_text[line[9:].replace('\n', '')] = ''
            last_doc_id = line[9:].replace('\n', '')
        if line[:6] == '<word>':
            doc_to_text[last_doc_id] += line[7:].replace('\n', ' ')
print(len(doc_to_text))
print(doc_to_text['WPB_ENG_20100127.0025.1:6'])

563065it [00:00, 1328311.31it/s]

74589
Within days of his inauguration , President Barack Obama signed executive orders to close the military prison at Guantanamo Bay within one year and to end torture in interrogation . He missed the Jan. 22 deadline to close Guantanamo but reaffirmed this month that he intends to close the prison as soon as possible . Obama has maintained other elements of the previous administration 's methods to capture and hold terrorism suspects . He has kept the military commission system to try certain terrorism suspects after strengthening evidentiary rules on behalf of defendants . He also preserved the authority to capture terrorism suspects in foreign countries , a practice known as extraordinary rendition . But he tightened the rules for where those captures can be made , limiting them to countries that do not have an effective rule of law . 





In [4]:
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
corpus = [] # only has train items
corpusidx_to_doc = {}
idx = 0
for k,v in tqdm(doc_to_text.items()):
    if k not in train_dict.keys():
        continue
    sentences = sent_detector.tokenize(v)
    for s in sentences:
        corpus.append(s)
        corpusidx_to_doc[idx] = k
        idx += 1

100%|██████████| 74589/74589 [00:02<00:00, 27326.43it/s]


In [5]:
model = SentenceTransformer('all-mpnet-base-v2')

In [None]:
# Uncomment to generate embeddings
# corpus_embeddings = model.encode(corpus, convert_to_tensor=True, show_progress_bar=True)
# with open("../data/news_narrative_splits/split1/news_narratives_train_emb.pt", "wb") as outfile:
#     torch.save(corpus_embeddings, outfile)

corpus_embeddings = torch.load('../data/news_narrative_splits/split1/news_narratives_train_emb.pt', map_location=torch.device('cpu')).numpy()
print(corpus_embeddings.shape)

In [7]:
def process_query(query, model):
    return np.expand_dims(model.encode(query, convert_to_numpy=True), axis=0)

In [8]:
# Build index using cosine similarity as distance metric
def construct_index(corpus_embeddings):
    index_f = faiss.index_factory(corpus_embeddings.shape[1], 'Flat', faiss.METRIC_INNER_PRODUCT)
    res = faiss.StandardGpuResources()
    index = faiss.index_cpu_to_gpu(res, 0, index_f)
    faiss.normalize_L2(corpus_embeddings)
    index.add(corpus_embeddings)
    return index

In [9]:
index = construct_index(corpus_embeddings)

In [10]:
# debugging function
def print_nearest_sents(sent_idxs, corpus):
    sents = []
    for idx in set(sent_idxs):
        sents.append(corpus[idx])
    return set(sents)

In [None]:
# Example Querying
# sent = 'doctors removed her adrenal glands'
# q = process_query(sent, model)
# faiss.normalize_L2(q)
# D, I = index.search(q, 5)
# print(I)
# print(D)
# print_nearest_sents(I[0], corpus)

In [11]:
def get_events_before_after(query, k, index, corpus, doc_to_text, corpusidx_to_doc, model, anchor=None):
    
    events_before = []
    events_after = []
    
    docs_found = []
    related_sentences = []
    
    q = process_query(query, model)
    faiss.normalize_L2(q)
    D, I = index.search(q, k)
    sent_idxs = I[0]
    for idx in sent_idxs:
        if anchor and anchor not in corpus[idx].split(" "): # Filtering for retrieved sentences w/o anchor
            continue
        related_sentences.append(corpus[idx])
        doc = doc_to_text[corpusidx_to_doc[idx]]
        docs_found.append(corpusidx_to_doc[idx])
        anchor_idx = corpus_idx_to_idx_in_doc(idx, corpusidx_to_doc)
        sentences = sent_detector.tokenize(doc)
        for i, s in enumerate(sentences):
            if len(s.split(" ")) < 2: # Single word sentences aren't useful
                continue
            if i < anchor_idx - 3 or i > anchor_idx + 3: # only add sentences in +/- 3 sentence window
                continue
            if i < anchor_idx:
                events_before.append(s)
            elif i > anchor_idx:
                events_after.append(s)
# Debug output
#     print("Query:", query)
#     print("Queried top", k, "sentences; Found", len(set(docs_found)) , "unique documents")
#     print("Related Sentences:", len(related_sentences))
#     for s in related_sentences:
#         print(s)
    return set(events_before), set(events_after)

In [12]:
def corpus_idx_to_idx_in_doc(corpus_idx, corpusidx_to_doc):
    docid = corpusidx_to_doc[corpus_idx]
    i = corpus_idx - 1
    while i >= 0:
        if corpusidx_to_doc[i] != docid:
            return corpus_idx - i - 1
        i -= 1
    return -1

In [13]:
def load_seeds(seed_path):
    doc_sent_map = {}
    with open(seed_path) as seeds:
        for line in seeds:
            text = line.split(" - ")
            doc_sent_map[text[0]] = text[1].strip()
    return doc_sent_map

205


In [None]:
dsm = load_seeds('../data/news_narrative_splits/split1/seeds/bomb_seeds_test.txt')
print(len(dsm))

In [14]:
# See if proposed event is similar to events in corpus_events (above threshold)
def report_top_sim(proposed_event, corpus_events, model, threshold):
    pe_embedding = model.encode(proposed_event, convert_to_tensor=True)
    corpus_events_embeddings = model.encode(corpus_events, convert_to_tensor=True)
    sims = []
    for embedding in corpus_events_embeddings:
        sim = util.cos_sim(pe_embedding, embedding)
        sims.append(sim.detach()[0][0].cpu().numpy())
    for i, s in enumerate(sims):
        if s > threshold:
            return True
    return False

In [25]:
# Determine if proposed event happened before or after by seeing if the top-5 most similar sentences were from the before or after sets
def before_or_after(proposed_event, events_before, events_after, model):
    pe_embedding = model.encode(proposed_event, convert_to_tensor=True)
    all_events = events_before + events_after
    events_embeddings = model.encode(all_events, convert_to_tensor=True)
    sims = []
    for embedding in events_embeddings:
        sim = util.cos_sim(pe_embedding, embedding)
        sims.append(sim.detach()[0][0].cpu().numpy())
    sims = np.array(sims)
    print(np.argmax(sims))
    max_sim_idx = np.argsort(-sims)[:5]
    print(max_sim_idx)
    count_before = 0
    for idx in max_sim_idx:
        if idx < len(events_before):
            count_before += 1
    return count_before >= 3

In [None]:
# Evaluation

# adjust comments for was_before/report_top_sim depending on evaluation style

sents_before = 0
sents_before_correct = 0
sents_after = 0
sents_after_correct = 0
for doc_id, anchor_sent in dsm.items():
#     print("Anchor Sent:", anchor_sent)
    events_before, events_after = get_events_before_after(anchor_sent, 100, index, corpus, doc_to_text, corpusidx_to_doc, model)
#     print("Events Before:", len(events_before))
#     for e in events_before:
#         print(e)

#     print("---"*10)
#     print("Events After:", len(events_after))
#     for e in events_after:
#         print(e)
#     print("---"*10)
    anchor_idx = -1
    for i, s in enumerate(test_dict[doc_id]):
        if anchor_sent == s:
            anchor_idx = i
    for i, s in enumerate(test_dict[doc_id]):
        was_before = before_or_after(s, list(events_before), list(events_after), model) #
        if i < anchor_idx:
            sents_before += 1
#             if report_top_sim(s, list(events_before), model, 0.6):
#                 sents_before_correct += 1
            if was_before: #
                sents_before_correct += 1 #
        elif i > anchor_idx:
            sents_after += 1
#             if report_top_sim(s, list(events_after), model, 0.6):
#                 sents_after_correct += 1
            if not was_before: #
                sents_after_correct += 1 #
print("From", len(dsm), "Seeds")
print("\t", sents_before_correct, "/", sents_before, "=", sents_before_correct/sents_before, "On sentences BEFORE anchor")
print("\t", sents_after_correct, "/", sents_after, "=", sents_after_correct/sents_after, "On sentences AFTER anchor")

In [28]:
# Full Query Output
query = 'smoke bomb burned several people'
before, after = get_events_before_after(query, 100, index, corpus, doc_to_text, corpusidx_to_doc, model)
# set of events disregards duplicate sentences (may come up when multiple sentences in top query results link back to same doc)
print("---"*10)
print("Events Before:", len(before))
# for e in before:
#     print(e)

print("---"*10)
print("Events After:", len(after))
# for e in after:
#     print(e)

Query: smoke bomb burned several people
Queried top 100 sentences; Found 94 unique documents
Related Sentences: 100
------------------------------
Events Before: 164
------------------------------
Events After: 168
