# Document Embedding

This notebook uses [`jina-embeddings-v2-base-en` embedding model](https://huggingface.co/jinaai/jina-embeddings-v2-base-en) to transform abstract texts to a 768-dimensional embedding space.

> **Note from our 2021 study:** In 2021, we used GPT-3 Embedding API to generate text similarity embeddings from the input text. Since the API was limited to pieces with less than 2048 token, we used heuristics to remove long abstracts with more than 2048 tokens. Note that, starting January 2022, GPT-3 Embedding API is NOT free anymore. You probably need to pay for it to run this notebook. Cached results from 2021 are available on HuggingFace.

## Input
- `data/pubmed/abstracts_2023.csv.gz` contains raw un-preprocessed texts collected from PubMed. To speed things up, duplicate documents will be queried only once. Documents with identical PMID are considered as duplicate.

## Outputs

- `models/embeddings/jina-embeddings-v2-base-en.safetensors`, in SafeTesnsors format, contains the PMIDs and corresponding embedding weights; one key per document.

### Requirements

You may need a valid HuggingFace access token to run this notebook. You can get one from [here](https://huggingface.co/settings/tokens).


In [1]:
# Setup and imports

from pathlib import Path
from tqdm.auto import tqdm
from IPython.display import display

import polars as pl
import torch
from safetensors.torch import save_file, load_file
from sentence_transformers import SentenceTransformer

from src.cogtext.datasets.pubmed import PubMedDataset

In [2]:
HF_MODEL_NAME = 'jinaai/jina-embeddings-v2-small-en'
MAX_TOKENS = 1024
BATCH_SIZE = 64  # number of abstracts to embed at once

OUTPUT_FILE = f'data/embeddings/abstracts2023_{HF_MODEL_NAME.split("/")[1]}.safetensors'

Prepare and cleanup the input data:

In [3]:
data = PubMedDataset(year=2023).load()
data = data.unique('pmid')

print(f'Number of unique documents: {len(data)}')

Number of unique documents: 464765


In [4]:
# discard abstracts where n_tokens > MAX_TOKENS

candidate_abstracts = data.lazy().with_columns(
    data.select(pl.col('abstract')
                  .map_elements(lambda x: x.count(' '))
                  .alias('abstract_len'))
).filter(pl.col('abstract_len').gt(400)).sort('abstract_len', descending=True).collect()

print(f'tokenizing {len(candidate_abstracts)} candidate articles (n_words > 400)...')
model = SentenceTransformer(HF_MODEL_NAME, device='cuda').eval()

def get_n_tokens(texts: pl.Series, tokenizer=model[0].tokenizer):
    """Helper function to count number of tokens in a batch of texts."""
    o = tokenizer(texts.to_list(), return_attention_mask=False, return_token_type_ids=False)
    n_tokens = [len(tokens) for tokens in o['input_ids']]
    return pl.Series(n_tokens)

# count tokens for each candidate abstract
candidate_abstracts = candidate_abstracts.with_columns(
    candidate_abstracts.select(pl.col('abstract').map_batches(get_n_tokens).alias('n_tokens')))

# find abstracts where n_tokens >= MAX_TOKENS
pmids_to_discard = candidate_abstracts.filter(pl.col('n_tokens').ge(MAX_TOKENS))['pmid'].to_list()

# filter long abstracts
data = data.lazy().filter(~pl.col('pmid').is_in(pmids_to_discard)).collect()

print(f'Discarded {len(pmids_to_discard)} articles where n_tokens > 1023.')

tokenizing 12223 candidate articles (n_words > 400)...
Discarded 511 articles where n_tokens > 1023.


In [5]:
# pipeline params

tmp_dir = Path('tmp/embeddings') / HF_MODEL_NAME.split("/")[1]  # e.g., tmp/embeddings/jina*
tmp_dir.mkdir(parents=True, exist_ok=True)

# load existing cached embeddings
existing_pmids = []  # list of existing embedded pmids
for f in tmp_dir.glob('*.safetensors'):
  embeddings = load_file(f)
  pmids = [int(k) for k in embeddings.keys()]
  existing_pmids.extend(pmids)

# filter out existing embeddings
data = data.filter(~pl.col('pmid').is_in(existing_pmids))
print(f'Loaded {len(existing_pmids)} existing embeddings.')

Loaded 464254 existing embeddings.


In [6]:
# initialize the embedding pipeline
model = SentenceTransformer(HF_MODEL_NAME, device='cuda').eval()
model.max_seq_length = MAX_TOKENS

# embed texts
with torch.no_grad():
  with tqdm(total=len(data)) as pbar:
    n_embedded = 0
    while n_embedded < len(data):
      batch = data.slice(n_embedded, BATCH_SIZE)
      texts = batch['abstract'].to_list()
      pmids = batch.select(pl.col('pmid').cast(pl.Utf8))['pmid'].to_list()
      e = model.encode(texts, convert_to_tensor=True)
      e = dict(zip(pmids, e))
      save_file(e, tmp_dir / f'from_{pmids[0]}.safetensors')  
      n_embedded += BATCH_SIZE
      pbar.update(BATCH_SIZE)

# combine all embeddings into a single file and save
if n_embedded > 0:
  all_embeddings = {}
  for f in tqdm(tmp_dir.glob('*.safetensors')):
    embeddings = load_file(f)
    all_embeddings.update(embeddings)
  save_file(all_embeddings, OUTPUT_FILE)

0it [00:00, ?it/s]