In [17]:
import os
import sys
import json
import numpy as np
import pandas as pd
import torch
from scipy.sparse import csr_matrix
from collections import Counter, defaultdict
from transformers import AutoTokenizer, BasicTokenizer, AutoModelForMaskedLM
from sentence_transformers import SentenceTransformer
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.evaluation import EvaluateRetrieval
from tqdm.notebook import tqdm

sys.path.append(os.pardir)
from splade_vocab.models  import Splade, BEIRSpladeModel, BEIRSpladeModelIDF

In [24]:
data_path = "/home/gaia_data/iida.h/BEIR/datasets/trec-covid"
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

  0%|          | 0/171332 [00:00<?, ?it/s]

In [25]:
model_pathes = {"mlm-splade-71694": "/home/gaia_data/iida.h/BEIR/model/pubmed_abst/bert-base-uncased/splade_model/raw/remove/71694/distilSplade_0.1_0.08_-groups-gcb50243-iida.h-BEIR-model-pubmed_abst-bert-base-uncased-mlm_model-raw-remove--71694-batch_size_40-2022-04-12_08-52-34/",
               "mlm-splade-30522": "/home/gaia_data/iida.h/BEIR/model/pubmed_abst/bert-base-uncased/splade_model/raw/remove/30522/distilSplade_0.1_0.08_-groups-gcb50243-iida.h-BEIR-model-pubmed_abst-bert-base-uncased-mlm_model-raw-remove--30522-batch_size_24-2022-04-11_23-21-18/",
                "splade-71694": "/home/gaia_data/iida.h/BEIR/model/pubmed_abst/bert-base-uncased/splade_model_init/raw/remove/71694/distilSplade_0.1_0.08_-groups-gcb50243-iida.h-BEIR-model-pubmed_abst-bert-base-uncased-init_model-raw-remove--71694-batch_size_40-2022-04-24_00-46-31/",
               "splade": "/home/gaia_data/iida.h/BEIR/model/msmarco/splade/distilSplade_0.1_0.08_bert-base-uncased-batch_size_40-2022-05-01_12-37-20"}

In [26]:
device = torch.device("cuda")
mlm_splade = Splade(model_pathes["mlm-splade-71694"]).to(device)

path /home/gaia_data/iida.h/BEIR/model/pubmed_abst/bert-base-uncased/splade_model/raw/remove/71694/distilSplade_0.1_0.08_-groups-gcb50243-iida.h-BEIR-model-pubmed_abst-bert-base-uncased-mlm_model-raw-remove--71694-batch_size_40-2022-04-12_08-52-34/


In [37]:
all_doc = []
for cid in tqdm(corpus):
    text = corpus[cid]["title"] + " " + corpus[cid]["text"]
    t_text = mlm_splade.tokenizer(text, max_length=512, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        t_text = {k: v.to(device) for k, v in t_text.items()}
        e_text = mlm_splade.encode(**t_text)
        all_doc.append(e_text.cpu().numpy())

  0%|          | 0/171332 [00:00<?, ?it/s]

In [38]:
sp_all_doc = np.vstack(all_doc)
sp_all_doc.shape

(171332, 71694)

In [39]:
sp_all_doc = csr_matrix(sp_all_doc)

In [40]:
(sp_all_doc.data.nbytes + sp_all_doc.indptr.nbytes + sp_all_doc.indices.nbytes) / 1024 / 1024

329.5795478820801

In [41]:
dense_model = SentenceTransformer("sentence-transformers/msmarco-bert-base-dot-v5")

In [42]:
all_doc_dense = []
for cid in tqdm(corpus):
    text = corpus[cid]["title"] + " " + corpus[cid]["text"]
    vec = dense_model.encode(text)
    all_doc_dense.append(vec)

  0%|          | 0/171332 [00:00<?, ?it/s]

In [43]:
all_doc_dense = np.vstack(all_doc_dense)
all_doc_dense.nbytes / 1024 / 1024

501.94921875