# Compute the embeddings of the passages

In [4]:
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm 
import json
import torch
# !pip install sentence_transformers
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer

In [18]:
CSC = True
MDL = 1

if MDL == 1:
    model_id = "paraphrase-multilingual-mpnet-base-v2"
    dim = 768  # dim = model.get_sentence_embedding_dimension()
    metric = 'cosine'
elif MDL == 2:
    model_id = "multi-qa-mpnet-base-dot-v1"   # English only
    dim = 768
    metric = 'dotproduct'
elif MDL == 3:
    model_id = "LaBSE"
    dim = 768
    metric = 'cosine'
elif MDL == 4:
    model_id = "symanto/sn-xlm-roberta-base-snli-mnli-anli-xnli"
    dim = 768
    metric = 'cosine'
elif MDL == 5:
    model_id = "microsoft/Multilingual-MiniLM-L12-H384"
    dim = 384
    metric = 'dotproduct'  # I am not sure whether should be cosine or dotproduct
elif MDL == 6:
    model_id = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
    dim = 384
    metric = 'dotproduct'   # I am not sure whether should be cosine or dotproduct
elif MDL == 7:
    model_id = "Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit"
    dim = 4096
    metric = 'cosine'
elif MDL == 8:
    model_id = "cohere-multilingual-22-12"
    dim = 768
    metric = 'dotproduct'  # https://docs.cohere.ai/docs/multilingual-language-models
elif MDL == 9:
    model_id = "openai-text-embedding-ada-002"
    dim = 1536
    metric = 'cosine'  # Iam not sue but see: https://community.openai.com/t/some-questions-about-text-embedding-ada-002-s-embedding/35299/17
    # and here James Briggs is using cosine

# limit to 384 characters to be safe and ensure we're not over the smaller 128 token limit 
# average token length of 3 characters (usually it is more like 3-5)
limit = 384     
    
# Some local file to cache computed embeddings
embedding_cache_path = '../emb/pubmed-embeddings-{}-chars-{}.json'.format(model_id.replace('/', '_'), limit)    

# Try to use GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if not CSC and model_id == "Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit":
    device = 'cpu'  # My GTX1080Ti RAM is not enough 

# check device being run on
if device != 'cuda':
    print("==========\n"+
          "WARNING: You are not running on GPU so this may be slow.\n"+
          "==========")

    

In [3]:
from datasets import load_dataset

ds = load_dataset(
    'pubmed_qa',
    'pqa_labeled',
    split='train'
)

Found cached dataset pubmed_qa (/users/alwasiti/.cache/huggingface/datasets/pubmed_qa/pqa_labeled/1.0.0/dd4c39f031a958c7e782595fa4dd1b1330484e8bbadd4d9212e5046f27e68924)


In [5]:
# Some local file to cache computed embeddings
embedding_cache_path = '../emb/pubmed-embeddings-{}-chars-{}.json'.format(model_id.replace('/', '_'), limit)

In [6]:
ds[0]

{'pubid': 21645374,
 'question': 'Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?',
 'context': {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.',
   'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), ce

In [7]:
ds

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 1000
})

In [8]:
def chunker(contexts: list):
    chunks = []
    all_contexts = ' '.join(contexts).split('.')
    chunk = []
    for context in all_contexts:
        chunk.append(context)
        if len(chunk) >= 3 and len('.'.join(chunk)) > limit:
            # surpassed limit so add to chunks and reset
            chunks.append('.'.join(chunk).strip()+'.')
            # add some overlap between passages
            chunk = chunk[-1:]  # 1 sentence overlap
    # if we finish and still have a chunk, add it
    if chunk is not None:
        chunks.append('.'.join(chunk))
    return chunks

In [9]:
contexts = ds[0]['context']['contexts']
chunks = chunker(contexts)

In [10]:
chunks

