In [1]:
import torch
import faiss
import nltk
import pickle
import numpy as np

from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm

In [2]:
with open("news_narratives_train.pkl", "rb") as outfile:
    train_dict = pickle.load(outfile)
with open("news_narratives_test.pkl", "rb") as outfile:
    test_dict = pickle.load(outfile)

In [3]:
news_filename = 'news_narratives.txt'
doc_to_text = {}
last_doc_id = ''
with open(news_filename) as news_narratives:
    for line in tqdm(news_narratives):
        if line == '\n':
            continue
        if line[:8] == '<doc_id>':
            doc_to_text[line[9:].replace('\n', '')] = ''
            last_doc_id = line[9:].replace('\n', '')
        if line[:6] == '<word>':
            doc_to_text[last_doc_id] += line[7:].replace('\n', ' ')
print(len(doc_to_text))
print(doc_to_text['WPB_ENG_20100127.0025.1:6'])

563065it [00:00, 1328311.31it/s]

74589
Within days of his inauguration , President Barack Obama signed executive orders to close the military prison at Guantanamo Bay within one year and to end torture in interrogation . He missed the Jan. 22 deadline to close Guantanamo but reaffirmed this month that he intends to close the prison as soon as possible . Obama has maintained other elements of the previous administration 's methods to capture and hold terrorism suspects . He has kept the military commission system to try certain terrorism suspects after strengthening evidentiary rules on behalf of defendants . He also preserved the authority to capture terrorism suspects in foreign countries , a practice known as extraordinary rendition . But he tightened the rules for where those captures can be made , limiting them to countries that do not have an effective rule of law . 





In [4]:
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
corpus = [] # only has train items
corpusidx_to_doc = {}
idx = 0
for k,v in tqdm(doc_to_text.items()):
    if k not in train_dict.keys():
        continue
    sentences = sent_detector.tokenize(v)
    for s in sentences:
        corpus.append(s)
        corpusidx_to_doc[idx] = k
        idx += 1

100%|██████████| 74589/74589 [00:02<00:00, 27326.43it/s]


In [5]:
model = SentenceTransformer('all-mpnet-base-v2')

In [None]:
corpus_embeddings = model.encode(corpus, convert_to_tensor=True, show_progress_bar=True)

In [None]:
# with open("news_narratives_train_2_emb.pt", "wb") as outfile:
#     torch.save(corpus_embeddings, outfile)

In [6]:
corpus_embeddings = torch.load('news_narratives_train_emb.pt', map_location=torch.device('cpu')).numpy()
print(corpus_embeddings.shape)

(171749, 768)


In [7]:
def process_query(query, model):
    return np.expand_dims(model.encode(query, convert_to_numpy=True), axis=0)

In [8]:
def construct_index(corpus_embeddings):
    index_f = faiss.index_factory(corpus_embeddings.shape[1], 'Flat', faiss.METRIC_INNER_PRODUCT)
    res = faiss.StandardGpuResources()
    index = faiss.index_cpu_to_gpu(res, 0, index_f)
    faiss.normalize_L2(corpus_embeddings)
    index.add(corpus_embeddings)
    return index

In [9]:
index = construct_index(corpus_embeddings)

In [None]:
sent = 'doctors removed her adrenal glands'
q = process_query(sent, model)
# print(q.shape)
faiss.normalize_L2(q)
D, I = index.search(q, 5)
print(I)
print(D)

In [10]:
def print_nearest_sents(sent_idxs, corpus):
    sents = []
    for idx in set(sent_idxs):
        sents.append(corpus[idx])
    return set(sents)

In [None]:
print_nearest_sents(I[0], corpus)

In [11]:
def get_events_before_after(query, k, index, corpus, doc_to_text, corpusidx_to_doc, model, anchor=None):
    
    events_before = []
    events_after = []
    
    docs_found = []
    related_sentences = []
    
    q = process_query(query, model)
    faiss.normalize_L2(q)
    D, I = index.search(q, k)
    sent_idxs = I[0]
    for idx in sent_idxs:
        if anchor and anchor not in corpus[idx].split(" "): # Filtering for retrieved sentences w/o anchor
            continue
        related_sentences.append(corpus[idx])
        doc = doc_to_text[corpusidx_to_doc[idx]]
        docs_found.append(corpusidx_to_doc[idx])
        anchor_idx = corpus_idx_to_idx_in_doc(idx, corpusidx_to_doc)
        sentences = sent_detector.tokenize(doc)
#         print(anchor_idx, len(sentences)-1)
        for i, s in enumerate(sentences):
            if len(s.split(" ")) < 2: # Single word sentences aren't useful
                continue
            if i < anchor_idx - 3 or i > anchor_idx + 3:
                continue
            if i < anchor_idx:
                events_before.append(s)
            elif i > anchor_idx:
