In [1]:
import numpy as np
import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel  

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")  
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext").cuda()

# replace with your own list of entity names
all_names = ["covid-19", "Coronavirus infection", "high fever", "Tumor of posterior wall of oropharynx"] 

bs = 128 # batch size during inference
all_embs = []
for i in tqdm(np.arange(0, len(all_names), bs)):
    toks = tokenizer.batch_encode_plus(all_names[i:i+bs], 
                                       padding="max_length", 
                                       max_length=25, 
                                       truncation=True,
                                       return_tensors="pt")
    toks_cuda = {}
    for k,v in toks.items():
        toks_cuda[k] = v.cuda()
    cls_rep = model(**toks_cuda)[0][:,0,:] # use CLS representation as the embedding
    all_embs.append(cls_rep.cpu().detach().numpy())

all_embs = np.concatenate(all_embs, axis=0)

Downloading (…)okenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [00:00<00:00, 390kB/s]
Downloading (…)lve/main/config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 462/462 [00:00<00:00, 4.31MB/s]
Downloading (…)solve/main/vocab.txt: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 226k/226k [00:00<00:00, 1.05MB/s]
Downloading (…)cial_tokens_map.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 112/112 [00:00<00:00, 182kB/s]
Downloading pytorch_model.bin: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 438M/438M [00:16<00:00, 27.4MB/s]
100%|███████████████████████████████████

In [3]:
all_embs

array([[-0.64542055, -0.38714662, -0.21302707, ...,  0.23982953,
         0.80409443,  0.42334706],
       [-1.1885477 , -0.2745586 ,  0.28520143, ..., -0.12637031,
         0.80622345,  0.06786576],
       [-0.13091089,  0.43986952, -0.14277111, ..., -0.3178458 ,
         0.37847698,  0.15484768],
       [-0.86763996,  0.00546674, -0.38051993, ...,  0.18258967,
         0.89631283, -0.29539046]], dtype=float32)

In [1]:
!nvidia-smi

Fri Oct  6 07:55:46 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   31C    P8    23W / 185W |      0MiB /  8192MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
|  0%   31C    P8    21W / 185W |      0MiB /  8192MiB |      0%      Default |
|       