In [1]:
import simi

import pandas as pd
from sentence_transformers import evaluation, losses, models, InputExample, SentenceTransformer
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

EMBEDDING = models.Transformer('allenai/specter2_base')
POOLING = models.Pooling(EMBEDDING.get_word_embedding_dimension())
# POOLING = models.Pooling(EMBEDDING.get_word_embedding_dimension(), pooling_mode="cls") # CLS pooling

MODEL = SentenceTransformer(modules=[EMBEDDING, POOLING])

RANDOM_STATE = 1

def similarity_scoring(df, model, append=False):
    scores = df.apply(lambda r: simi.pairwise_cosine_similarity(simi.model_embeddings(model, [r["title_a"], r["title_b"]]))[0][0], axis=1)
    scores = pd.Series(scores, index=df.index, name="cosine-sim")
    if append:
        return df.merge(scores.to_frame(), left_index=True, right_index=True)
    return scores

2023-09-08 23:06:15.056599: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
df = pd.read_csv("class-zbmath-dataset.csv", index_col=0)
df["label"] = df["label"].astype("float")
# TODO: the mathberta tokenizer also recognizes latex within [MATH]...[/MATH] -> convert latex markupp in titles accordingly
df.sample(5)

Unnamed: 0,title_a,MSC_a,MSC2_a,title_b,MSC_b,MSC2_b,label
278544,The inert doublet model of dark matter revisited,83F05,"('81V22',)",Aligned natural inflation: monodromies of two ...,83F05,"('81V22',)",1.0
352494,The effect of contact on the decohesion of lam...,74M15,"('74K10', '74E30')",2-D normal compliances for elastic and viscoel...,74M15,"('74E30', '74D05', '74Q15', '74Q20')",1.0
7670,A Markovian stochastic model of the profit's g...,91B62,"('90C40', '90C90')",Stochastic evolutionary dynamics resolve the t...,91A22,"('91A15',)",0.0
394974,Measure of departure from symmetric associatio...,62H17,"('62H20',)",Parameter estimation in nonlinear regression m...,62J02,"('62E20', '62F10', '62F25')",0.0
331493,A wavelet-based model of one-dimensional perio...,42C40,"('65T60',)",Triangular summability and Lebesgue points of ...,42B08,"('42B10', '42B25')",0.0


In [3]:
X_treval, X_test = train_test_split(df, train_size=0.9, random_state=RANDOM_STATE)
X_train, X_eval = train_test_split(X_treval, train_size=0.888889, random_state=RANDOM_STATE)
print("train:", len(X_train), "eval:", len(X_eval), "test:", len(X_test))

train: 351472 eval: 43935 test: 43935


In [4]:
# prepare evaluation data
eval_examples = X_eval.reset_index(drop=True).apply(lambda r: InputExample(texts=[r["title_a"], r["title_b"]], label=r["label"]), axis=1)
evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(eval_examples)

In [5]:
# prepare re-training: training data, loss
retrain_examples = X_train.reset_index(drop=True).apply(lambda r: InputExample(texts=[r["title_a"], r["title_b"]], label=r["label"]), axis=1)
retrain_dataloader = DataLoader(retrain_examples, shuffle=True, batch_size=8)
#retrain_loss = losses.SoftmaxLoss(model=MODEL, num_labels=2,
#        sentence_embedding_dimension=MODEL.get_sentence_embedding_dimension())
#retrain_loss = losses.ContrastiveLoss(model=MODEL)
retrain_loss = losses.CosineSimilarityLoss(model=MODEL)

In [6]:
%%time
# finetune model
MODEL.fit(train_objectives=[(retrain_dataloader, retrain_loss)], evaluator=evaluator, epochs=10,
        output_path="specter2+mp+retrain_class_zbmath") 

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/43934 [00:00<?, ?it/s]

Iteration:   0%|          | 0/43934 [00:00<?, ?it/s]

Iteration:   0%|          | 0/43934 [00:00<?, ?it/s]

Iteration:   0%|          | 0/43934 [00:00<?, ?it/s]

Iteration:   0%|          | 0/43934 [00:00<?, ?it/s]

Iteration:   0%|          | 0/43934 [00:00<?, ?it/s]

