In [None]:
import os
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import torch
from neurovlm.data import get_data_dir
from neurovlm.models import Specter

# Text Encoding

This notebook encodes (title, abstract) pairs using Specter. Specter was trained on scientific (title, abstract) pairs, suggesting it is likely to perform well with medium length queries. MiniLM-L6 is expected to better handle short form queries. 


Use specter is used to to encode (title, abstract) pairs to a 768 dimensional space.

> A. Singh, M. D'Arcy, A. Cohan, D. Downey, and S. Feldman, “SciRepEval: A Multi-Format Benchmark for Scientific Document Representations,” in Proc. Conf. Empirical Methods in Natural Language Processing (EMNLP), 2022. [Online]. Available: https://api.semanticscholar.org/CorpusID:254018137


In [None]:
# Load publications dataframe
data_dir = get_data_dir()
df_pubs = pd.read_parquet(data_dir / "publications_less.parquet")

In [None]:
# Load specter
adapter = "adhoc_query"
specter = Specter(adapter=adapter)

# Encode text in batches
os.makedirs("specter", exist_ok=True)

papers = [title + "[SEP]" + abstract
          for title, abstract in zip(df_pubs['name'], df_pubs['description'])]

batch_size = 4

for i in tqdm(range(0, len(papers), batch_size), total=len(papers)//batch_size):

    with torch.no_grad():
        latent = specter(papers[i:i+batch_size])

    torch.save(
        {"embeddings": latent, "pmid": df_pubs["pmid"].values[i:i+batch_size]},
        f"specter/encoded_text_specter2_{adapter}_{str(i).zfill(5)}.pt",
        pickle_protocol=5
    )

In [None]:
# Stack vectors and save
files = os.listdir("specter")
files.sort()

latent_text = torch.zeros((len(df_pubs), 768), dtype=torch.float32)
pmids_text = np.zeros(len(df_pubs), dtype=int)

for idx in range(0, len(df_pubs), batch_size):
    latent_text[idx:idx+batch_size] , pmids_text[idx:idx+batch_size] = torch.load(
        f"specter/encoded_text_specter2_{adapter}_{str(idx).zfill(5)}.pt", weights_only=False
    ).values()

# Sort the same as df_pubs
inds = np.argsort(pmids_text)
latent_text = latent_text[inds]

# Save
latent_text = latent_text / torch.norm(latent_text, dim=1)[:, None]
torch.save(latent_text, data_dir / f"latent_text_specter2_{adapter}.pt")