pubmedbert model in huggingface:
https://huggingface.co/NeuML/pubmedbert-base-embeddings

The enviornment to run this code is called "transformer"

In [2]:
import os
import json
import subprocess
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
nodes = pd.read_csv('data/nodes_snake.csv', sep= ',')
node_names = nodes['node_name'].tolist()

In [7]:
# There are some NaN in the descriptions list, clean it: ensure all elements are strings
node_names = ["" if pd.isna(i) else str(i) for i in node_names]

# check if there is any Non-string left
for idx, desc in enumerate(node_names):
    if not isinstance(desc, str):
        print(f"Non-string at index {idx}: {desc} (type: {type(desc)})")

In [8]:
# Mean Pooling - Take attention mask into account for correct averaging
def meanpooling(output, mask):
    embeddings = output[0] # First element of model_output contains all token embeddings
    mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
    return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)


In [10]:
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings")
model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings")


In [11]:
attr = torch.empty((0, 768))
# Tokenize sentences
for i in node_names:
    inputs = tokenizer(i, padding=True, truncation=True, max_length=512, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        output = model(**inputs)
    
    # Perform pooling. In this case, mean pooling.
    emb = meanpooling(output, inputs['attention_mask'])
    attr = torch.concat((attr, emb), dim = 0)

print("Sentence embeddings:")
print(attr)

Sentence embeddings:
tensor([[-0.0797, -0.4923, -0.3390,  ..., -0.7394,  0.0348, -0.1092],
        [-0.5147, -0.7015, -0.2551,  ..., -0.5804,  0.5219, -0.2627],
        [ 0.2139, -0.7698,  0.1177,  ..., -0.4510, -0.0762, -0.1129],
        ...,
        [-0.3003, -0.0267,  0.5011,  ...,  0.0611,  0.2385, -0.0895],
        [-0.4699, -0.5239,  0.1228,  ..., -0.1884, -1.1497,  0.0826],
        [-0.2361,  0.4123,  0.2347,  ..., -0.6276,  0.4826, -0.1514]])


In [12]:
print(attr.shape)

torch.Size([129375, 768])


In [15]:
attr_df = pd.DataFrame(attr.numpy())
attr_df.index = nodes['node_index']
attr_df.to_csv("data/emb_pubmedbert_all_nodes.csv", index=True)

In [17]:
attr_df.shape

(129375, 768)