In [1]:
import os
import json
import nltk
import pprint
import random
import urllib
import simpleneighbors

import tensorflow_hub as hub
import tensorflow.compat.v2 as tf

from tqdm.notebook import tqdm
from IPython.display import HTML, display
from tensorflow_text import SentencepieceTokenizer

In [2]:
def download_squad(url):
  return json.load(urllib.request.urlopen(url))

def extract_sentences_from_squad_json(squad):
  all_sentences = []
  for data in squad['data']:
    for paragraph in data['paragraphs']:
      sentences = nltk.tokenize.sent_tokenize(paragraph['context'])
      all_sentences.extend(zip(sentences, [paragraph['context']] * len(sentences)))
  return list(set(all_sentences)) # remove duplicates

def extract_questions_from_squad_json(squad):
  questions = []
  for data in squad['data']:
    for paragraph in data['paragraphs']:
      for qas in paragraph['qas']:
        if qas['answers']:
          questions.append((qas['question'], qas['answers'][0]['text']))
  return list(set(questions))

def output_with_highlight(text, highlight):
  output = "<li> "
  i = text.find(highlight)
  while True:
    if i == -1:
      output += text
      break
    output += text[0:i]
    output += '<b>'+text[i:i+len(highlight)]+'</b>'
    text = text[i+len(highlight):]
    i = text.find(highlight)
  return output + "</li>\n"

def display_nearest_neighbors(query_text, answer_text=None):
  query_embedding = model.signatures['question_encoder'](tf.constant([query_text]))['outputs'][0]
  search_results = index.nearest(query_embedding, n = num_results)

  if answer_text:
    result_md = '''
    <p>Random Question from SQuAD:</p>
    <p>&nbsp;&nbsp;<b>%s</b></p>
    <p>Answer:</p>
    <p>&nbsp;&nbsp;<b>%s</b></p>
    ''' % (query_text , answer_text)
  else:
    result_md = '''
    <p>Question:</p>
    <p>&nbsp;&nbsp;<b>%s</b></p>
    ''' % query_text

  result_md += '''
    <p>Retrieved sentences :
    <ol>
  '''

  if answer_text:
    for s in search_results:
      result_md += output_with_highlight(s, answer_text)
  else:
    for s in search_results:
      result_md += '<li>' + s + '</li>\n'

  result_md += "</ol>"
  display(HTML(result_md))