# Exercise 2: Document-level embedding analysis of PubMed papers with SPECTER

In [None]:
from transformers import AutoTokenizer, AutoModel

# Load SPECTER model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/specter")
model = AutoModel.from_pretrained("allenai/specter")

In [1]:
# Load dictionary of papers
import json

with open("papers_dict.json") as f:
    papers_dict = json.load(f)

papers_dict

{'37941028': {'ArticleTitle': 'Mechanism and therapeutic potential of targeting cGAS-STING signaling in neurological disorders.',
  'AbstractText': "DNA sensing is a pivotal component of the innate immune system that is responsible for detecting mislocalized DNA and triggering downstream inflammatory pathways. Among the DNA sensors, cyclic GMP-AMP synthase (cGAS) is a primary player in detecting cytosolic DNA, including foreign DNA from pathogens and self-DNA released during cellular damage, culminating in a type I interferon (IFN-I) response through stimulator of interferon genes (STING) activation. IFN-I cytokines are essential in mediating neuroinflammation, which is widely observed in CNS injury, neurodegeneration, and aging, suggesting an upstream role for the cGAS DNA sensing pathway. In this review, we summarize the latest developments on the cGAS-STING DNA-driven immune response in various neurological diseases and conditions. Our review covers the current understanding of the 

In [None]:
import tqdm

# we can use a persistent dictionary (via shelve) so we can stop and restart if needed
# alternatively, do the same but with embeddings starting as an empty dictionary
embeddings = {}
for pmid, paper in tqdm.tqdm(papers.items()):
    data = [paper["ArticleTitle"] + tokenizer.sep_token + get_abstract(paper)]
    inputs = tokenizer(
        data, padding=True, truncation=True, return_tensors="pt", max_length=512
    )
    result = model(**inputs)
    # take the first token in the batch as the embedding
    embeddings[pmid] = result.last_hidden_state[:, 0, :].detach().numpy()[0]

# turn our dictionary into a list
embeddings = [embeddings[pmid] for pmid in papers.keys()]

In [None]:
# Identify the first three principal components of the paper embeddings
from sklearn import decomposition
pca = decomposition.PCA(n_components=3)
embeddings_pca = pd.DataFrame(
    pca.fit_transform(embeddings),
    columns=['PC0', 'PC1', 'PC2']
)
embeddings_pca["query"] = [paper["query"] for paper in papers.values()]

In [None]:
# Plot 2-D scatter plots for PC0 vs PC1, PC0 vs PC2, PC1 vs PC2, color-coded by search query

**Interpretation:**