In [1]:
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

In [1]:
import pyspark

In [2]:
!python3 --version

Python 3.7.9


In [3]:
# first import a few utility methods that we'll use later on
from IPython.display import Image, HTML, display
# check PySpark is running
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
sc = SparkContext.getOrCreate() 
spark = SparkSession(sc)

In [4]:
pyspark.__version__

'2.4.5'

In [20]:
PATH_TO_DATA = "../data"

In [48]:
# metadata,csv has covid related scientific articles. cols include title and abstract 
covid_df = spark.read.csv(PATH_TO_DATA + "/metadata.csv", header=True, inferSchema=True) 

In [49]:
print("Number of articles: {}".format(covid_df.count()))
print("Sample of articles:")
covid_df.show(5)

Number of articles: 1056660
Sample of articles:
+--------+--------------------+--------+--------------------+--------------------+--------+---------+-------+--------------------+------------+--------------------+--------------+------+----------------+--------+--------------------+--------------------+--------------------+-----+
|cord_uid|                 sha|source_x|               title|                 doi|   pmcid|pubmed_id|license|            abstract|publish_time|             authors|       journal|mag_id|who_covidence_id|arxiv_id|      pdf_json_files|      pmc_json_files|                 url|s2_id|
+--------+--------------------+--------+--------------------+--------------------+--------+---------+-------+--------------------+------------+--------------------+--------------+------+----------------+--------+--------------------+--------------------+--------------------+-----+
|ug7v899j|d1aafb70c066a2068...|     PMC|Clinical features...|10.1186/1471-2334...|PMC35282| 11472636|  no-

In [50]:
# we only require cord_uid, title, doi, abstract, authors, url
cols_to_keep = ['cord_uid', 'title', 'doi', 'abstract', 'authors', 'url']
covid_df = covid_df[[cols_to_keep]]

In [44]:
# rows with no abstract - 235202 rows
covid_df.filter(covid_df.abstract.isNull()).count() 

235202

In [51]:
covid_df = covid_df.filter(covid_df.abstract.isNotNull())

In [52]:
!jupyter nbextension enable --py widgetsnbextension


Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [53]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') # this model does well on semantics related tasks https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1


In [54]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [28]:
import pickle
pickle_out = open("model.pkl", "wb")
pickle.dump(model, pickle_out)
pickle_out.close()

In [29]:
model_rdd_pkl = sc.binaryFiles("model.pkl")
model_rdd_data = model_rdd_pkl.collect()

# Load and broadcast python object over spark nodes
model_ = pickle.loads(model_rdd_data[0][1])

In [30]:
bc_model = sc.broadcast(model_)

In [31]:
print (bc_model.value)

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)


In [55]:
def get_embeddings(abstract):
    sentence_embeddings = bc_model.value.encode(abstract)
    return sentence_embeddings.tolist()


In [56]:
import pyspark.sql.functions as f
from pyspark.sql.types import ArrayType, IntegerType,FloatType
from pyspark.ml.linalg import Vectors, VectorUDT
emb_udf = f.udf(get_embeddings, ArrayType(FloatType()))

In [57]:
covid_df_embeddings = covid_df.withColumn("abstract_embedding", emb_udf(f.col("abstract")))

In [58]:
covid_df_embeddings.show(2, truncate=False)

+--------+--------------------------------------------------------------------------------------------------------------------------------+---------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

## Milvus

In [32]:
# !python3 -m pip install elasticsearch==8.4.3

In [23]:
import elasticsearch
es = elasticsearch.Elasticsearch(hosts="http://elastic:changeme@localhost:9200/")

In [24]:
elasticsearch.__version__

(8, 4, 3)

In [25]:
es.info(pretty=True)

ObjectApiResponse({'name': 'Julias-Air.lan', 'cluster_name': 'elasticsearch', 'cluster_uuid': 'Wba4tYECRlKZ92fsuCK1aQ', 'version': {'number': '8.4.3', 'build_flavor': 'default', 'build_type': 'tar', 'build_hash': '42f05b9372a9a4a470db3b52817899b99a76ee73', 'build_date': '2022-10-04T07:17:24.662462378Z', 'build_snapshot': False, 'lucene_version': '9.3.0', 'minimum_wire_compatibility_version': '7.17.0', 'minimum_index_compatibility_version': '7.0.0'}, 'tagline': 'You Know, for Search'})

