# Exercise 2: Document-level embedding analysis of PubMed papers with SPECTER

In [2]:
from transformers import AutoTokenizer, AutoModel

# Load SPECTER model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/specter")
model = AutoModel.from_pretrained("allenai/specter")

  from .autonotebook import tqdm as notebook_tqdm
Downloading (…)okenizer_config.json: 100%|██████████| 321/321 [00:00<00:00, 1.60MB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 612/612 [00:00<00:00, 4.29MB/s]
Downloading (…)solve/main/vocab.txt: 100%|██████████| 222k/222k [00:00<00:00, 8.69MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 334kB/s]
Downloading pytorch_model.bin: 100%|██████████| 440M/440M [00:08<00:00, 54.1MB/s] 


In [None]:
# Load dictionary of papers 
import json

with open("papers_dict.json") as f:
    papers_dict = json.load(f)

In [None]:
# Process papers to find SPECTER embeddings
import tqdm

# we can use a persistent dictionary (via shelve) so we can stop and restart if needed
# alternatively, do the same but with embeddings starting as an empty dictionary
embeddings = {}
for pmid, paper in tqdm.tqdm(papers.items()):
    data = [paper["ArticleTitle"] + tokenizer.sep_token + get_abstract(paper)]
    inputs = tokenizer(
        data, padding=True, truncation=True, return_tensors="pt", max_length=512
    )
    result = model(**inputs)
    # take the first token in the batch as the embedding
    embeddings[pmid] = result.last_hidden_state[:, 0, :].detach().numpy()[0]

# turn our dictionary into a list
embeddings = [embeddings[pmid] for pmid in papers.keys()]

In [None]:
# Identify the first three principal components of the paper embeddings
from sklearn import decomposition
pca = decomposition.PCA(n_components=3)
embeddings_pca = pd.DataFrame(
    pca.fit_transform(embeddings),
    columns=['PC0', 'PC1', 'PC2']
)
embeddings_pca["query"] = [paper["query"] for paper in papers.values()]

In [None]:
# Plot 2-D scatter plots for PC0 vs PC1, PC0 vs PC2, PC1 vs PC2, color-coded by search query

**Interpretation:**