In [1]:
import sys
import os
sys.path.append(os.path.abspath("../"))
import simi

import pandas as pd
from sentence_transformers import evaluation, losses, models, InputExample, SentenceTransformer
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

EMBEDDING = models.Transformer('allenai/specter2_base')
POOLING = models.Pooling(EMBEDDING.get_word_embedding_dimension())
# POOLING = models.Pooling(EMBEDDING.get_word_embedding_dimension(), pooling_mode="cls") # CLS pooling

MODEL = SentenceTransformer(modules=[EMBEDDING, POOLING])

RANDOM_STATE = 1

2023-09-08 23:14:56.594539: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
df = pd.read_csv("anchor-arxiv-dataset.csv", index_col=0)
# TODO: the mathberta tokenizer also recognizes latex within [MATH]...[/MATH] -> convert latex markupp in titles accordingly
df.sample(5)

Unnamed: 0,title,abstract,categories,doi,a,p,n
27623,Total nonnegativity of infinite Hurwitz matric...,In this paper we fully describe functions ge...,"('math.CV', 'math.CA')",10.1007/s11785-013-0344-0,An alternative criterion for entire functions ...,The results are based on a connection between ...,We construct a bounded domain $\Omega$ in $\m...
20264,Clifford Prolate Spheroidal wave Functions,"In the present paper, we introduce the multi...","('math.CA',)",,Then we investigate the role of the CPSWFs in ...,"In the present paper, we introduce the multid...",We prove the total positivity of the Narayana...
16502,Analysis of the Energy Decay of a Degenerated ...,"In this paper, we study a system of thermoel...","('math.AP',)",,"In this paper, we study a system of thermoela...",In the first case and under special assumption...,The paper deals with output feedback stabiliz...
17091,Well-posedness for the continuity equation for...,We prove well-posedness of linear scalar con...,"('math.AP', 'math.CA')",,"As an application, we obtain uniqueness of sol...",We prove well-posedness of linear scalar cons...,Unstable behavior is `discouraged' by the runn...
68469,Isomorphisms between determinantal point proce...,We prove the Bernoulli property for determin...,"('math.PR', 'math.DS')",,"For this purpose, we also prove the Bernoulli ...","As its continuum version, we prove an isomorph...","For each system, we derive a closed-form expre..."


In [3]:
X_treval, X_test = train_test_split(df, train_size=0.9, random_state=RANDOM_STATE)
X_train, X_eval = train_test_split(X_treval, train_size=0.888889, random_state=RANDOM_STATE)
print("train:", len(X_train), "eval:", len(X_eval), "test:", len(X_test))

train: 73760 eval: 9221 test: 9221


In [4]:
# prepare evaluation data
eval_examples = X_eval.reset_index(drop=True).apply(lambda r: InputExample(texts=[r["a"], r["p"], r["n"]]), axis=1)
evaluator = evaluation.TripletEvaluator.from_input_examples(eval_examples)

In [5]:
# prepare re-training: training data, loss
retrain_examples = X_train.reset_index(drop=True).apply(lambda r: InputExample(texts=[r["a"], r["p"], r["n"]]), axis=1)
retrain_dataloader = DataLoader(retrain_examples, shuffle=True, batch_size=8)
retrain_loss = losses.TripletLoss(model=MODEL, triplet_margin=5) # TODO: triplet_margin: hyperparameter to optimize

In [6]:
%%time
# finetune model
MODEL.fit(train_objectives=[(retrain_dataloader, retrain_loss)], evaluator=evaluator, epochs=10,
        output_path="specter2+mp+retrain_anchor_arxiv") 

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9220 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9220 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9220 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9220 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9220 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9220 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9220 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9220 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9220 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9220 [00:00<?, ?it/s]

CPU times: user 10h 32min 28s, sys: 2h 2min 15s, total: 12h 34min 44s
Wall time: 8h 53min 35s


In [7]:
%%time
# evaluate Spearman-Pearson-rank-coefficient on test data
spear_ranc = MODEL.evaluate(evaluator)
spear_ranc

CPU times: user 4min 28s, sys: 15.7 s, total: 4min 43s
Wall time: 1min 46s


0.8958898167226982