# Aula6 - Doc2Query

[Unicamp - IA368DD: Deep Learning aplicado a sistemas de busca.](https://www.cpg.feec.unicamp.br/cpg/lista/caderno_horario_show.php?id=1779)

Autor: Marcus Vinícius Borela de Castro

[Repositório no github](https://github.com/marcusborela/deep_learning_em_buscas_unicamp)

Stage: calculating metrics before expanding texts with queries 

# Organizando o ambiente

In [31]:
import pandas as pd

In [12]:
DIRETORIO_TRABALHO = '/home/borela/fontes/deep_learning_em_buscas_unicamp/local/doc2query'

In [13]:
assert os.path.exists(DIRETORIO_TRABALHO), f"Path para {DIRETORIO_TRABALHO} não existe!"

In [1]:
import os

export JVM_PATH=/usr/lib/jvm/java-11-openjdk-amd64/lib/server/libjvm.so
export JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64

In [2]:
os.environ['JVM_PATH'] = '/usr/lib/jvm/java-11-openjdk-amd64/lib/server/libjvm.so'
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-11-openjdk-amd64'

In [3]:
from pyserini.search.lucene import LuceneSearcher

  from .autonotebook import tqdm as notebook_tqdm


# Construindo o índice prebuilt trec-covid

In [4]:
LuceneSearcher.from_prebuilt_index('beir-v1.0.0-trec-covid.flat')

Downloading index at https://rgw.cs.uwaterloo.ca/pyserini/indexes/lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz...


lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz: 216MB [02:35, 1.46MB/s]                               


<pyserini.search.lucene._searcher.LuceneSearcher at 0x7f152012a880>

In [74]:
os.getcwd()

'/home/borela/fontes/deep_learning_em_buscas_unicamp/code/aula6_doct2query'

In [75]:
!ls /home/borela/.cache/pyserini/indexes/

lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.57b812594b11d064a23123137ae7dade


In [78]:
nome_indice_trec_covid_sem_expansao = os.popen('ls /home/borela/.cache/pyserini/indexes/').read()[:-1]

In [79]:
print(nome_indice_trec_covid_sem_expansao)

lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.57b812594b11d064a23123137ae7dade


In [80]:
caminho_indice_trec_covid_sem_expansao = f'/home/borela/.cache/pyserini/indexes/{nome_indice_trec_covid_sem_expansao}'

In [81]:
os.path.exists(caminho_indice_trec_covid_sem_expansao)

True

In [None]:
!wget https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/queries.jsonl.gz

# Baixando os dados e preparando para avaliação 

## Queries

In [89]:
from pyserini.search import get_topics

In [94]:
topics = get_topics('covid-round5')
print(f'{len(topics)} queries total')

50 queries total


In [101]:
topics[1], queries[0]

({'question': 'what is the origin of COVID-19',
  'query': 'coronavirus origin',
  'narrative': "seeking range of information about the SARS-CoV-2 virus's origin, including its evolution, animal source, and first transmission into humans"},
 {'_id': '1',
  'text': 'what is the origin of COVID-19',
  'metadata': {'query': 'coronavirus origin',
   'narrative': "seeking range of information about the SARS-CoV-2 virus's origin, including its evolution, animal source, and first transmission into humans"}})

In [102]:
topics[50], queries[49]

({'question': 'what is known about an mRNA vaccine for the SARS-CoV-2 virus?',
  'query': 'mRNA vaccine coronavirus',
  'narrative': 'Looking for studies specifically focusing on mRNA vaccines for COVID-19, including how mRNA vaccines work, why they are promising, and any results from actual clinical studies.'},
 {'_id': '50',
  'text': 'what is known about an mRNA vaccine for the SARS-CoV-2 virus?',
  'metadata': {'query': 'mRNA vaccine coronavirus',
   'narrative': 'Looking for studies specifically focusing on mRNA vaccines for COVID-19, including how mRNA vaccines work, why they are promising, and any results from actual clinical studies.'}})

## qrel de teste

In [29]:
!wget https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv

--2023-04-09 19:14:15--  https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv
Resolving huggingface.co (huggingface.co)... 52.85.213.6, 52.85.213.73, 52.85.213.2, ...
Connecting to huggingface.co (huggingface.co)|52.85.213.6|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 980831 (958K) [text/plain]
Saving to: ‘test.tsv’


2023-04-09 19:14:16 (9,56 MB/s) - ‘test.tsv’ saved [980831/980831]



In [30]:
!mv test.tsv {DIRETORIO_TRABALHO}/

In [32]:
qrel = pd.read_csv(f"{DIRETORIO_TRABALHO}/test.tsv", sep="\t", header=None, 
                   skiprows=1, names=["query", "docid", "rel"])

In [33]:
qrel.head()

Unnamed: 0,query,docid,rel
0,1,005b2j4b,2
1,1,00fmeepz,1
2,1,g7dhmyyo,2
3,1,0194oljo,1
4,1,021q9884,1


In [42]:
from tqdm import tqdm

In [112]:
# Run all queries in topics, retrive top 1k for each query
def run_all_queries(file, topics, searcher, num_max_hits=100):
  """
  A função run_all_queries é responsável por realizar todas as consultas armazenadas no dicionário topics utilizando o objeto searcher fornecido e salvar os resultados em um arquivo de texto.
  Usada no notebook da aula 2

  Parâmetros:

  file: caminho do arquivo de saída onde serão salvos os resultados das consultas.
  topics: dicionário contendo as consultas a serem executadas. Cada consulta é representada por uma chave única no dicionário. O valor correspondente a cada chave é um outro dicionário contendo as informações da consulta, como seu título e outras informações relevantes.
  searcher: objeto do tipo Searcher que será utilizado para realizar as consultas.
  num_max_hits: número máximo de documentos relevantes que serão retornados para cada consulta.
  Retorno:

  A função não retorna nenhum valor, mas salva os resultados das consultas no arquivo especificado em file.
  Comentário:

  A função usa a biblioteca tqdm para exibir uma barra de progresso enquanto executa as consultas.
  O número de consultas concluídas é impresso a cada 100 consultas.
  """
  print(f'Running {len(topics)} queries in total')
  with open(file, 'w') as runfile:
    cnt = 0
    for id in tqdm(topics, desc='Running Queries'):
        # print(f'id = {id}')
        query = topics[id]['question']
        # print(f'query = {query}')

        hits = searcher.search(query, num_max_hits)
        for i in range(0, len(hits)):
            _ = runfile.write(f'{id} Q0 {hits[i].docid} {i+1} {hits[i].score:.6f} SemExpansao\n')
            # = runfile.write('{} Q0 {} {} {:.6f} Pyserini\n'.format(id, hits[i].docid, i+1, hits[i].score))
        cnt += 1
        if cnt % 100 == 0:
            print(f'{cnt} queries completed')


In [None]:
searcher = LuceneSearcher(caminho_indice_trec_covid_sem_expansao) # './indexes/lucene-index-msmarco-passage')

In [85]:
searcher.set_bm25(k1=0.82, b=0.68)  

In [86]:
path_run = f"{DIRETORIO_TRABALHO}/runs"
path_run_sem_expansao = path_run + '/run-trec-covid-bm25.txt'

In [87]:
path_run, path_run_sem_expansao

('/home/borela/fontes/deep_learning_em_buscas_unicamp/local/doc2query/runs',
 '/home/borela/fontes/deep_learning_em_buscas_unicamp/local/doc2query/runs/run-trec-covid-bm25.txt')

In [None]:
%%time
if not os.path.exists(path_run):
  os.makedirs(path_run)
  print('pasta criada!')
else:
  print('pasta já existia!')


pasta criada!
CPU times: user 549 µs, sys: 154 µs, total: 703 µs
Wall time: 625 µs


In [None]:
num_max_hits = 1000

In [113]:
run_all_queries(path_run_sem_expansao, topics, searcher, num_max_hits)

Running 50 queries in total


Running Queries: 100%|██████████| 50/50 [00:02<00:00, 22.48it/s]


In [None]:
if not os.path.exists(path_run_sem_expansao):
  # roda 1o estágio de busca por bm25
  # code from https://colab.research.google.com/github/castorini/anserini-notebooks/blob/master/pyserini_msmarco_passage_demo.ipynb

  LuceneSearcher.set_bm25(k1=0.82, b=0.68)  
  run_all_queries(path_run_sem_expansao, queries, LuceneSearcher, num_max_hits)
  
  print("Dados estágio 1 (bm25) carregados!")
else:
  print("Dados estágio 1 (bm25) já existiam!")    

Dados estágio 1 (bm25) já existiam!


In [None]:
assert os.path.exists(path_run_estagio1), f"Pasta {path_run_estagio1} não criada!"

In [114]:
!head {path_run_sem_expansao}

44 Q0 xfjexm5b 1 12.562400 SemExpansao
44 Q0 28utunid 2 11.335000 SemExpansao
44 Q0 qi1henyy 3 11.334999 SemExpansao
44 Q0 qp77vl6h 4 11.273800 SemExpansao
44 Q0 ugkxxaeb 5 11.187800 SemExpansao
44 Q0 ej76fsxa 6 10.961200 SemExpansao
44 Q0 2r0a357c 7 10.691000 SemExpansao
44 Q0 d8eqifvv 8 10.688500 SemExpansao
44 Q0 4cmeglm3 9 10.485900 SemExpansao
44 Q0 pklvvgd3 10 10.485899 SemExpansao


# Calculando métricas

In [118]:
run = pd.read_csv(f"{path_run_sem_expansao}", sep="\s+", 
                  names=["query", "q0", "docid", "rank", "score", "system"])

In [119]:
run.head()

Unnamed: 0,query,q0,docid,rank,score,system
0,44,Q0,xfjexm5b,1,12.5624,SemExpansao
1,44,Q0,28utunid,2,11.335,SemExpansao
2,44,Q0,qi1henyy,3,11.334999,SemExpansao
3,44,Q0,qp77vl6h,4,11.2738,SemExpansao
4,44,Q0,ugkxxaeb,5,11.1878,SemExpansao


In [120]:
run = run.to_dict(orient="list")

In [131]:
run['query'][0], run['docid'][0], run['rank'][0]

(44, 'xfjexm5b', 1)

In [116]:
from evaluate import load

In [117]:
trec_eval = load("trec_eval")

Downloading builder script: 100%|██████████| 5.51k/5.51k [00:00<00:00, 8.88MB/s]


In [127]:
qrel.head()

Unnamed: 0,query,docid,rel
0,1,005b2j4b,2
1,1,00fmeepz,1
2,1,g7dhmyyo,2
3,1,0194oljo,1
4,1,021q9884,1


In [128]:
qrel["q0"] = "q0"
qrel = qrel.to_dict(orient="list")

In [133]:
qrel['query'][0], qrel['docid'][0], qrel['rel'][0]

(1, '005b2j4b', 2)

In [134]:
results = trec_eval.compute(predictions=[run], references=[qrel])

In [135]:
results

{'runid': 'SemExpansao',
 'num_ret': 50000,
 'num_rel': 24673,
 'num_rel_ret': 9608,
 'num_q': 50,
 'map': 0.1879641632534631,
 'gm_map': 0.12133101699875738,
 'bpref': 0.33286628496980336,
 'Rprec': 0.2838796736774716,
 'recip_rank': 0.8396666666666667,
 'P@5': 0.6920000000000001,
 'P@10': 0.64,
 'P@15': 0.5973333333333334,
 'P@20': 0.583,
 'P@30': 0.5606666666666666,
 'P@100': 0.4671999999999999,
 'P@200': 0.3922,
 'P@500': 0.27384000000000003,
 'P@1000': 0.19215999999999997,
 'NDCG@5': 0.6342218733576787,
 'NDCG@10': 0.5963435398557583,
 'NDCG@15': 0.5626944780603168,
 'NDCG@20': 0.5439544691895091,
 'NDCG@30': 0.515655668693321,
 'NDCG@100': 0.43453518414003445,
 'NDCG@200': 0.3796410593597176,
 'NDCG@500': 0.3537872525591422,
 'NDCG@1000': 0.4064338720626243}