#                 print(s)
                events_after.append(s)
#     print("Query:", query)
#     print("Queried top", k, "sentences; Found", len(set(docs_found)) , "unique documents")
#     print("Related Sentences:", len(related_sentences))
#     for s in related_sentences:
#         print(s)
    return set(events_before), set(events_after)

In [12]:
def corpus_idx_to_idx_in_doc(corpus_idx, corpusidx_to_doc):
    docid = corpusidx_to_doc[corpus_idx]
    i = corpus_idx - 1
    while i >= 0:
        if corpusidx_to_doc[i] != docid:
            return corpus_idx - i - 1
        i -= 1
    return -1

In [13]:
def load_seeds(seed_path):
    doc_sent_map = {}
    with open(seed_path) as seeds:
        for line in seeds:
            text = line.split(" - ")
            doc_sent_map[text[0]] = text[1].strip()
    return doc_sent_map

dsm = load_seeds('cancer_seeds_test.txt')
print(len(dsm))
# print(dsm['NYT_ENG_19970909.0496.13:16'])

205


In [14]:
def report_top_sim(proposed_event, corpus_events, model, threshold):
    is_similar = False
    pe_embedding = model.encode(proposed_event, convert_to_tensor=True)
    corpus_events_embeddings = model.encode(corpus_events, convert_to_tensor=True)
    sims = []
    for embedding in corpus_events_embeddings:
        sim = util.cos_sim(pe_embedding, embedding)
        sims.append(sim.detach()[0][0].cpu().numpy())
    for i, s in enumerate(sims):
        if s > threshold:
#             print(proposed_event)
#             print(s, "-", corpus_events[i])
#             print("---"*5)
            is_similar = True
            break
    return is_similar

In [25]:
def before_or_after(proposed_event, events_before, events_after, model):
    pe_embedding = model.encode(proposed_event, convert_to_tensor=True)
    all_events = events_before + events_after
    events_embeddings = model.encode(all_events, convert_to_tensor=True)
    sims = []
    for embedding in events_embeddings:
        sim = util.cos_sim(pe_embedding, embedding)
        sims.append(sim.detach()[0][0].cpu().numpy())
    sims = np.array(sims)
    print(np.argmax(sims))
    max_sim_idx = np.argsort(-sims)[:5]
    print(max_sim_idx)
    count_before = 0
    for idx in max_sim_idx:
        if idx < len(events_before):
            count_before += 1
#     max_sim_idx = np.argmax(sims)
#     if max_sim_idx >= len(events_before):
#         return False
#     return True
    return count_before >= 3

In [26]:
sents_before = 0
sents_before_correct = 0
sents_after = 0
sents_after_correct = 0
for doc_id, anchor_sent in dsm.items():
#     print("Anchor Sent:", anchor_sent)
    events_before, events_after = get_events_before_after(anchor_sent, 100, index, corpus, doc_to_text, corpusidx_to_doc, model)
#     print("Events Before:", len(events_before))
#     for e in events_before:
#         print(e)

#     print("---"*10)
#     print("Events After:", len(events_after))
#     for e in events_after:
#         print(e)
#     print("---"*10)
    anchor_idx = -1
    for i, s in enumerate(test_dict[doc_id]):
        if anchor_sent == s:
            anchor_idx = i
#     print(anchor_idx)
    for i, s in enumerate(test_dict[doc_id]):
        was_before = before_or_after(s, list(events_before), list(events_after), model)
        if i < anchor_idx:
            sents_before += 1
#             if report_top_sim(s, list(events_before), model, 0.6):
#                 sents_before_correct += 1
            if was_before:
                sents_before_correct += 1
        elif i > anchor_idx:
            sents_after += 1
#             if report_top_sim(s, list(events_after), model, 0.6):
#                 sents_after_correct += 1
            if not was_before:
                sents_after_correct += 1
print("From", len(dsm), "Seeds")
print("\t", sents_before_correct, "/", sents_before, "=", sents_before_correct/sents_before, "On sentences BEFORE anchor")
print("\t", sents_after_correct, "/", sents_after, "=", sents_after_correct/sents_after, "On sentences AFTER anchor")