Iteration:   0%|          | 0/43934 [00:00<?, ?it/s]

Iteration:   0%|          | 0/43934 [00:00<?, ?it/s]

Iteration:   0%|          | 0/43934 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [7]:
%%time
# evaluate Spearman-Pearson-rank-coefficient on test data
spear_ranc = MODEL.evaluate(evaluator)
spear_ranc

CPU times: user 9min 5s, sys: 37.7 s, total: 9min 43s
Wall time: 1min 31s


0.5370127892961544

In [8]:
%%time
# evaluate model with test data
X_test_score = similarity_scoring(X_test, MODEL, append=True)

CPU times: user 7min 51s, sys: 16.5 s, total: 8min 7s
Wall time: 5min 38s


In [9]:
pd.set_option('display.max_colwidth', 160)

In [10]:
X_test_score.sample(15, random_state=RANDOM_STATE)

Unnamed: 0,title_a,MSC_a,MSC2_a,title_b,MSC_b,MSC2_b,label,cosine-sim
15960,Global exact quadratization of continuous-time nonlinear control systems,93C10,"('93C15', '93A10', '34H05', '34H99', '34A34', '53A04')",Mixed \(\mathcal{H}_2/\mathcal{H}_\infty\) control of hidden Markov jump systems,93E03,"('60J75', '93B36', '93C55', '93C05')",0.0,-0.091918
186573,Edge operators with conditions of Toeplitz type,58J40,"('35S15', '47G30', '35A17', '35J70', '58J32')",A Bismut type theorem for subelliptic heat semigroups,58J20,"('35H20', '47D06')",0.0,0.057823
181977,1-cohomology and splitting of group extensions,20E22,"('20J99', '20E07')",On some products of nilpotent groups,20E22,"('20F16', '20F18', '20E07', '20F14', '20H25')",1.0,0.292061
299961,Reputation in the long-run with imperfect monitoring,91A20,"('91A05',)",Parallel repetition via fortification: analytic view and the quantum case,91A20,"('81P40', '81P45', '91A05', '91A06', '91A12', '91A80')",1.0,0.238634
226982,Traces and quasi-traces on the Boutet de Monvel algebra.,58J42,"('35S15',)",The local and global parts of the basic zeta coefficient for operators on manifolds with boundary,58J42,"('35S15',)",1.0,0.037171
306772,"A cyclic weight algorithm of decoding the \((47, 24, 11)\) quadratic residue code",94B35,"('94B40',)",A result on the weight distributions of binary quadratic residue codes,94B35,"('94B40',)",1.0,0.965497
205892,Properties of the Székely-Móri symmetry criterion statistics in the case of binary vectors,60E05,"('62E20', '62H10')",On deformation technique of the hyperbolic secant distribution,60E05,"('60E10', '62E17', '62E20')",1.0,0.856848
218358,Simultaneous visibility representations of plane \(st\)-graphs using L-shapes,05C62,"('05C10', '05C85', '68R10')",On the minimum order of graphs with given semigroup,05C99,"('05C65', '20M30')",0.0,0.671497
77078,The unsteady MHD boundary-layer flow on a shrinking sheet,76W05,"('76N20', '76M45')",Meridional trapping and zonal propagation of inertial waves in a rotating fluid shell,76U05,"('76B55', '86A05')",0.0,0.034627
63150,On unified contact metric manifold,53C15,"('53C25',)",Two characterizations of the Chern connection,53C10,"('53A55', '53B05', '58A20', '58A32')",0.0,-0.002014


In [11]:
X_test_score[X_test_score["label"] == 1].describe()

Unnamed: 0,label,cosine-sim
count,21860.0,21860.0
mean,1.0,0.585343
std,0.0,0.318441
min,1.0,-0.204495
25%,1.0,0.313754
50%,1.0,0.651043
75%,1.0,0.878361
max,1.0,0.999384


In [12]:
X_test_score[X_test_score["label"] == 0].describe()

Unnamed: 0,label,cosine-sim
count,22075.0,22075.0
mean,0.0,0.205488
std,0.0,0.271258
min,0.0,-0.266328
25%,0.0,-0.002529
50%,0.0,0.104405
75%,0.0,0.355813
max,0.0,0.996512
