In [1]:
import zarr
import numpy as np
from pathlib import Path
import pandas as pd
import numcodecs

In [2]:
data_dir = Path('../msmarco-passages/')

def get_index(target, source):
    sorter = np.argsort(source)
    return sorter[np.searchsorted(source, target, sorter=sorter)]

In [3]:
doc = zarr.open(str(data_dir / "docs.all.zarr"))
query = zarr.open(str(data_dir / 'queries.all.zarr'))

# Eval (run)

In [63]:
df = pd.read_csv(data_dir / 'runs' / 'run.msmarco-passage.dev.small.tsv', names=['q_id', 'doc_id', "rank"], sep='\t')
q_ids = df.q_id.unique()
doc_ids = df.doc_id.unique()

## Query

In [76]:
q_idx = get_index(q_ids, query.id)

In [79]:
z = zarr.open("queries.eval.zarr", mode="w")
z.array("id", data=query.id.oindex[q_idx], chunks=(256), overwrite=True)
z.array("text", data=query.text.oindex[q_idx], chunks=(256), object_codec=numcodecs.VLenUTF8(), overwrite=True)

<zarr.core.Array '/text' (6980,) object>

## Passage

In [77]:
doc_idx = get_index(doc_ids, doc.id)

In [78]:
z = zarr.open("docs.eval.zarr", mode="w")
z.array("id", data=doc.id.oindex[doc_idx], chunks=(256), overwrite=True)
z.array("text", data=doc.text.oindex[doc_idx], chunks=(256), object_codec=numcodecs.VLenUTF8(), overwrite=True)

<zarr.core.Array '/text' (3823977,) object>

---

# Passage

In [19]:
train_ids = np.unique(zarr.open(str(data_dir / 'triples.train.zarr/'))[:, 1:].flatten())
train_ids = sorted(list(set(train_ids)))
val_ids = np.unique(zarr.open(str(data_dir / 'triples.val.zarr/'))[:, 1:].flatten())
val_ids = sorted(list(set(val_ids)))

In [20]:
train_idx = get_index(train_ids, doc.id)
val_idx = get_index(val_ids, doc.id)

In [21]:
len(train_idx), len(val_idx)

(1171509, 393267)

## Save

In [22]:
indices = train_idx
z = zarr.open("docs.train.zarr", mode="w")
z.array("id", data=doc.id.oindex[indices], chunks=(256), overwrite=True)
z.array("text", data=doc.text.oindex[indices], chunks=(256), object_codec=numcodecs.VLenUTF8(), overwrite=True)

<zarr.core.Array '/text' (1171509,) object>

In [23]:
indices = val_idx
z_train = zarr.open("docs.val.zarr", mode="w")
z_train.array("id", data=doc.id.oindex[indices], chunks=(256), overwrite=True)
z_train.array("text", data=doc.text.oindex[indices], chunks=(256), object_codec=numcodecs.VLenUTF8(), overwrite=True)

<zarr.core.Array '/text' (393267,) object>

# Query

In [11]:
train_ids = zarr.open(str(data_dir / 'triples.train.zarr/'))[:, 0]
train_ids = sorted(list(set(train_ids)))
val_ids = zarr.open(str(data_dir / 'triples.val.zarr/'))[:, 0]
val_ids = sorted(list(set(val_ids)))

In [14]:
train_idx = get_index(train_ids, query.id)
val_idx = get_index(val_ids, query.id)

## Save

In [17]:
z = zarr.open("queries.train.zarr", mode="w")
z.array("id", data=query.id.oindex[train_idx], chunks=(256), overwrite=True)
z.array("text", data=query.text.oindex[train_idx], chunks=(256), object_codec=numcodecs.VLenUTF8(), overwrite=True)

<zarr.core.Array '/text' (1512,) object>

In [18]:
z = zarr.open("queries.val.zarr", mode="w")
z.array("id", data=query.id.oindex[val_idx], chunks=(256), overwrite=True)
z.array("text", data=query.text.oindex[val_idx], chunks=(256), object_codec=numcodecs.VLenUTF8(), overwrite=True)

<zarr.core.Array '/text' (1512,) object>