In [1]:
from collections import Counter

import requests
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
from transformers import BertTokenizerFast
from typing import List 


  from .autonotebook import tqdm as notebook_tqdm


In [9]:
pubmed = load_dataset("pubmed_qa", "pqa_labeled", split="train")
pubmed

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 1000
})

In [14]:
CHAR_LIMIT = 384

def chunker(contexts: List[str]):
    chunks = []
    chunk = []
    all_contexts = ' '.join(contexts).split('.')
    for context in all_contexts:
        chunk.append(context)
        if len(chunk) >= 3 and len('.'.join(chunk)) > CHAR_LIMIT:
            # surpassed limit so add to chunks and reset
            chunks.append('.'.join(chunk).strip()+'.')
            # add some overlap between passages
            chunk = chunk[-2:]
    # if we finish and still have a chunk, add it
    if chunk is not None:
        chunks.append('.'.join(chunk))
    return chunks

chunks = chunker(pubmed[0]['context']['contexts'])
chunks

['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature.',
 'The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants. The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A.',
 'The role of mitochondria during PCD has been recognized in animals; however, it has been less

In [15]:
pubmed[0]['context']['contexts']

['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.',
 'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD). Window stage leaves were stained with the mitochondrial dye MitoTracker Red CMX

In [16]:
ids = []
for i in range(len(chunks)):
    ids.append(f"{pubmed[0]['pubid']}-{i}")
ids

['21645374-0',
 '21645374-1',
 '21645374-2',
 '21645374-3',
 '21645374-4',
 '21645374-5',
 '21645374-6']

In [17]:
data = []
for record in pubmed:
    chunks = chunker(record['context']['contexts'])
    for i, context in enumerate(chunks):
        data.append({
            'id': f"{record['pubid']}-{i}",
            'context': context
        })

In [19]:
type(data[0:3])

list

In [7]:
device = "cpu"
dense_model = SentenceTransformer(
    'msmarco-bert-base-dot-v5',
    device=device
)

In [8]:
emb = dense_model.encode(data[0]['context'])
emb.shape

dim = dense_model.get_sentence_embedding_dimension()
dim

768

In [9]:
from splade.models.transformer_rep import Splade

sparse_model_id = 'naver/splade-cocondenser-ensembledistil'

sparse_model = Splade(sparse_model_id, agg='max')
sparse_model.to(device)  # move to GPU if possible
sparse_model.eval()

Downloading: 100%|██████████| 670/670 [00:00<00:00, 1.43MB/s]
Downloading: 100%|██████████| 418M/418M [00:17<00:00, 24.5MB/s] 
Downloading: 100%|██████████| 466/466 [00:00<00:00, 3.67MB/s]
Downloading: 100%|██████████| 226k/226k [00:00<00:00, 4.03MB/s]
Downloading: 100%|██████████| 455k/455k [00:00<00:00, 4.01MB/s]
Downloading: 100%|██████████| 112/112 [00:00<00:00, 666kB/s]


In [20]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)

# tokens = tokenizer(data[0]['context'], return_tensors='pt')

In [22]:
tokenizer.get_vocab().items()



In [11]:
import torch 

with torch.no_grad():
    sparse_emb = sparse_model(
        d_kwargs=tokens.to(device)
    )['d_rep'].squeeze()
sparse_emb.shape



torch.Size([30522])

In [12]:
indices = sparse_emb.nonzero().squeeze().cpu().tolist()
print(len(indices))

174


In [13]:
values = sparse_emb[indices].cpu().tolist()
sparse = {'indices': indices, 'values': values}
sparse

{'indices': [1000,
  1039,
  1052,
  1997,
  1999,
  2003,
  2024,
  2049,
  2083,
  2094,
  2173,
  2239,
  2278,
  2290,
  2306,
  2331,
  2415,
  2427,
  2523,
  2537,
  2550,
  2565,
  2566,
  2597,
  2644,
  2754,
  2757,
  2832,
  2974,
  3030,
  3081,
  3102,
  3252,
  3269,
  3274,
  3280,
  3370,
  3392,
  3399,
  3508,
  3526,
  3571,
  3581,
  3628,
  3727,
  3740,
  3817,
  3965,
  3968,
  4264,
  4295,
  4372,
  4442,
  4456,
  4574,
  4649,
  4717,
  4730,
  4758,
  4775,
  4870,
  4962,
  4963,
  5080,
  5104,
  5258,
  5397,
  5701,
  5708,
  5920,
  5996,
  6198,
  6210,
  6215,
  6310,
  6418,
  6470,
  6531,
  6546,
  6580,
  6897,
  7053,
  7337,
  7366,
  7403,
  7473,
  7609,
  7691,
  7775,
  7816,
  8475,
  8676,
  8715,
  8761,
  8765,
  8872,
  8979,
  9007,
  9232,
  9448,
  9607,
  9706,
  9890,
  9895,
  9915,
  10012,
  10088,
  10244,
  10267,
  10327,
  10507,
  10708,
  10738,
  11503,
  11568,
  11704,
  11767,
  11798,
  11829,
  11934,
  12222,
  124

In [14]:
idx2token = {idx: token for token, idx in tokenizer.get_vocab().items()}
sparse_dict_tokens = {
    idx2token[idx]: round(weight, 2) for idx, weight in zip(indices, values)
}
# sort so we can see most relevant tokens first
sparse_dict_tokens = {
    k: v for k, v in sorted(
        sparse_dict_tokens.items(),
        key=lambda item: item[1],
        reverse=True
    )
}
sparse_dict_tokens

import pinecone


def builder(records: list):
    ids = [x['id'] for x in records]
    contexts = [x['context'] for x in records]
    # create dense vecs
    dense_vecs = dense_model.encode(contexts).tolist()
    # create sparse vecs
    input_ids = tokenizer(
        contexts, return_tensors='pt',
        padding=True, truncation=True
    )
    with torch.no_grad():
        sparse_vecs = sparse_model(
            d_kwargs=input_ids.to(device)
        )['d_rep'].squeeze()
    # convert to upsert format
    upserts = []
    for _id, dense_vec, sparse_vec, context in zip(ids, dense_vecs, sparse_vecs, contexts):
        # extract columns where there are non-zero weights
        indices = sparse_vec.nonzero().squeeze().cpu().tolist()  # positions
        values = sparse_vec[indices].cpu().tolist()  # weights/scores
        # build sparse values dictionary
        sparse_values = {
            "indices": indices,
            "values": values
        }
        # build metadata struct
        metadata = {'context': context}
        # append all to upserts list as pinecone.Vector (or GRPCVector)
        upserts.append({
            'id': _id,
            'values': dense_vec,
            'sparse_values': sparse_values,
            'metadata': metadata
        })
    return upserts

In [15]:
builder(data[:3])

[{'id': '21645374-0',
  'values': [-0.0860980823636055,
   -0.06404668837785721,
   -0.09067502617835999,
   -0.13883478939533234,
   0.4034903347492218,
   0.045110125094652176,
   0.17842219769954681,
   0.008638600818812847,
   0.39867258071899414,
   -0.12001333385705948,
   -0.0558837428689003,
   0.10405895859003067,
   -0.5984246730804443,
   0.44607430696487427,
   0.07607333362102509,
   0.718574047088623,
   0.1389893740415573,
   -0.03241889178752899,
   0.059661466628313065,
   0.058138348162174225,
   -0.14696797728538513,
   0.02058233693242073,
   0.7175154685974121,
   0.2626684904098511,
   0.1868906319141388,
   -0.27962228655815125,
   -0.43341633677482605,
   -0.3650135099887848,
   -0.40824878215789795,
   0.4922322630882263,
   -0.04993167147040367,
   -0.3248227834701538,
   0.14582441747188568,
   -0.21379970014095306,
   0.06254763901233673,
   -0.031296614557504654,
   -0.5419843792915344,
   -0.16867417097091675,
   -0.44803082942962646,
   -0.075442731380462

In [16]:
import pinecone
import os 

CREATE_INDEX = False

def init_pinecone_connection():
    api_key = os.getenv("PINECONE_API_KEY") or "YOUR_API_KEY"
    env = os.getenv("PINECONE_ENVIRONMENT") or "YOUR_ENVIRONMENT"
    pinecone.init(api_key=api_key, environment=env)

init_pinecone_connection()
index_name = 'pubmed-splade'
if CREATE_INDEX:
    pinecone.create_index(
        index_name,
        dimension=dim,
        metric="dotproduct",
        pod_type="s1"
    )
else:
    index = pinecone.GRPCIndex(index_name)

In [17]:
index.describe_index_stats()
index.upsert(builder(data[:3]))



upserted_count: 3

In [18]:
index = pinecone.GRPCIndex(index_name)

from tqdm.auto import tqdm

batch_size = 64

for i in tqdm(range(0, len(data), batch_size)):
    # extract batch of data
    i_end = min(i+batch_size, len(data))
    batch = data[i:i_end]
    # pass data to builder and upsert
    index.upsert(builder(data[i:i+batch_size]))

100%|██████████| 93/93 [24:08<00:00, 15.58s/it]


In [None]:
len(data), index.describe_index_stats()

In [None]:
def encode(text: str):
    # create dense vec
    dense_vec = dense_model.encode(text).tolist()
    # create sparse vec
    input_ids = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        sparse_vec = sparse_model(
            d_kwargs=input_ids.to(device)
        )['d_rep'].squeeze()
    # convert to dictionary format
    indices = sparse_vec.nonzero().squeeze().cpu().tolist()
    values = sparse_vec[indices].cpu().tolist()
    sparse_dict = {"indices": indices, "values": values}
    # return vecs
    return dense_vec, sparse_dict

In [None]:
query = "Can clinicians use the PHQ-9 to assess depression in people with vision loss?"
dense, sparse = encode(query)
# query
xc = index.query(
    vector=dense,
    sparse_vector=sparse,
    top_k=2,  # how many results to return
    include_metadata=True
)
xc

In [None]:
def hybrid_scale(dense, sparse, alpha: float):
    # check alpha value is in range
    if alpha < 0 or alpha > 1:
        raise ValueError("Alpha must be between 0 and 1")
    # scale sparse and dense vectors to create hybrid search vecs
    hsparse = {
        'indices': sparse['indices'],
        'values':  [v * (1 - alpha) for v in sparse['values']]
    }
    hdense = [v * alpha for v in dense]
    return hdense, hsparse

In [None]:
hdense, hsparse = hybrid_scale(dense, sparse, alpha=1.0)
# query
xc = index.query(
    vector=hdense,
    sparse_vector=hsparse,
    top_k=2,  # how many results to return
    include_metadata=True
)
xc

In [None]:
hdense, hsparse = hybrid_scale(dense, sparse, alpha=0.0)
# query
xc = index.query(
    vector=hdense,
    sparse_vector=hsparse,
    top_k=2,  # how many results to return
    include_metadata=True
)
xc

In [None]:
query = "Does ibuprofen increase perioperative blood loss during hip arthroplasty?"
dense, sparse = encode(query)
hdense, hsparse = hybrid_scale(dense, sparse, alpha=0.0)  # pure SPARSE
# query
xc = index.query(
    vector=hdense,
    sparse_vector=hsparse,
    top_k=2,  # how many results to return
    include_metadata=True
)
xc

In [None]:
query = "Does ibuprofen increase perioperative blood loss during hip arthroplasty?"
dense, sparse = encode(query)
hdense, hsparse = hybrid_scale(dense, sparse, alpha=1.0)  # pure DENSE
# query
xc = index.query(
    vector=hdense,
    sparse_vector=hsparse,
    top_k=2,  # how many results to return
    include_metadata=True
)
xc