['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature.',
 'PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants. The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis.',
 'madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD)

In [11]:
# give each chunk a unique ID
ids = []
for i in range(len(chunks)):
    ids.append(f"{ds[0]['pubid']}-{i}")
ids

['21645374-0', '21645374-1', '21645374-2', '21645374-3', '21645374-4']

In [12]:
# create the full contexts dataset 
data = []
for record in ds:
    chunks = chunker(record['context']['contexts'])
    for i, context in enumerate(chunks):
        data.append({
            'id': f"{record['pubid']}-{i}",
            'context': context
        })

data[:7]

[{'id': '21645374-0',
  'context': 'Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature.'},
 {'id': '21645374-1',
  'context': 'PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants. The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis.'},
 {'id': '21645374-2',
  'context': 'madagascariensis. A single areole within a window stage leaf (PCD is occurri

In [13]:
len(data)

4316

In [14]:
if model_id == "Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit":
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if CSC:
        dense_model = AutoModel.from_pretrained(model_id, cache_dir='/scratch/project_2007072/cache').to(device)
    else:
        dense_model = AutoModel.from_pretrained(model_id).to(device)
    # Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
    dense_model.eval()
    SPECB_QUE_BOS = tokenizer.encode("[", add_special_tokens=False)[0]
    SPECB_QUE_EOS = tokenizer.encode("]", add_special_tokens=False)[0]
    SPECB_DOC_BOS = tokenizer.encode("{", add_special_tokens=False)[0]
    SPECB_DOC_EOS = tokenizer.encode("}", add_special_tokens=False)[0]
    
    def tokenize_with_specb(texts, is_query):
        # Tokenize without padding
        batch_tokens = tokenizer(texts, padding=False, truncation=True)   
        # Add special brackets & pay attention to them
        for seq, att in zip(batch_tokens["input_ids"], batch_tokens["attention_mask"]):
            if is_query:
                seq.insert(0, SPECB_QUE_BOS)
                seq.append(SPECB_QUE_EOS)
            else:
                seq.insert(0, SPECB_DOC_BOS)
                seq.append(SPECB_DOC_EOS)
            att.insert(0, 1)
            att.append(1)
        # Add padding
        batch_tokens = tokenizer.pad(batch_tokens, padding=True, return_tensors="pt")
        return batch_tokens.to(device)

    def get_weightedmean_embedding(batch_tokens, model):
        # Get the embeddings
        with torch.no_grad():
            # Get hidden state of shape [bs, seq_len, hid_dim]
            last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state

        # Get weights of shape [bs, seq_len, hid_dim]
        weights = (
            torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
            .unsqueeze(0)
            .unsqueeze(-1)
            .expand(last_hidden_state.size())
            .float().to(last_hidden_state.device)
        )

        # Get attn mask of shape [bs, seq_len, hid_dim]
        input_mask_expanded = (
            batch_tokens["attention_mask"]
            .unsqueeze(-1)
            .expand(last_hidden_state.size())
            .float()
        )

        # Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
        sum_mask = torch.sum(input_mask_expanded * weights, dim=1)

        embeddings = sum_embeddings / sum_mask

        return embeddings
else:
    dense_model = SentenceTransformer(
        model_id,
        device=device
    )

we will need to index everything in the format:
```python
{
    'id': 'id-123',
    'values': [0.1, 0.2, ...],  # dense vec
    'metadata': {
        'context': 'some text here',
        'media': 'medical data',
        'language': 'en'
                }  # metadata dict
}
```

In [15]:
def builder(records: list):
    ids = [x['id'] for x in records]
    contexts = [x['context'] for x in records]
    # create dense vecs
    if model_id == "Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit":
        dense_vecs = get_weightedmean_embedding(tokenize_with_specb(contexts, is_query=False), dense_model).tolist()
    else:
        dense_vecs = dense_model.encode(contexts).tolist()
    # convert to upsert format of pinecone
    upserts = []
    for _id, dense_vec, context in zip(ids, dense_vecs, contexts):
        # build metadata struct
        metadata = {
            'context': context,
            'media': 'medical data',
            'language': 'en'
        }
        # append all to upserts list as pinecone.Vector (or GRPCVector)
        upserts.append({
            'id': _id,
            'values': dense_vec,
            'metadata': metadata
        })
    return upserts

In [17]:
if model_id == "Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit":    
    batch_size = 1
else:
    batch_size = 64
        
upserts = []

for i in tqdm(range(0, len(data), batch_size)):
    # extract batch of data
    i_end = min(i+batch_size, len(data))
    batch = data[i:i_end]
    # pass data to builder and save it in upsert list
    upserts.extend(builder(data[i:i+batch_size]))

  0%|          | 0/4316 [00:00<?, ?it/s]

In [18]:
print("Store file on disc")
json.dump(upserts, open(embedding_cache_path, 'w', encoding='utf-8'), ensure_ascii=False)

Store file on disc


# Loading embeddings and upserting into Pinecone index

In [27]:
# Loading Pinecone API from the credentials.json file in the same directory
# this file is not included in the repo, but you can create it yourself
with open('../credentials.json') as file:
    credentials = json.load(file)

PINECONE_KEY = credentials['PINECONE_KEY_' + model_id]
index_id = 'med-sem-' + model_id.lower()

In [20]:
embedding_cache_path

'../emb/pubmed-embeddings-paraphrase-multilingual-mpnet-base-v2-chars-384.json'

In [35]:
# Read data to upsert to the DB
upserts = json.load(open(embedding_cache_path, 'r', encoding='utf-8'))

In [None]:
import pinecone  # pip install pinecone-client

pinecone.init(
    api_key=PINECONE_KEY,  # app.pinecone.io
    environment="us-central1-gcp"  # next to api key in console
)

# pod_type to be p1 or s1
if index_id not in pinecone.list_indexes():
    pinecone.create_index(
        index_id,
        dimension=dim,
        metric=metric,
        pod_type="s1"
    )

# In future pinecone library there will be .GRPCIndex instead of .Index which is faster so better to use
index = pinecone.Index(index_id)
index.describe_index_stats()

{'dimension': 768,
 'index_fullness': 0.0,
 'namespaces': {},
 'total_vector_count': 0}

In [36]:
batch_size = 64

for i in tqdm(range(0, len(data), batch_size)):
    # extract batch of data
    i_end = min(i+batch_size, len(data))
    batch = data[i:i_end]
    # pass data to builder and upsert
    index.upsert(upserts[i:i+batch_size])

  0%|          | 0/68 [00:00<?, ?it/s]