In [1]:
import chromadb
import polars
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction as SentenceTransformer
from tqdm.auto import tqdm
from utils import TaggedVector

class pregunta(str):
    def __new__(cls, pregunta, qid):
        obj = super().__new__(cls, pregunta)
        obj.qid = qid
        return obj

    def __init__(self, pregunta, qid): ...

    @classmethod
    def from_dict(self, x: dict[str, str]):
        return pregunta(**x)

    def __repr__(self):
        return f"pregunta(content='{self!s}', qid={self.qid!r})"

In [2]:
client = chromadb.PersistentClient('persist/')

In [3]:
client.list_collections()

[Collection(name=imf_publications_arg)]

In [4]:
print('Colecciones en la base:')
for collection in client.list_collections():
    print(f"{collection.name} : {collection.count()} vectores.")

Colecciones en la base:
imf_publications_arg : 10446 vectores.


In [5]:
arg = client.get_collection('imf_publications_arg', embedding_function=SentenceTransformer('sentence-transformers/all-mpnet-base-v2'))

In [6]:
preguntas = polars.read_csv('../data/preguntas_clean_arg.csv')['qid', 'pregunta'].unique(maintain_order=True)
preguntas = map(pregunta.from_dict, preguntas.to_dicts())
preguntas = list(preguntas)

In [None]:
N = int(round(arg.count() * .65)) # 65% de todos los chunks

for pregunta in tqdm(preguntas):
    relevant_chunks = arg.query(
        query_texts=pregunta, 
        n_results=N,
        include=['embeddings', 'documents', 'distances']
    )

    file_basename = f'{pregunta.qid}_embeddings'
    [documents]     = relevant_chunks['documents']
    [ids]           = relevant_chunks['ids']
    [distances]     = relevant_chunks['distances']
    [embeddings]    = relevant_chunks['embeddings']

    tvs = map(TaggedVector.for_question(pregunta.qid), zip(ids, documents, distances, embeddings))
    polars.DataFrame([x.as_record() for x in tqdm(list(tvs))]).sort('distance', descending=True).write_parquet('../data/'+file_basename+'.parquet')

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/6790 [00:00<?, ?it/s]

  0%|          | 0/6790 [00:00<?, ?it/s]

  0%|          | 0/6790 [00:00<?, ?it/s]

  0%|          | 0/6790 [00:00<?, ?it/s]

  0%|          | 0/6790 [00:00<?, ?it/s]

  0%|          | 0/6790 [00:00<?, ?it/s]

  0%|          | 0/6790 [00:00<?, ?it/s]

  0%|          | 0/6790 [00:00<?, ?it/s]

  0%|          | 0/6790 [00:00<?, ?it/s]

  0%|          | 0/6790 [00:00<?, ?it/s]