165
[165  53 178  55  92]
55
[ 55 178 143  27  40]
244
[244 256 214 157  34]
121
[121 146  23 222 202]
319
[319 239  40 332  89]
56
[ 56 122  35  13 274]
171
[171 105 276 250 147]
171
[171 300 147 105 279]
320
[320 120 122 151 285]
128
[128  57 247 312  11]
214
[214 120 247 181  61]
218
[218 247 105 172 341]
218
[218   7 102 264 288]
179
[179 107 184 141 297]
15
[ 15 276 212 129 325]
15
[ 15 276   2 300 212]
272
[272 244 212   3 222]
212
[212 176 276 301 324]
248
[248  90 175 278 251]
63
[ 63  27 305 283 267]
82
[ 82 283  63 265 260]
197
[197 287 115 207 120]
217
[217 232  53 155 256]
320
[320 255 183 196   9]
151
[151  82  24 284 309]
300
[300 299 267 144 206]
206
[206  51 144  83 222]
220
[220 273 319 276 302]
261
[261 227 291 196 197]
172
[172 287 187 279 208]
273
[273 276  14 261 302]
190
[190 113 186  86 237]
244
[244 267  14  28 211]
160
[160 254 237 213 150]
215
[215 103 254 287 111]
200
[200 219  85 223  94]
238
[238  30 322  62 267]
166
[166  54 211 196  21]
118
[118  62 101  

97
[ 97 139 261  69  96]
282
[282 284  98 250 291]
310
[310 291 266  48 300]
31
[ 31 166 171 305 267]
39
[ 39 128 257  75  25]
216
[216  20 252 177  49]
216
[216 224 218 110 252]
318
[318 334 238 216 227]
216
[216 112 227   1 154]
105
[269 105 107 328  71]
266
[266  10 275 335 153]
268
[268 253 106 105 269]
23
[ 23  29  91  38 278]
105
[269 105 230 218 332]
304
[304  88 162 321 197]
88
[ 88 197 253 162  39]
304
[304 321 104  30  88]
109
[109  58 188  61 193]
89
[ 89 247 139 254 317]
89
[ 89 267 281 115  75]
22
[ 22   4 149 178 115]
149
[149  22 248 314 270]
162
[162 143   1  29 318]
213
[213 111 189 275 105]
331
[331 311 195   8 236]
162
[162 284 255  46   8]
101
[101 143 255  27  73]
170
[170  65 210 213  45]
15
[ 15 234  54  33 188]
94
[ 94  46   8 267 292]
192
[192  88 223 142  22]
179
[179 193 207  23 169]
72
[ 72  83 207  69 212]
167
[167 207 232 104 204]
238
[238  82 248  72  98]
203
[203  64 248 277 283]
293
[293 141 300 122 227]
267
[267 244 259 186  54]
64
[ 64  72 313 265 237

108
[108 156 170 280  62]
235
[235 234  54 177  98]
151
[151  34  73 205 244]
146
[146 283 114 306 226]
17
[ 17   8 325   6 238]
173
[173 107 113  91 273]
170
[170 328 186 214 135]
320
[320  99 136 117 193]
199
[199 289 310 205 303]
310
[310 325 199 205 152]
197
[197 105 270 257  94]
112
[112  31  71 218  62]
282
[282 263 232 148 138]
193
[193   7  74 252 138]
171
[171 211 238 185 318]
171
[171 318 280  67 173]
39
[ 39  27 301 186 174]
256
[256 185 136  68  83]
39
[ 39 220 283 175 247]
24
[ 24 186 268  27  11]
3
[  3  98  78 177 174]
276
[276 300 112 240 138]
178
[178 149 121 156 280]
8
[  8 156 226 280 262]
286
[286 156 113 326 315]
156
[156 286   8 262  43]
339
[339 309 187 166  70]
195
[195  78 144  11  46]
302
[302 124 315  65  79]
264
[264  77 269  49 217]
204
[204  70 302 292 321]
302
[302 204 293 116 272]
6
[  6 198 326 215 155]
211
[211 185 233 159 132]
97
[ 97 164  80 279 250]
267
[267 222   6 166 261]
142
[142 189 134 190 279]
40
[ 40 160 244 101 131]
88
[ 88  72  80 281 102]

In [28]:
query = 'smoke bomb burned several people'
before, after = get_events_before_after(query, 100, index, corpus, doc_to_text, corpusidx_to_doc, model)
# set of events disregards duplicate sentences (may come up when multiple sentences in top query results link back to same doc)
print("---"*10)
print("Events Before:", len(before))
# for e in before:
#     print(e)

print("---"*10)
print("Events After:", len(after))
# for e in after:
#     print(e)

Query: smoke bomb burned several people
Queried top 100 sentences; Found 94 unique documents
Related Sentences: 100
------------------------------
Events Before: 164
------------------------------
Events After: 168


In [None]:
proposed_event = 'area was taped off and people were not allowed in.'
report_top_sim(proposed_event, list(after), model, 0.4)