In [1]:
import urllib.request
import sqltables
import json
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModel

In [2]:
db = sqltables.sqlite3.Database("submissions.sqlite3")
submissions = db.open_table("submissions")
    
if "specter" in list(db.tables):
    specter = db.open_table("specter")
else:
    specter = db.create_table(name="specter", column_names=["arxiv_id", "paper_info"])

In [3]:
processed = set(row.arxiv_id for row in specter)
len(processed)

0

In [4]:
batch_size = 32

In [5]:
rows = list(submissions.view("""
select _.* from _ left join specter using (arxiv_id)
where specter.arxiv_id is null
group by arxiv_id
"""))
len(rows)

7069

In [6]:
tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
model = AutoModel.from_pretrained('allenai/specter')

for i in range(0, len(rows), batch_size):
    print(f"{i}/{len(rows)}")
# for i in [0]:
    batch_rows = rows[i:(i+batch_size)]
    title_abs = [row.title + tokenizer.sep_token + row.abstract for row in batch_rows]
    inputs = tokenizer(title_abs, padding=True, truncation=True, return_tensors="pt", max_length=512)
    result = model(**inputs)
    # take the first token in the batch as the embedding
    embeddings = result.last_hidden_state[:, 0, :]
    specter_rows = []
    for row, embedding in zip(batch_rows, embeddings.tolist()):
        paper_info = {"embedding": {"model": "specter@local", "vector": embedding}}
        specter_rows.append([row.arxiv_id, json.dumps(paper_info)])
    specter.insert(specter_rows)

tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

0/7069
32/7069
64/7069
96/7069
128/7069
160/7069
192/7069
224/7069
256/7069
288/7069
320/7069
352/7069
384/7069
416/7069
448/7069
480/7069
512/7069
544/7069
576/7069
608/7069
640/7069
672/7069
704/7069
736/7069
768/7069
800/7069
832/7069
864/7069
896/7069
928/7069
960/7069
992/7069
1024/7069
1056/7069
1088/7069
1120/7069
1152/7069
1184/7069
1216/7069
1248/7069
1280/7069
1312/7069
1344/7069
1376/7069
1408/7069
1440/7069
1472/7069
1504/7069
1536/7069
1568/7069
1600/7069
1632/7069
1664/7069
1696/7069
1728/7069
1760/7069
1792/7069
1824/7069
1856/7069
1888/7069
1920/7069
1952/7069
1984/7069
2016/7069
2048/7069
2080/7069
2112/7069
2144/7069
2176/7069
2208/7069
2240/7069
2272/7069
2304/7069
2336/7069
2368/7069
2400/7069
2432/7069
2464/7069
2496/7069
2528/7069
2560/7069
2592/7069
2624/7069
2656/7069
2688/7069
2720/7069
2752/7069
2784/7069
2816/7069
2848/7069
2880/7069
2912/7069
2944/7069
2976/7069
3008/7069
3040/7069
3072/7069
3104/7069
3136/7069
3168/7069
3200/7069
3232/7069
3264/7069
3296/70

In [7]:
db.close()