In [118]:
import numpy as np
import math
from collections import defaultdict
from bert_embedding import BertEmbedding
from scipy import spatial
import re
import os

In [158]:
## CONSTANTS

## Min/max distance between two points in dataset
dmin = 0.0604
dmax = 0.8

## Approximation factor
C = 1.005

## Size of hash function output
K = 10

## Number of hash functions drawn
L = 20

## 1 - delta = Success probability of PLEB
delta = 0.001
num_hash_table_constructions = math.ceil(math.log(1.0/(1 - delta), 2))

In [120]:
## PREPROCESSING - get sentences from text file

dirname = os.path.dirname(os.path.abspath(''))
with open(os.path.join(dirname, 'data/datacleaned_valid.txt'), encoding="utf8") as fp:
    phrases = fp.read().split('\n')
phrases_list = [list(filter(None, line.strip().split(','))) for line in phrases if line.strip() and re.search('[a-zA-Z]', line)]
sentences = [sentence for phrase in phrases_list for sentence in phrase]

bert_embedding = BertEmbedding()

In [129]:
def get_bert_sentence(sentence):
    average = np.zeros(768,)
    sentence_embedding = bert_embedding([sentence])[0]
    for word_embedding in sentence_embedding[1:]:
        for token in word_embedding:
            average += token
    return np.divide(average, len(sentence_embedding)-1)

def get_bert_sentences(sentences):
    bert_sentences = []
    sentence_to_bert_sentence = defaultdict(list)
    bert_sentence_to_sentence = defaultdict(list)
    for sentence in sentences:
        bert_sentence = get_bert_sentence(sentence)
        bert_sentences.append(bert_sentence)
        sentence_to_bert_sentence[sentence] = bert_sentence
        bert_sentence_to_sentence[str(bert_sentence)] = sentence
    return (bert_sentences, sentence_to_bert_sentence, bert_sentence_to_sentence)

def sentence_pair_dist(sentence1, sentence2):
    bert_sentence1 = get_bert_sentence(sentence1)
    bert_sentence2 = get_bert_sentence(sentence2)
    return spatial.distance.cosine(bert_sentence1, bert_sentence2)

In [65]:
## PREPROCESSING - get bert embeddings of sentences
bert_sentences, sentence_to_bert_sentence, bert_sentence_to_sentence = get_bert_sentences(sentences)

In [128]:
def round_dot_product(dot_product):
    if dot_product > 0:
        return 1
    return 0


def random_hyperplane_round(bert_sentence, hyperplane):
    dot_product = np.dot(bert_sentence, hyperplane)
    return round_dot_product(dot_product)


def get_random_hyperplane_hash_function(bert_sentence_size, size=K):
    return [np.random.normal(size=bert_sentence_size) for _ in range(size)]


def eval_hyperplane_hash_function(hash_function, bert_sentence):
    return [random_hyperplane_round(bert_sentence, hyperplane) for hyperplane in hash_function]


def construct_hash_table(bert_sentences):
    hash_function = get_random_hyperplane_hash_function(bert_sentence_size=len(bert_sentences[0]))
    hash_table = defaultdict(list)
    for bert_sentence in bert_sentences:
        key = eval_hyperplane_hash_function(hash_function, bert_sentence)
        hash_table[str(key)].append(bert_sentence)
    return (hash_function, hash_table)


def construct_hash_tables(bert_sentences, num_tables=L):
    hash_tables = []
    for _ in range(L):
        hash_tables.append(construct_hash_table(bert_sentences))
    return hash_tables

In [207]:
def solve_pleb(dist, hash_table_constructions, r2, query_bert_sentence, query_cutoff=4*L):
    for hash_tables in hash_table_constructions:
        num_queries = 0
        for hash_function, hash_table in hash_tables:
            key = eval_hyperplane_hash_function(hash_function, query_bert_sentence)
            for candidate_bert_nn in hash_table[str(key)]:
                if dist(candidate_bert_nn, query_bert_sentence) < r2:
                    return candidate_bert_nn
                num_queries += 1
                if num_queries >= query_cutoff:
                    break
            if num_queries >= query_cutoff:
                    break
    return None

In [208]:
num_pleb_instances = math.ceil(math.log((dmax/(C-1.0))/(dmin/(2.0*C)), C))
pleb_instances = []
for index in range(num_pleb_instances):
    pleb_instance = {}
    pleb_instance['r2'] = dmin/(2.0)*(C**index)
    pleb_instance['hash_table_construction'] = [construct_hash_tables(bert_sentences) for _ in range(num_hash_table_constructions)]
    pleb_instances.append(pleb_instance)

In [209]:
def solve_ann(query_sentence, pleb_instances):
    query_bert_sentence = get_bert_sentence(query_sentence)
    lo = 0
    mid = math.floor(num_query_distances/2)
    high = num_query_distances - 1
    bert_ann = None
    while lo <= high:
        pleb_instance = pleb_instances[mid]
        candidate_bert_nn = solve_pleb(spatial.distance.cosine, pleb_instance['hash_table_construction'], pleb_instance['r2'], query_bert_sentence)
        if candidate_bert_nn is None:
            lo = mid + 1
            mid = math.floor((lo + high)/2)
        else:
            bert_ann = candidate_bert_nn
            high = mid - 1
            mid = math.floor((lo + high)/2)
    ann = bert_sentence_to_sentence[str(bert_ann)]
    return ann

In [211]:
## query_sent is the sentence you're finding the approximate nearest neighbor for (ann). To solve,
## use solve_ann, which takes query_sent and pleb_instances (this is part of preprocessing, so after you
## run it above, you don't need to touch it). It returns an approximate nearest neighbor from
## the dataset.

query_sent = 'i can\'t hear you'
ann = solve_ann(query_sent, pleb_instances)
print(ann)
print(sentence_pair_dist(query_sent, ann))

 i canâ€™t hear you let me turn the volume up
0.23102311702435696
