In [1]:
from sentence_transformers import CrossEncoder
from tqdm.auto import tqdm
import polars
import os

preguntas = polars.read_csv('../data/preguntas_clean_arg.csv')

p_embeddings_paths = dict()
for qid in preguntas['qid'].unique(maintain_order=True):
    path = f'../data/{qid}_embeddings.parquet'
    assert os.path.exists(path) and os.path.isfile(path)
    p_embeddings_paths[qid] = path

  from tqdm.autonotebook import tqdm, trange


In [2]:
reranker_model = CrossEncoder(model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', device='cuda')

In [3]:
absolute_max = -1
absolute_min = float('inf')

for pregunta in tqdm(preguntas['qid'].unique(maintain_order=True)):
  pregunta_embeddings_path = p_embeddings_paths[pregunta]
  df = polars.read_parquet(pregunta_embeddings_path)

  pairs = (
      df.join(preguntas['qid', 'pregunta'], how='left', left_on='distance_to', right_on='qid')
        .select('pregunta', 'text')
      )

  scores = reranker_model.predict(
      sentences = list(pairs.iter_rows()),
      batch_size = 100,
      show_progress_bar = True
  )

  absolute_max = max(absolute_max, max(scores))
  absolute_min = min(absolute_min, min(scores))

  (df
    .with_columns(score = scores)
    .select('vector_id', 'distance_to' ,'score')
    .rename(dict(distance_to='score_to'))
    .sort('score', descending=True)
    .write_parquet(f'../data/{pregunta}_relevance_scores.parquet')
  )

  0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/68 [00:00<?, ?it/s]

Batches:   0%|          | 0/68 [00:00<?, ?it/s]

Batches:   0%|          | 0/68 [00:00<?, ?it/s]

Batches:   0%|          | 0/68 [00:00<?, ?it/s]

Batches:   0%|          | 0/68 [00:00<?, ?it/s]

Batches:   0%|          | 0/68 [00:00<?, ?it/s]

Batches:   0%|          | 0/68 [00:00<?, ?it/s]

Batches:   0%|          | 0/68 [00:00<?, ?it/s]

Batches:   0%|          | 0/68 [00:00<?, ?it/s]

Batches:   0%|          | 0/68 [00:00<?, ?it/s]

In [13]:
from fundar import json
import numpy as np

class NumpyJsonEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, np.floating):
            return float(o)

json.dump(
    dict(
        absolute_max=absolute_max,
        absolute_min=absolute_min
    ),
    '../data/normalizer_absolute_extremes.json',
    cls=NumpyJsonEncoder
)