In [1]:
from elasticsearch import Elasticsearch
from sentence_transformers import SentenceTransformer, util

In [2]:
client = Elasticsearch('http://localhost:9200')

<h1>Adding semantic similarity to our search-mechanism</h1>

In [3]:
def search_elasticsearch(query: str):
    res = client.search(
        index='articles',
        body={
          "size": 50,
          "query": {
            "bool": {
              "must": [
                {"multi_match": {
                  "query": query,
                  "fields": ["text^3", "article_title^3"]
                }}
              ],
              "must_not": [
                {"match": {"main_section": "External links"}}
              ],
              "should": [
                {"match": {"main_section": {"query": "Summary","boost": 2}}}
              ]
            }
          },
          "highlight": {
            "fields": {
              "text": {}
             } 
          }
        }
    )['hits']['hits']
    
    res = [dict(
        source=article['_source'], 
        highlight=article['highlight']['text']
    ) for article in res]
    
    return res

In [6]:
query = 'Which diseases can be transmitted by animals?'
r = search_elasticsearch(query)
print(r[0]['highlight'])

['from plant to plant <em>by</em> insects that feed on plant sap, such as aphids; and viruses in <em>animals</em> <em>can</em> <em>be</em>', 'Norovirus and rotavirus, common causes of viral gastroenteritis, are <em>transmitted</em> <em>by</em> the faecal–oral route', 'HIV is one of several viruses <em>transmitted</em> through sexual contact and <em>by</em> exposure to infected blood.', 'This <em>can</em> <em>be</em> narrow, meaning a virus is capable of infecting few species, or broad, meaning it is capable', 'Immune responses <em>can</em> also <em>be</em> produced <em>by</em> vaccines, <em>which</em> confer an artificially acquired immunity to']


  res = client.search(


In [5]:
model = SentenceTransformer('prajjwal1/bert-tiny')

No sentence-transformers model found with name /Users/mehrdadmoradi/.cache/torch/sentence_transformers/prajjwal1_bert-tiny. Creating a new one with MEAN pooling.
Some weights of the model checkpoint at /Users/mehrdadmoradi/.cache/torch/sentence_transformers/prajjwal1_bert-tiny were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical 

In [6]:
query_vector = model.encode(query)

for article in r:
    text = article['source']['text']
    vector = model.encode(text)
    score = util.cos_sim(query_vector, vector)
    article.update({'score': score.item()})

In [7]:
r[0].keys()

dict_keys(['source', 'highlight', 'score'])

<h1>refactoring our functions</h1>

In [17]:
def score(text:str, query:str):
    query_vector = model.encode(query)
    text_vector = model.encode(text)
    cos_sim_score = util.cos_sim(query_vector, text_vector)
    return cos_sim_score.item()

In [18]:
def get_elasticsearch_dsl(query) -> dict:
    return {
      "size": 50,
      "query": {
        "bool": {
          "must": [
            {"multi_match": {
              "query": query,
              "fields": ["text^3", "article_title^3"]
            }}
          ],
          "must_not": [
            {"match": {"main_section": "External links"}}
          ],
          "should": [
            {"match": {"main_section": {"query": "Summary","boost": 2}}}
          ]
        }
      },
      "highlight": {
        "fields": {
          "text": {}
         } 
      }
    }

In [27]:
def semantic_search(es_client:Elasticsearch, query:str):
    search_result = es_client.search(
        index='articles',
        body=get_elasticsearch_dsl(query)
    )['hits']['hits']
    
    res = [dict(
        source=article['_source'], 
        highlight=article['highlight']['text'],
        score=score(article['_source']['text'], query)
    ) for article in search_result]
    
    res = sorted(res, reverse=True, key=lambda x: x['score'])
    
    return res[:10]

In [33]:
result = semantic_search(client, 'Which diseases can be transmitted by animals?')
for r in result:
    print(r['source']['article_title'])

  search_result = es_client.search(
  search_result = es_client.search(


Pandemic
Virus
Swine influenza
Swine influenza
Virus
Virus
Virus
Virus
Pandemic
Swine influenza
