## Get the model

In [2]:
from sentence_transformers import SentenceTransformer
# grab a pre-trained basic model from sentenceTransformer (sentence BERT)
model = SentenceTransformer('bert-base-nli-mean-tokens')

In [16]:
# A quick test to see how our model is doing
sentences = ['This framework generates embeddings for each input sentence',
            'Sentences are passed as a list of string.', 
            'The quick brown fox jumps over the lazy dog.',
            'hello world!!']
sentence_embeddings = model.encode(sentences)
for sentence, embedding in zip(sentences, sentence_embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding.shape, type(embedding))
    print("")

Sentence: This framework generates embeddings for each input sentence
Embedding: (768,) <class 'numpy.ndarray'>

Sentence: Sentences are passed as a list of string.
Embedding: (768,) <class 'numpy.ndarray'>

Sentence: The quick brown fox jumps over the lazy dog.
Embedding: (768,) <class 'numpy.ndarray'>

Sentence: hello world!!
Embedding: (768,) <class 'numpy.ndarray'>



In [None]:
embeddings

## Examine the Quora Duplicate Question Pair

In [4]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark import SparkContext, SparkConf

# create the Spark Session
spark = SparkSession.builder.getOrCreate()

# create the Spark Context
sc = spark.sparkContext

data_path = "quora_duplicate_questions.tsv"

In [5]:
curr_data = spark.read.csv(data_path, header=True, sep = '\t')
curr_data = curr_data.withColumn('qid1', col('qid1').cast("Int"))
curr_data = curr_data.withColumn('qid2', col('qid2').cast("Int"))
curr_data = curr_data.withColumn('id', col('id').cast("Int"))
# for some reasons, the database contains empty lines... use filter to get rid of those
curr_data = curr_data.na.drop()
print("take some sample data")
curr_data.show()

take some sample data
+---+----+----+---------------------+--------------------+------------+
| id|qid1|qid2|            question1|           question2|is_duplicate|
+---+----+----+---------------------+--------------------+------------+
|  0|   1|   2| What is the step ...|What is the step ...|           0|
|  1|   3|   4| What is the story...|What would happen...|           0|
|  2|   5|   6| How can I increas...|How can Internet ...|           0|
|  3|   7|   8| Why am I mentally...|Find the remainde...|           0|
|  4|   9|  10| Which one dissolv...|Which fish would ...|           0|
|  5|  11|  12| Astrology: I am a...|I'm a triple Capr...|           1|
|  6|  13|  14|  Should I buy tiago?|What keeps childe...|           0|
|  7|  15|  16| How can I be a go...|What should I do ...|           1|
|  8|  17|  18|When do you use シ...|"When do you use ...|           0|
|  9|  19|  20| Motorola (company...|How do I hack Mot...|           0|
| 10|  21|  22| Method to find se...|What a

In [8]:
print("num of total question pairs", curr_data.select('id').distinct().count())
# build a dictionary: question_id -> question_text
left_questions = {i[0]: i[1] for i in curr_data.select('qid1', 'question1').collect()}
right_questions = {i[0]: i[1] for i in curr_data.select('qid2', 'question2').collect()}
all_questions = {**left_questions, **right_questions}  # combine two dictinoaries together
print('total unique questions', len(all_questions))

num of total question pairs 404275
total unique questions 537915


In [9]:
# NOTE: the question pair ids seems to contain holes, i.e. not consecutive
print(max(all_questions.keys()), min(all_questions.keys()), len(all_questions.keys()))

537933 1 537915


## Convert Questions into feature vectors

In [10]:
all_question_list = sorted(all_questions.items(), key=lambda x: x[0])
all_question_list[:3]

[(1, 'What is the step by step guide to invest in share market in india?'),
 (2, 'What is the step by step guide to invest in share market?'),
 (3, 'What is the story of Kohinoor (Koh-i-Noor) Diamond?')]

In [11]:
# output, a list of vectors, each of size [786,]
# NOTE: Skip this cell and the next if you want to directly load the feature vectors from the disk
feature_vec = model.encode([i[1] for i in all_question_list], show_progress_bar=True, convert_to_numpy=True)

Batches: 100%|██████████| 67240/67240 [18:11<00:00, 61.60it/s]


In [12]:
# should be able to store the feature_vec to disk
import pickle
with open('feature_vec.pickle', 'wb') as handle:
    pickle.dump(feature_vec, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [2]:
# should be able to load the feature_vec from disk
import pickle
with open('feature_vec.pickle', 'rb') as handle:
    b = pickle.load(handle)

In [None]:
# TODO: build a graph
# You should be able to use the index of each feature_vec, which has the same corresponding indexes in all_question_list