Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse query vector #63

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion beir/retrieval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .use_qa import UseQA
from .sparta import SPARTA
from .dpr import DPR
from .bpr import BinarySentenceBERT
from .bpr import BinarySentenceBERT
from .splade import SPLADE
13 changes: 8 additions & 5 deletions beir/retrieval/models/sparta.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ def _compute_sparse_embeddings(self, documents):
return sparse_embeddings

def encode_query(self, query: str, **kwargs):
return self.tokenizer(query, add_special_tokens=False)['input_ids']

col = self.tokenizer(query, add_special_tokens=False)['input_ids']
row = [0]*len(col)
data = [1]*len(col)
return csr_matrix((data, (row, col)), shape=(1, len(self.bert_input_embeddings)), dtype=np.float)

def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 16, **kwargs):

sentences = [(doc["title"] + self.sep + doc["text"]).strip() for doc in corpus]
Expand All @@ -69,9 +72,9 @@ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 16, **kw
doc_embs = self._compute_sparse_embeddings(sentences[start_idx: start_idx + batch_size])
for doc_id, emb in enumerate(doc_embs):
for tid, score in emb:
col[sparse_idx] = start_idx+doc_id
row[sparse_idx] = tid
col[sparse_idx] = tid
row[sparse_idx] = start_idx+doc_id
values[sparse_idx] = score
sparse_idx += 1

return csr_matrix((values, (row, col)), shape=(len(self.bert_input_embeddings), len(sentences)), dtype=np.float)
return csr_matrix((values, (row, col)), shape=(len(sentences), len(self.bert_input_embeddings)), dtype=np.float)
54 changes: 54 additions & 0 deletions beir/retrieval/models/splade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import List, Dict

import array
import tqdm
import torch
import numpy as np
import transformers
from scipy import sparse


class SPLADE:
def __init__(self, model_name_or_path, max_length=256):
self.model = transformers.AutoModelForMaskedLM.from_pretrained(model_name_or_path)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.max_length = max_length

def encode(self, text):
inputs = self.tokenizer(text, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
token_embeddings = outputs[0]
attention_mask = inputs["attention_mask"]
sentence_embedding = torch.max(torch.log(1 + torch.relu(token_embeddings)) * attention_mask.unsqueeze(-1), dim=1).values
return sentence_embedding.cpu().numpy()

def encode_query(self, query: str, **kwargs) -> sparse.csr_matrix:
""" returns a csr_matrix of shape [1, n_vocab] """
output = self.encode(query)
return sparse.csr_matrix(output)

def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, is_queries=False, **kwargs) -> sparse.csr_matrix:
""" returns a csr_matrix of shape [n_documents, n_vocab] """
# https://maciejkula.github.io/2015/02/22/incremental-construction-of-sparse-matrices/
indices = array.array("i")
indptr = array.array("i")
data = array.array("f")
sentences = [(doc["title"] + " " + doc["text"]).strip() for doc in corpus]
indptr.append(0)
last_indptr = 0
for i in tqdm.tqdm(range(0, len(sentences), batch_size), desc="encode_corpus"):
batch = sentences[i:i+batch_size]
dense = self.encode(batch)
nz_rows, nz_cols = np.nonzero(dense)
nz_values = dense[(nz_rows, nz_cols)]
data.extend(nz_values)
local_indptr = np.bincount(nz_rows).cumsum() + last_indptr
indptr.extend(local_indptr)
indices.extend(nz_cols)
last_indptr = local_indptr[-1]
shape = (len(corpus), self.model.config.vocab_size)
results = sparse.csr_matrix((data, indices, indptr), shape=shape, dtype=np.float)
return results
26 changes: 16 additions & 10 deletions beir/retrieval/search/sparse/sparse_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Dict, Union, Tuple
import logging
import numpy as np
import torch

logger = logging.getLogger(__name__)