In [52]:
es.indices.delete(index="articles")

ObjectApiResponse({'acknowledged': True})

In [53]:
create_articles = {
  "mappings": {
    "properties": {
        "ArticleID": {
                "type": "long"
            },
        "ArticleBody": {
                "type": "text"
            },
      "ArticleEmbedding": {
        "type": "dense_vector",
        "dims": 384,
          "index": True,
        "similarity": "dot_product" 
      }
    }
  }
}


# create indices with the settings and mappings above
res_articles= es.indices.create(index="articles", body=create_articles)

print("Created indices:")
print(res_articles)



Created indices:
{'acknowledged': True, 'shards_acknowledged': True, 'index': 'articles'}


In [54]:
# write articles data to ES
from pyspark.sql import SQLContext
new_wiki_df.write.format("es").save("articles",mode="Overwrite")
num_articles_es = es.count(index="articles")['count']
num_articles_df = new_wiki_df.count()
print("Dataframe count: {}".format(num_articles_df))
print("ES index count:  {}".format(num_articles_es))


KeyboardInterrupt: 

In [90]:
query = {
    "field": "ArticleEmbedding",
#     "query_vector": [-0.0026530896, -0.106110826, 0.013022521, 0.019322135, -0.032082796, -0.04718036, 0.017488785, -0.058958735, -0.08442862, -0.025244473, -0.01948575, 0.09362367, -0.09798385, -0.014297416, -0.02292031, 0.07180782, -0.09608177, 0.022824131, 0.016023096, 0.057868376, 0.146573, 0.047351293, -0.06709576, -0.0530664, -0.002906885, 0.002639588, 0.036275957, -0.07186012, 0.03046088, -0.016122233, -0.0222033, -0.0869858, -0.0321625, 0.03458366, 0.047649913, 0.024003437, 0.117709555, -0.009475278, 0.060281534, -0.057277028, 0.053496584, 0.023014013, 0.016101865, 0.028563319, -0.021855613, 0.031637758, -0.018110784, 0.015483797, -0.06045664, 3.676376E-4, 0.04384337, -0.028637392, 0.07000763, -0.07000517, 0.03147504, -0.013488896, 0.06198595, -0.043582533, 0.05090894, 0.0075850477, -0.097205974, -0.0069312635, -0.085915424, 0.04091204, -0.026549138, -0.061288953, -6.0991105E-4, -0.02959771, -0.05463921, -0.0036083153, -0.013497895, -0.009777585, 0.04459545, -0.04792598, -0.046037134, -0.075805046, -0.032905877, 0.051684774, -0.027034068, 0.050243527, 0.001229346, 0.106519, 0.05792586, 0.111894794, -0.055140868, 0.024925716, -0.08739058, 0.08171452, -0.029290123, -0.027886713, 0.053908434, 0.0063234656, 0.03733791, 0.03274986, 0.09068619, 0.016080616, 0.041873537, -0.045373484, 0.08507376, 0.020899545, -0.019598057, -0.0059730634, -0.006248996, 0.02608547, -0.021357184, 0.010636408, 0.0054980353, -0.09333139, 0.012161385, 0.026003677, 0.006000345, -0.009235833, 0.032366373, 0.040144704, -0.028128834, 0.015492466, -0.014203594, 0.08664101, 0.029814206, -0.122807994, -0.05117525, 0.07969632, 0.029994922, -0.034072775, -0.022239063, 0.035040315, 0.007864784, 7.112838E-31, -0.042441115, -0.077591516, -0.046475235, 0.014835627, 0.031368554, -0.039742354, 0.05855107, -0.10223138, 0.0034204926, 0.055519644, 0.06726484, 0.03546454, -0.027101943, -0.049015317, -0.03807126, 0.03854796, -0.001826189, 0.08234472, 0.0024778764, -0.040324118, -0.017674005, 0.0609346, 0.031504057, -0.013410832, -0.020669118, -0.033120252, -0.041680597, 0.028340092, -0.017582338, -0.0063814376, -0.0026383123, -0.021628331, -0.042379126, 0.07915032, 0.12230183, -0.0073573305, -0.0062173596, 0.0096799, 0.056917217, 0.0013061147, 0.013194965, -0.012203437, 0.006525313, -0.023449669, -0.026998332, -0.0029552083, -0.062463243, 0.03818207, 0.04470971, 0.008865803, 0.026522025, -0.05347576, -0.05869781, -0.042567633, -0.020259991, -0.021388236, -0.009516501, 0.04466841, 0.03962966, -0.02797199, -0.08211785, 0.040182285, 0.005851232, 0.038556933, -0.0675431, -0.09833584, -0.0901296, 0.020584948, 0.06538009, -0.017311312, 0.009318086, -0.0061956174, -0.057067826, -0.0039402237, 0.036091827, 0.017761959, 0.027492702, 0.03590985, 0.0045012846, -0.041267317, 0.017245743, -0.021690274, 0.02891751, 0.026906563, 0.14283444, -0.020864628, -0.057296738, 0.03358435, 0.027035857, -0.009780905, -0.013535158, -0.006445142, -0.04652213, -0.04069934, 0.016934244, -4.386837E-33, 0.0104406895, -0.009570379, 0.046733327, 0.008131192, -0.0061340486, -0.027043764, 0.006542475, -0.14234024, -0.019805528, -0.050743982, -0.07107627, 0.059776098, 0.020290485, 0.12089425, 0.043378644, 0.04286761, -0.068920776, -0.004593106, 0.033383977, 0.05009119, 0.053532332, -0.014439218, 0.0120874215, 0.029259466, -0.056896288, -0.112525016, -0.05309504, -0.008337547, 0.0651009, 0.07114189, 0.0032933685, -0.018966272, -0.030742513, 0.023531279, 0.011725062, -0.10328694, 0.042438656, 0.058188073, 0.06491843, 0.040672682, -0.027508393, -0.023217985, 0.10292234, -0.020295035, 0.059067808, 0.020196967, 0.11623325, -0.01869921, -0.09247531, -0.013471904, -0.08656749, 0.0725583, 0.018109417, 0.050069798, 0.010382377, -0.056140516, 0.04488469, -0.021179387, 0.060222067, -0.014401437, -0.0066753295, 0.031260762, -0.07571547, -0.013591448, 0.09549196, 0.011778597, -0.038075615, -0.06239211, 0.06372597, -0.065504484, -0.0066781244, -0.052720007, -0.008579144, -0.004303062, 0.016597701, -0.03828444, 0.008510275, 0.025413197, -0.011576118, 0.034638867, 0.05158509, 0.0762118, -0.025284994, -0.11308367, -0.008343394, 0.056923654, 0.20058066, -0.018861927, -0.06398137, -0.06702765, 0.034940723, -0.021810168, 0.021126147, 0.012160713, 0.012544474, -5.3022403E-33, -0.078855336, 0.07898251, -0.022382038, 0.02942526, -0.011429098, -0.12028537, -0.012932262, -0.0064684693, 0.02619199, -0.011775266, 0.052204818, 0.0012121077, 0.027676169, -0.09462038, 0.040469326, -0.07226843, -0.06779553, -0.058613695, 0.007079697, -0.0030365775, 0.028786583, 0.022336049, -0.019038577, 0.007419755, -0.0951915, -0.023780884, 0.08073516, -0.036096074, 0.04383598, 0.016261898, 0.019834898, 0.10520888, -0.03293175, -0.03486668, -0.13005786, -0.0018912273, -0.046414826, -0.07028865, 0.12031976, -0.032238547, -0.031780735, 0.08815358, -0.0010059525, 0.032592103, 0.0037085838, 0.04731914, -0.0014357789, -0.024320452, 0.05670293, 0.014157247, -0.012958769, 0.02446379, 0.04257979, 0.015694812, 0.009032202, 0.05068394, -0.0199882, 0.06611641, -0.097098745, -0.05016419, -0.071861915, -0.027177889, -0.008451132, -0.02107583],
    "query_vector": [0.002368977, -0.0144982245, 0.023526847, -0.0270676, -0.06152153, -0.06319506, -0.006329024, 0.038582187, -0.10974901, -0.095876604, -0.020383101, 0.09050552, -0.06406877, 0.042892814, -0.074971884, 0.084289216, -0.03736275, -0.010245111, -0.09258829, 0.020025147, 0.07120598, -0.04971169, 0.015024931, 0.020793352, 0.05473224, 0.06566036, -0.068200275, 0.094899155, 0.040251877, -0.017525531, 0.051904332, 0.020421993, 0.013445646, 0.021007864, 0.10727594, 0.028702682, 0.09590714, -0.018366627, 0.019068921, -0.0615405, 0.058271516, 0.01933643, 0.0046409722, 0.09529977, -0.041849006, 0.072505645, 0.04183017, -0.06618785, 0.037792426, 0.04485447, 0.04457694, -0.014782231, 0.101819016, -0.016642287, 0.027641037, 0.054425687, -0.036799554, 0.02884082, 0.056466516, 0.04443891, 0.023617487, -0.048990775, -0.06679463, 0.038091518, -0.041292, 0.018274385, -0.11654477, -0.038729046, -0.055952363, -0.042560954, 0.08443825, -0.027033567, -0.029463334, 0.012365034, -0.035966076, -0.0664877, 0.03612434, 0.0749826, -0.10116313, 0.100956865, -0.016306406, 0.026671264, 0.017784053, 0.088061176, -0.11567393, 0.015200883, -0.03610861, 0.06974517, 0.05692757, 0.0064039323, 0.045728277, -0.05490624, -0.07118619, 0.015667854, -0.009063289, 0.0076261084, -0.032558806, -0.018248437, -0.011964076, -0.020121334, 0.03436083, 0.026972553, -0.06661181, -0.0028853759, 0.025089953, -0.017351832, -0.0010943299, 0.09086529, -0.012296073, -0.0027995806, 0.05712214, 0.011286844, -0.034836028, 0.11548477, -0.035157632, -0.05980753, 0.014466054, -0.008511785, -0.04349793, -0.14256237, 0.013189861, 0.07079003, 0.032072198, -0.028593201, -0.095878094, 0.05606845, 0.0031338737, 2.4481045E-31, -0.099471554, 0.008352207, -0.068235345, 0.029604977, 0.059200346, -0.10011949, 0.019923128, 0.07390007, -0.020234453, 0.00827478, -0.0024624022, 0.037483715, -0.055007525, -0.053796902, -0.012903822, 0.014793636, 0.038752925, -0.06403104, -8.1623264E-4, -0.08712315, 0.010577271, -0.007903737, 0.03631186, 0.0024037657, -0.05908608, -0.1294004, -0.093244046, 0.09330286, 0.026699832, 0.0052673863, -0.005539673, -0.027170707, 0.05001951, 0.030519053, 0.067550994, 0.098150365, -0.02658749, -0.011032814, -0.0147188855, 0.05905123, -0.044007435, -0.012623193, -0.04378616, 0.04204706, -0.010736497, 0.019962404, -0.03962603, 0.0037885748, 0.02496956, 0.0324687, -0.027894758, -0.004452769, -0.041351028, 0.018607322, -0.040638227, 0.018351585, 0.0020137108, -0.036686186, -0.011227237, 0.012679353, -0.13659516, -0.01233588, 0.023829697, 0.037931547, -0.031751633, -0.0519563, 0.05592474, -0.01772243, -0.09141166, 5.722347E-4, 0.037370626, -0.03722956, -0.005987689, -0.0262768, -0.0033156164, -0.037257437, -0.08398864, 0.04437932, 0.03554128, -0.01026183, 0.07097148, 0.0046338914, 0.052947335, 0.066810116, 0.033074893, -0.03641795, 0.05610068, -0.037080638, -0.10845184, 0.0069260546, -0.091040455, 0.01859011, 0.008945483, -0.049684417, 0.032262947, -5.879136E-33, -0.0012220285, 0.007905317, 0.022022778, -0.06956964, -0.06771696, 0.031538192, 0.029543558, -0.05830277, 0.018832017, -0.019282239, -0.069842055, 0.08996895, -0.025959978, 0.114866324, -0.09499167, 0.05315035, -0.01909329, -0.054535788, 0.02521154, -0.0065357136, 0.020912914, 0.04137011, -0.053962063, 0.09136067, -0.02436968, 0.015813138, 0.020087928, 0.072709404, -0.0048399246, 0.008892837, -0.013705127, 0.0050969287, 0.14630456, -0.053477902, 0.026720993, 0.06484497, 0.015059467, -0.018938672, 0.019018753, -0.05743301, 0.0042085447, -0.1301546, 0.013365679, -0.11481724, -0.0039650206, -0.04188711, -0.026372196, -0.01678552, -0.013866285, 0.028431132, -0.005080146, 0.06425022, 0.018923571, 0.09801433, 0.06717092, -0.045013003, 0.05612729, 0.007712787, -0.045470852, -0.0734631, -0.013141988, 0.009614672, 0.012187457, 0.04773729, 0.054260805, 0.028477136, 0.023143142, -0.097899325, -0.0642174, -0.014278599, -0.100660674, -0.016723381, 0.06427751, 0.03287425, -0.08626409, -0.0954718, -0.026407506, -0.021837747, 0.017445814, -0.06763849, -0.008557039, 0.07117043, 0.05989563, -0.0031953563, 0.016573992, -0.06333501, 0.09330664, -8.207419E-4, -0.032371726, -0.04370505, 0.098739296, -0.031903636, -0.068461895, -0.003606344, -0.036733933, -6.239126E-33, -0.011548932, 0.024357317, -0.038152684, -0.016619805, -0.004227289, -0.06708526, 0.06656596, 0.07441204, 0.04495226, 0.026751146, -0.026529675, 0.03758306, -0.0025080636, -0.021281747, -0.007817577, -0.043852143, 0.0165045, -0.08535518, -0.022916218, -0.011729721, 0.070208274, 0.039096605, 0.00788824, -0.021716163, 0.014755593, -0.074116856, 0.010112023, -0.011099901, 0.0393602, -0.006320954, -0.012658638, 0.026892807, 0.07309131, 0.054369822, -0.0018915823, 0.011408576, 0.033384435, -0.042764015, -0.01669409, -0.039452538, -0.07208219, 0.063401595, -0.03410022, 0.03552125, 0.051395725, 0.035846144, -0.02752153, 0.05370285, 0.047187172, 0.03535646, 0.015923133, 0.05017062, 0.07464302, 0.03602904, -0.03533178, 0.023518767, -0.06925232, 0.0064210133, -0.052782036, 0.029399788, -0.04605264, -0.01240707, 0.01872488, -0.058449537],
    "k": 1,
    "num_candidates":10000
}
res = es.search(index="articles", knn=query, source=["ArticleEmbedding", "ArticleBody"])


In [91]:
res['hits']

{'total': {'value': 1, 'relation': 'eq'},
 'max_score': 0.5844929,
 'hits': [{'_index': 'articles',
   '_id': 'sNA4SYQB6ehvhq6PcDn-',
   '_score': 0.5844929,
   '_source': {'ArticleEmbedding': [0.05728377,
     -0.013098963,
     -0.008899082,
     0.021507302,
     -0.007438295,
     -0.084927954,
     -0.0370988,
     -0.11202134,
     -0.057328977,
     -7.89589e-05,
     0.05069894,
     -0.008978922,
     -0.041873034,
     -0.09128488,
     0.00028431969,
     0.12215205,
     -0.054541413,
     0.09311677,
     -0.008690997,
     0.032057155,
     0.04203597,
     -0.027274426,
     0.062772006,
     -0.009234066,
     0.045994025,
     -0.005340212,
     0.08396787,
     0.019829223,
     0.08709027,
     -0.007518563,
     -0.061248634,
     0.040912937,
     0.007310492,
     0.0432856,
     0.034601554,
     0.020428147,
     -0.030315233,
     -0.012386789,
     0.02627573,
     -0.047554422,
     0.06440479,
     0.088841796,
     0.0846807,
     0.028445017,
     -0.08635