# Semantic Search over crosswords dataset by ParsBERT-FarsTail model

In [None]:
from tqdm import tqdm
import torch
from sentence_transformers import models, SentenceTransformer, util
import random

Using [bert-fa-base-uncased-farstail](https://huggingface.co/m3hrdadfi/bert-fa-base-uncased-farstail) which is fine-tuned on [ParsBERT](https://github.com/hooshvare/parsbert) with [FarsTail dataset](https://github.com/dml-qom/FarsTail). For more info check [Sentence Transofrmers](https://github.com/m3hrdadfi/sentence-transformer)))

In [8]:
def load_st_model(model_name_or_path):
    word_embedding_model = models.Transformer(model_name_or_path)
    pooling_model = models.Pooling(
        word_embedding_model.get_word_embedding_dimension(),
        pooling_mode_mean_tokens=True,
        pooling_mode_cls_token=False,
        pooling_mode_max_tokens=False)
    
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    return model

In [10]:
# Load the Sentence-Transformer

#load from local files
embedder = load_st_model('model/bert-fa-base-uncased-farstail')

#load from hugginface
#embedder = load_st_model('m3hrdadfi/bert-fa-base-uncased-farstail')

In [None]:
#save model to local
#embedder.save("model/bert-fa-base-uncased-farstail")

In [65]:
cw_lines = []

In [66]:
data_files = ["cw.train.tsv", "cw.dev.tsv", "cw.test.tsv"]
index = 0
for file in data_files:
    for i, line in enumerate(open(file, encoding="utf-8")):
        cw_lines.append(line.strip())
        index += 1

In [184]:
print(cw_lines[0], cw_lines[-1], len(cw_lines), sep="\n")

بی حال و سست	کسل
آرام و یواش	اهسته
30157


Embed all dataset records

In [128]:
all_embeddings = embedder.encode(cw_lines, convert_to_tensor=True, show_progress_bar=True)

Batches:   0%|          | 0/943 [00:00<?, ?it/s]

Calculate query similarity against all records of dataset except itself. Create first and top_five matches lists.

In [147]:
top_k = 6

firsts = []
top_fives = []
for i in tqdm(range(len(cw_lines))):
    q = cw_lines[i].split("\t")[0]
    a = cw_lines[i].split("\t")[1]
    
    query_embedding = embedder.encode(q, convert_to_tensor=True, show_progress_bar=False)
    cos_scores = util.pytorch_cos_sim(query_embedding, all_embeddings)[0]
    #cos_scores = util.semantic_search(query_embedding, all_embeddings)[0]
    cos_scores = cos_scores.cpu()
    
    #We use torch.topk to find the highest 5 scores
    top_results = torch.topk(cos_scores, k=top_k)

    first = []
    for j in [0, 1]:
        if top_results[1][j] != i and a in cw_lines[top_results[1][j]]:
            firsts.append({
                    'q': q,
                    'a': a,
                    'found': cw_lines[top_results[1][j]],
                    'score': top_results[0][j]
                })
    
    top_five = []
    top_score = []
    for score, idx in zip(top_results[0], top_results[1]):
        if idx != i and a in cw_lines[idx]:
            top_five.append(cw_lines[idx])
            top_score.append(score)

    if top_five:
        top_fives.append({
            'q': q,
            'a': a,
            'found': top_five,
            'score': top_score
        })
    #for r in cos_scores:
        #print(cw_lines[r['corpus_id']], round(r['score']*100, 2))
    #print('- - '*50)

100%|██████████| 30157/30157 [28:16<00:00, 17.77it/s]


In [149]:
len(firsts), len(top_fives)

(3966, 6853)

Percentage of first highest scores and top five highest scores in record match (existing of answer in question or answer)

In [188]:
len(firsts)/len(cw_lines), round(len(top_fives)/len(cw_lines)*100, 2)

(0.1315117551480585, 22.72)

In [150]:
random.choice(firsts)

{'q': 'پزشک و طبیب',
 'a': 'دکتر',
 'found': 'دکتر و پزشک\tطبیب',
 'score': tensor(0.9382)}

In [151]:
random.choice(top_fives)

{'q': 'زیادی و اضافی',
 'a': 'زاید',
 'found': ['زیادی\tزاید'],
 'score': [tensor(0.7994)]}

Percentage of exact matches of first highst scores

In [153]:
refined_first  = [x for x in firsts if x['a'] == x['found'].split("\t")[1]]

In [187]:
len(refined_first), round(len(refined_first)/len(cw_lines)*100, 2)

(2463, 8.17)

Percentage of exact matches of top five highest scores.

In [175]:
refined_top_fives = []
for five in top_fives:
    for i, f in enumerate(five['found']):
        if five['a'] == f.split("\t")[1]:
            refined_top_fives.append({
                'q': five['q'],
                'a': five['a'],
                'found': f,
                'score': five['score'][i]
            })
            break

In [186]:
len(refined_top_fives), round(len(refined_top_fives)/len(cw_lines)*100, 2)

(4656, 15.44)

In [163]:
random.choice(refined_first)

{'q': 'گرد چیزی گشتن',
 'a': 'طواف',
 'found': 'پیرامون چیزی گشتن\tطواف',
 'score': tensor(0.8122)}