Expand All @@ -16,22 +17,27 @@ def __init__(self, model, batch_size: int = 16, **kwargs):
def search(self,
corpus: Dict[str, Dict[str, str]],
queries: Dict[str, str],
top_k: int, *args, **kwargs) -> Dict[str, Dict[str, float]]:
top_k: int, *args, **kwargs
) -> Dict[str, Dict[str, float]]:

doc_ids = list(corpus.keys())
query_ids = list(queries.keys())
documents = [corpus[doc_id] for doc_id in doc_ids]
logging.info("Computing document embeddings and creating sparse matrix")
self.sparse_matrix = self.model.encode_corpus(documents, batch_size=self.batch_size)

self.sparse_matrix_doc = self.model.encode_corpus(documents, batch_size=self.batch_size) # [n_doc, n_voc]
logging.info("Starting to Retrieve...")
for start_idx in trange(0, len(queries), self.batch_size, desc='query'):
qid = query_ids[start_idx]
query_tokens = self.model.encode_query(queries[qid])
#Get the candidate passages
scores = np.asarray(self.sparse_matrix[query_tokens, :].sum(axis=0)).squeeze(0)
top_k_ind = np.argpartition(scores, -top_k)[-top_k:]
self.results[qid] = {doc_ids[pid]: float(scores[pid]) for pid in top_k_ind}

local_query_ids = query_ids[start_idx:start_idx+self.batch_size]
local_queries = [queries[qid] for qid in local_query_ids]
qry_matrix = self.model.encode_query(local_queries)
scores = self.sparse_matrix_doc.dot(qry_matrix.transpose()).todense() # [n_doc, vocab]x[vocab, n_qry] -> [n_doc, n_qry]
scores = torch.from_numpy(scores) # [n_qry, n_doc]
top_k_values, top_k_indices = torch.topk(scores, top_k, dim=0, sorted=False)
top_k_values = top_k_values.transpose(0, 1).tolist() # [n_qry, top_k]
top_k_indices = top_k_indices.transpose(0, 1).tolist() # [n_qry, top_k]
for i, qid in enumerate(local_query_ids):
k_ind = top_k_indices[i]
k_val = top_k_values[i]
self.results[qid] = {doc_ids[pid]: score for pid, score in zip(k_ind, k_val) if doc_ids[pid] != qid}
return self.results

64 changes: 64 additions & 0 deletions examples/retrieval/evaluation/sparse/evaluate_splade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.sparse import SparseSearch

import logging
import pathlib, os
import random
import shutil

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout

dataset = "arguana"

#### Download scifact dataset and unzip the dataset
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
data_path = util.download_and_unzip(url, out_dir)

#### Provide the data path where scifact has been downloaded and unzipped to the data loader
# data folder would contain these files:
# (1) scifact/corpus.jsonl (format: jsonlines)
# (2) scifact/queries.jsonl (format: jsonlines)
# (3) scifact/qrels/test.tsv (format: tsv ("\t"))

corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

#### Sparse Retrieval using SPLADE ####
url = "https://download-de.europe.naverlabs.com/Splade_Release_Jan22/distilsplade_max.tar.gz"
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "weights")
os.makedirs(out_dir, exist_ok=True)
filename = os.path.join(out_dir, "splade.tar.gz")
model_dir = os.path.join(out_dir, "distilsplade_max")
if not os.path.exists(model_dir):
util.download_url(url, filename)
shutil.unpack_archive(filename, out_dir)
sparse_model = SparseSearch(models.SPLADE(model_dir, max_length=256), batch_size=48)
retriever = EvaluateRetrieval(sparse_model)

#### Retrieve dense results (format of results is identical to qrels)
results = retriever.retrieve(corpus, queries)

#### Evaluate your retrieval using NDCG@k, MAP@K ...

logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)

#### Print top-k documents retrieved ####
top_k = 10

query_id, ranking_scores = random.choice(list(results.items()))
scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
logging.info("Query : %s\n" % queries[query_id])

# for rank in range(top_k):
# doc_id = scores_sorted[rank][0]
# # Format: Rank x: ID [Title] Body
# logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))