In [1]:
import jax
import jax.numpy as jnp
from jax.experimental import sparse

from tqdm import tqdm
import json
import os
import jsonpickle


from pyserini.index.lucene import IndexReader
from pyserini.analysis import Analyzer, get_lucene_analyzer

from spare.text2vec import BagOfWords

from functools import partial
#os.environ["CUDA_VISIBLE_DEVICES"]=""

In [2]:
def scatter(input, dim, index, src, reduce=None):
   # Works like PyTorch's scatter. See https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html
   
   dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))
   
   if reduce is None:
       _scatter = jax.lax.scatter
   elif reduce == "add":
       _scatter = jax.lax.scatter_add
   elif reduce == "multiply":
       _scatter = jax.lax.scatter_mul
       
   _scatter = partial(_scatter, dimension_numbers=dnums)
   vmap_inner = partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
   vmap_outer = partial(jax.vmap, in_axes=(1, 1, 1), out_axes=1)

   for idx in range(len(input.shape)):
       if idx == dim:
           pass
       elif idx < dim:
           _scatter = vmap_inner(_scatter)
       else:
           _scatter = vmap_outer(_scatter)
           
   return _scatter(input, jnp.expand_dims(index, axis=-1), src)



In [3]:


index_reader = IndexReader(f"beir_datasets/nq/anserini_index")
analyzer = Analyzer(get_lucene_analyzer())

print("build token2id dict")
token2id = {term.term:i for i,term in enumerate(index_reader.terms())}

def tokenizer(text):
    tokens_ids = []
    for token in analyzer.analyze(text.lower()):
        #token_id=token2id[token]
        if token in token2id:
            tokens_ids.append(token2id[token])
        #if token in token2id:
        #    token_id=token2id[token]
        #    if token_id is not None:
        #        tokens_ids.append(token_id)
    return tokens_ids

vocab_size = len(token2id)

bow = BagOfWords(tokenizer, vocab_size)

def text2vec(text):
    b = bow(text)
    
    #dense_vec = jnp.zeros((bow.dim,1))
    
    return scatter(jnp.zeros((bow.dim,)), 0, jnp.array(list(b.keys())), jnp.array(list(b.values()), dtype=jnp.float32))


build token2id dict


In [4]:

with open("beir_datasets/nq/csr_anserini_bm25_12_075/class_info.jsonpickle") as f:
    class_vars = jsonpickle.decode(f.read())
shape = tuple(class_vars.pop("shape"))
devices = jax.devices()

In [5]:
from safetensors import safe_open

tensors = {}
with safe_open("beir_datasets/nq/csr_anserini_bm25_12_075/tensors.safetensors", framework="jax") as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

2023-10-15 23:11:35.369732: W external/xla/xla/service/gpu/nvptx_compiler.cc:703] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.2.140). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
print(shape)
csr_matrix = sparse.BCSR((tensors["vec_2"], tensors["vec_1"], tensors["vec_0"] ), shape=shape)

(2681468, 997027)


In [7]:
vec = text2vec("hello, my name is tiago")
vec


Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)

In [8]:
@jax.jit
def scores(csr_matrix, vec):
    return jax.lax.top_k(sparse.csr_matvec(csr_matrix, vec), k=1000)

In [9]:
values, indicies = scores(csr_matrix, vec)

In [10]:
with open(f"beir_datasets/msmarco/relevant_pairs.jsonl") as f:
    questions = list({line["question"] for line in map(json.loads, f)})

In [None]:
def vec_coo_todense_wscatter(bow, vocab_size):

  dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))
  return jax.lax.scatter(jnp.zeros((vocab_size,)), jnp.expand_dims(list(bow.keys()), axis=-1), list(bow.values()), dimension_numbers=dnums)


In [12]:
import time

timer = 0

for question in tqdm(questions):
    
    q_vec = text2vec(question)
    
    s = time.time()
    scores(csr_matrix, q_vec)
    timer += time.time()-s

len(questions)/timer

  0%|          | 0/6980 [00:00<?, ?it/s]


AttributeError: 'list' object has no attribute 'block_until_ready'

In [44]:
indices = [1,3,5]
indices = jnp.array([indices,[1]*len(indices)]).T
data = jnp.array([20,1,1], dtype=jnp.float32)
print(indices.shape, data.shape)
vec = sparse.BCOO((data, indices), shape=(6,1))

(3, 2) (3,)


In [49]:


#jax.lax.scatter(jnp.zeros(6,), jnp.array([1,3,5]), jnp.array([20,1,1], dtype=jnp.float32), 0)

AttributeError: 'int' object has no attribute 'update_window_dims'

In [39]:
vec.data

Array([20.,  1.,  1.], dtype=float32)