In [22]:
import urllib.request
import sqltables
import json
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModel

In [23]:
db = sqltables.sqlite3.Database("submissions.sqlite3")
submissions = db.open_table("submissions")
    
if "specter" in list(db.tables):
    specter = db.open_table("specter")
else:
    specter = db.create_table(name="specter", column_names=["arxiv_id", "paper_info"])

In [24]:
processed = set(row.arxiv_id for row in specter)
len(processed)

25324

In [25]:
batch_size = 32

In [26]:
rows = list(submissions.view("""
select _.* from _ left join specter using (arxiv_id)
where specter.arxiv_id is null
group by arxiv_id
"""))
len(rows)

3078

In [27]:
tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
model = AutoModel.from_pretrained('allenai/specter')

for i in range(0, len(rows), batch_size):
    print(f"{i}/{len(rows)}")
# for i in [0]:
    batch_rows = rows[i:(i+batch_size)]
    title_abs = [row.title + tokenizer.sep_token + row.abstract for row in batch_rows]
    inputs = tokenizer(title_abs, padding=True, truncation=True, return_tensors="pt", max_length=512)
    result = model(**inputs)
    # take the first token in the batch as the embedding
    embeddings = result.last_hidden_state[:, 0, :]
    specter_rows = []
    for row, embedding in zip(batch_rows, embeddings.tolist()):
        paper_info = {"embedding": {"model": "specter@local", "vector": embedding}}
        specter_rows.append([row.arxiv_id, json.dumps(paper_info)])
    specter.insert(specter_rows)

0/3078
32/3078
64/3078
96/3078
128/3078
160/3078
192/3078
224/3078
256/3078
288/3078
320/3078
352/3078
384/3078
416/3078
448/3078
480/3078
512/3078
544/3078
576/3078
608/3078
640/3078
672/3078
704/3078
736/3078
768/3078
800/3078
832/3078
864/3078
896/3078
928/3078
960/3078
992/3078
1024/3078
1056/3078
1088/3078
1120/3078
1152/3078
1184/3078
1216/3078
1248/3078
1280/3078
1312/3078
1344/3078
1376/3078
1408/3078
1440/3078
1472/3078
1504/3078
1536/3078
1568/3078
1600/3078
1632/3078
1664/3078
1696/3078
1728/3078
1760/3078
1792/3078
1824/3078
1856/3078
1888/3078
1920/3078
1952/3078
1984/3078
2016/3078
2048/3078
2080/3078
2112/3078
2144/3078
2176/3078
2208/3078
2240/3078
2272/3078
2304/3078
2336/3078
2368/3078
2400/3078
2432/3078
2464/3078
2496/3078
2528/3078
2560/3078
2592/3078
2624/3078
2656/3078
2688/3078
2720/3078
2752/3078
2784/3078
2816/3078
2848/3078
2880/3078
2912/3078
2944/3078
2976/3078
3008/3078
3040/3078
3072/3078


In [28]:
db.close()