In [1]:
#@title Setup common imports and functions
import json
import nltk
import os
import pprint
import random
import simpleneighbors
import urllib
from IPython.display import HTML, display

import tensorflow.compat.v2 as tf
import tensorflow_hub as hub
from tensorflow_text import SentencepieceTokenizer

nltk.download('punkt')


def download_squad(url):
  return json.load(urllib.request.urlopen(url))

def extract_sentences_from_squad_json(squad):
  all_sentences = []
  for data in squad['data']:
    for paragraph in data['paragraphs']:
      sentences = nltk.tokenize.sent_tokenize(paragraph['context'])
      all_sentences.extend(zip(sentences, [paragraph['context']] * len(sentences)))
  return list(set(all_sentences)) # remove duplicates

def extract_questions_from_squad_json(squad):
  questions = []
  for data in squad['data']:
    for paragraph in data['paragraphs']:
      for qas in paragraph['qas']:
        if qas['answers']:
          questions.append((qas['question'], qas['answers'][0]['text']))
  return list(set(questions))

def output_with_highlight(text, highlight):
  output = "<li> "
  i = text.find(highlight)
  while True:
    if i == -1:
      output += text
      break
    output += text[0:i]
    output += '<b>'+text[i:i+len(highlight)]+'</b>'
    text = text[i+len(highlight):]
    i = text.find(highlight)
  return output + "</li>\n"

def display_nearest_neighbors(query_text, answer_text=None):
  query_embedding = model.signatures['question_encoder'](tf.constant([query_text]))['outputs'][0]
  search_results = index.nearest(query_embedding, n=num_results)

  if answer_text:
    result_md = '''
    <p>Random Question from SQuAD:</p>
    <p>&nbsp;&nbsp;<b>%s</b></p>
    <p>Answer:</p>
    <p>&nbsp;&nbsp;<b>%s</b></p>
    ''' % (query_text , answer_text)
  else:
    result_md = '''
    <p>Question:</p>
    <p>&nbsp;&nbsp;<b>%s</b></p>
    ''' % query_text

  result_md += '''
    <p>Retrieved sentences :
    <ol>
  '''

  if answer_text:
    for s in search_results:
      result_md += output_with_highlight(s, answer_text)
  else:
    for s in search_results:
      result_md += '<li>' + s + '</li>\n'

  result_md += "</ol>"
  display(HTML(result_md))

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/ilhambintang/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
%%time
squad_url = 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json' #@param ["https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json", "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"]

squad_json = download_squad(squad_url)
sentences = extract_sentences_from_squad_json(squad_json)
questions = extract_questions_from_squad_json(squad_json)
print("%s sentences, %s questions extracted from SQuAD %s" % (len(sentences), len(questions), squad_url))

print("\nExample sentence and context:\n")
sentence = random.choice(sentences)
print("sentence:\n")
pprint.pprint(sentence[0])
print("\ncontext:\n")
pprint.pprint(sentence[1])
print()

10455 sentences, 10552 questions extracted from SQuAD https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json

Example sentence and context:

sentence:

('On May 28, 2012, Jacksonville was hit by Tropical Storm Beryl, packing winds '
 'up to 70 miles per hour (113 km/h) which made landfall near Jacksonville '
 'Beach.')

context:

('Jacksonville has suffered less damage from hurricanes than most other east '
 'coast cities, although the threat does exist for a direct hit by a major '
 'hurricane. The city has only received one direct hit from a hurricane since '
 '1871; however, Jacksonville has experienced hurricane or near-hurricane '
 'conditions more than a dozen times due to storms crossing the state from the '
 'Gulf of Mexico to the Atlantic Ocean, or passing to the north or south in '
 'the Atlantic and brushing past the area. The strongest effect on '
 'Jacksonville was from Hurricane Dora in 1964, the only recorded storm to hit '
 'the First Coast with sustained hurr

In [3]:
%%time
module_url = "https://tfhub.dev/google/universal-sentence-encoder-qa/3"
model = hub.load(module_url)

CPU times: user 27.8 s, sys: 2.87 s, total: 30.7 s
Wall time: 36.2 s


In [6]:
%%time

batch_size = 100

encodings = model.signatures['response_encoder'](
  input=tf.constant([sentences[0][0]]),
  context=tf.constant([sentences[0][1]]))
index = simpleneighbors.SimpleNeighbors(
    len(encodings['outputs'][0]), metric='angular')

print('Computing embeddings for %s sentences' % len(sentences))
slices = zip(*(iter(sentences),) * batch_size)
num_batches = int(len(sentences) / batch_size)
for n, s in enumerate(slices):
  response_batch = list([r for r, c in s])
  context_batch = list([c for r, c in s])
  encodings = model.signatures['response_encoder'](
    input=tf.constant(response_batch),
    context=tf.constant(context_batch)
  )
  for i in range(len(response_batch)):
    index.add_one(response_batch[i], encodings['outputs'][i])

index.build()
print('simpleneighbors index for %s sentences built.' % len(sentences))

Computing embeddings for 10455 sentences
simpleneighbors index for 10455 sentences built.
CPU times: user 14min 42s, sys: 1min 5s, total: 15min 47s
Wall time: 12min 33s


In [None]:
num_results = 25

query = random.choice(questions)
display_nearest_neighbors(query[0], query[1])