# Document Embedding

This notebook uses [`jina-embeddings-v2-base-en` embedding model](https://huggingface.co/jinaai/jina-embeddings-v2-base-en) to transform abstract texts to a 768-dimensional embedding space.

> **Note from our 2021 study:** In 2021, we used GPT-3 Embedding API to generate text similarity embeddings from the input text. Since the API was limited to pieces with less than 2048 token, we used heuristics to remove long abstracts with more than 2048 tokens. Note that, starting January 2022, GPT-3 Embedding API is NOT free anymore. You probably need to pay for it to run this notebook. Cached results from 2021 are available on HuggingFace.

## Input
- `data/pubmed/abstracts_2023.csv.gz` contains raw un-preprocessed texts collected from PubMed. To speed things up, duplicate documents will be queried only once. Documents with identical PMID are considered as duplicate.

## Outputs

- `models/embeddings/jina-embeddings-v2-base-en.safetensors`, in SafeTesnsors format, contains the PMIDs and corresponding embedding weights; one key per document.

### Requirements

You may need a valid HuggingFace access token to run this notebook. You can get one from [here](https://huggingface.co/settings/tokens).


In [1]:
# Setup and imports

from pathlib import Path
from tqdm.auto import tqdm
from IPython.display import display

import polars as pl
import torch
from safetensors.torch import save_file, load_file
from sentence_transformers import SentenceTransformer

from src.cogtext.datasets.pubmed import PubMedDataset

In [2]:
GPT3_MODEL_ID = 'ada'  # 1024-dim embeddings
HF_MODEL_NAME = 'jinaai/jina-embeddings-v2-small-en'

OUTPUT_FILE = f'data/embeddings/abstracts_2023_{HF_MODEL_NAME.split("/")[1]}-sbert.safetensors'

Prepare and cleanup the input data:

In [3]:
data = PubMedDataset(year=2023).load()
data = data.unique('pmid')

print(f'Number of unique documents: {len(data)}')

Number of unique documents: 464765


In [4]:

# 1. Parameters
batch_size = 8          # number of abstracts to embed at once
n_embedded = 0          # number of embedded abstracts
existing_pmids = []     # list of existing embedded pmids
cache_dir = Path('tmp/embeddings') / HF_MODEL_NAME.split("/")[1]
cache_dir.mkdir(parents=True, exist_ok=True)

# 2. load existing cached embeddings
for f in cache_dir.glob('*.safetensors'):
  embeddings = load_file(f)
  pmids = [int(k) for k in embeddings.keys()]
  existing_pmids.extend(pmids)

data = data.filter(~pl.col('pmid').is_in(existing_pmids))
print(f'Loaded {len(existing_pmids)} existing embeddings.')

# 3. initialize the pipeline
model = SentenceTransformer(HF_MODEL_NAME, device='cuda').eval()
model.max_seq_length = 2048

# 4. embed texts
with torch.no_grad():
  with tqdm(total=len(data)) as pbar:
    while n_embedded < len(data):
      batch = data.slice(n_embedded, batch_size)
      texts = batch['abstract'].to_list()
      pmids = batch.select(pl.col('pmid').cast(pl.Utf8))['pmid'].to_list()
      e = model.encode(texts, convert_to_tensor=True)
      e = dict(zip(pmids, e))
      save_file(e, cache_dir / f'from_{pmids[0]}.safetensors')  
      n_embedded += batch_size
      pbar.update(batch_size)


Loaded 3168 existing embeddings.


  0%|          | 0/461597 [00:00<?, ?it/s]