In [1]:
import simi

import itertools
import json
import random
import re
import zipfile

import nltk.data
import pandas as pd
import random
from sentence_transformers import evaluation, losses, models, InputExample, SentenceTransformer
from sklearn.model_selection import train_test_split
import sqlalchemy
from tqdm.auto import tqdm
tqdm.pandas()

ARXIV_ZIP = "arxiv dump/arxiv-metadata-oai-snapshot-version111.json.zip"
DUMP_JSON = "arxiv-metadata-oai-snapshot.json"
RANDOM_STATE = 1

random.seed(RANDOM_STATE)

2023-12-16 20:37:31.083390: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
def join_text(title, abstract):
    title = title.strip("").rstrip(".")
    t = nltk.data.load('tokenizers/punkt/PY3/english.pickle')
    ss = [sn.replace("\n", " ") for sn in t.tokenize(re.sub(r"\s+", " ", "{}. {}".format(title, abstract)))]
    try:
        return " ".join(ss)
    except Exception as e:
        return ""

In [3]:
%%time
dfs = []
with zipfile.ZipFile(ARXIV_ZIP) as za:
    with tqdm(total=za.getinfo(DUMP_JSON).file_size, unit="b", unit_divisor=1024, unit_scale=True, desc=DUMP_JSON) as pb:
        with za.open(DUMP_JSON) as f:       
            for l in f:
                j = json.loads(l)
                dfs.append(pd.DataFrame([[j["title"], j["abstract"], j["categories"], j["doi"]]], columns=["title", "abstract", "categories", "doi"]))
                pb.update(len(l))
df = pd.concat(dfs, ignore_index=True).reset_index(drop=True)
df["categories"] = df["categories"].map(lambda c: tuple(c.split()))

math_cats = sorted([c for c in set(itertools.chain.from_iterable(df["categories"])) if c.startswith("math") or c.startswith("stat")])
df = df[df["categories"].map(lambda c: len(set(c) & set(math_cats)) > 0)]

df["text"] = df.progress_apply(lambda r: join_text(r["title"], r["abstract"]), axis=1)
df = df[df["text"] != ""]

#df = df[["title", "categories"]]
df = df[["text", "categories"]]

df = df[df.categories.map(lambda c: len(c)==1)]

arxiv-metadata-oai-snapshot.json:   0%|          | 0.00/3.37G [00:00<?, ?b/s]

  0%|          | 0/650485 [00:00<?, ?it/s]

CPU times: user 9min 12s, sys: 11.5 s, total: 9min 24s
Wall time: 9min 21s


In [4]:
df = df.reset_index(drop=True)

In [5]:
df.sample(5)

Unnamed: 0,text,categories
216876,Subgradient Ellipsoid Method for Nonsmooth Con...,"(math.OC,)"
146529,Generalized gap acceptance models for unsignal...,"(math.PR,)"
36020,Discrete time approximation of decoupled Forwa...,"(math.PR,)"
51646,Analytic approximation in $L^p$ and coinvarian...,"(math.CV,)"
181380,An introduction to Bent Jorgensen's ideas. We ...,"(stat.OT,)"


## Model definitions

In [10]:
# Prerequisites for the models
bert_base = models.Transformer('bert-base-uncased')
bert_mp = SentenceTransformer(modules=[bert_base, models.Pooling(bert_base.get_word_embedding_dimension())])

mathbert_base = models.Transformer('witiko/mathberta')
mathbert_mp = SentenceTransformer(modules=[mathbert_base, models.Pooling(mathbert_base.get_word_embedding_dimension())])
mathbert_cls = SentenceTransformer(modules=[mathbert_base, models.Pooling(mathbert_base.get_word_embedding_dimension(), pooling_mode="cls")])

bert_mlm_base = models.Transformer("./bert+re-train_mlm_abstracts_arxiv")
bert_mlm_mp = SentenceTransformer(modules=[bert_mlm_base, models.Pooling(bert_mlm_base.get_word_embedding_dimension())])

sbert = SentenceTransformer('all-mpnet-base-v2')
# alternative: sbert = SentenceTransformer('all-distilroberta-v1')

specter2_base = models.Transformer('allenai/specter2_base')

# Definition of models that are evaluated
# combine models into dict
eval_models = {
    "Bert+MP": bert_mp,
    "Bert+MP+class-arx": SentenceTransformer('./bert+mean-pooling+retrain_class_arxiv'),
    "Bert+MP+class-zbm": SentenceTransformer('./bert+mean-pooling+retrain_class_zbmath'),
    
    "Mathbert+CLS": mathbert_cls,
    "Mathbert+MP+class-arx": SentenceTransformer('./mathbert+mean-pooling+retrain_class_arxiv'),
    "Mathbert+MP+class-zbm": SentenceTransformer('./mathbert+mean-pooling+retrain_class_zbmath'),
    
    "Bert+TSDAE+MP": SentenceTransformer('./bert+mean-pooling+re-train_tsdae_abstracts_arxiv'),
    "Bert+TSDAE+MP+class-arx": SentenceTransformer('./bert+mean-pooling+re-train_tsdae_abstracts_arxiv+retrain_class_arxiv'),
    "Bert+TSDAE+MP+class-zbm": SentenceTransformer('./bert+mean-pooling+re-train_tsdae_abstracts_arxiv+retrain_class_zbmath'),
    
    "Bert+MLM+MP": bert_mlm_mp,
    "Bert+MLM+MP+class-arx": SentenceTransformer('./bert+re-train_mlm_abstracts_arxiv+mean-pooling+retrain_class_arxiv'),
    "Bert+MLM+MP+class-zbm": SentenceTransformer('./bert+re-train_mlm_abstracts_arxiv+mean-pooling+retrain_class_zbmath'),
    "Bert+MLM+MP+class-zbm+anch-arx": SentenceTransformer('./bert+re-train_mlm_abstracts_arxiv+mean-pooling+retrain_class_zbmath_anchor_arxiv'),
    "Bert+MLM+MP+class-arx+anch-arx+class-zbm": SentenceTransformer('./bert+re-train_mlm_abstracts_arxiv+mean-pooling+retrain_class_arxiv_anchor_arxiv_class_zbmath'),
    
    "SBert": sbert,
    "SBert+ret_class-arx": SentenceTransformer('./sbert+retrain_class_arxiv'),
    "SBert+ret_class-zbm": SentenceTransformer('./sbert+retrain_class_zbmath'),
    
    "Specter 2+MP": SentenceTransformer(modules=[specter2_base, models.Pooling(specter2_base.get_word_embedding_dimension())]),
    "Specter 2.0+MP+class-arx": SentenceTransformer('./specter2+mp+retrain_class_arxiv'),
    "Specter 2.0+MP+class-zbm": SentenceTransformer('./specter2+mp+retrain_class_zbmath'),
    "Specter 2+MP+class-zbm+anch-arx": SentenceTransformer('./specter2+mp+retrain_class_zbmath_anchor_arxiv'),
    "Specter 2+MP+anch-arx+class-arx+class-zbm": SentenceTransformer('./specter2+mp+retrain_anchor_arxiv_class_arxiv_class_zbmath'),
    
    "SGPT": SentenceTransformer('Muennighoff/SGPT-125M-weightedmean-nli-bitfit'),
    "SGPT+class-arx": SentenceTransformer('./gpt+retrain_class_arxiv'),
    "SGPT+class-zbm": SentenceTransformer('./gpt+retrain_class_zbmath'),
}

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at witiko/mathberta were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.weight'

In [7]:
%%time

with tqdm(total=len(eval_models), desc="Models") as pb:
    for n,m in eval_models.items():
        #df["embedding ({})".format(n)] = pd.Series(list(simi.model_embeddings(m, df["title"], show_progress_bar=True)))
        df["embedding ({})".format(n)] = pd.Series(list(simi.model_embeddings(m, df["text"], show_progress_bar=True)))
        pb.update()

Models:   0%|          | 0/24 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

Batches:   0%|          | 0/8510 [00:00<?, ?it/s]

CPU times: user 1d 3h 41min 30s, sys: 49min 47s, total: 1d 4h 31min 18s
Wall time: 11h 54min 49s


In [8]:
df.sample(5)

Unnamed: 0,text,categories,embedding (Bert+MP),embedding (Bert+MP+class-arx),embedding (Bert+MP+class-zbm),embedding (Mathbert+CLS),embedding (Mathbert+MP+class-arx),embedding (Mathbert+MP+class-zbm),embedding (Bert+TSDAE+MP),embedding (Bert+TSDAE+MP+class-arx),...,embedding (SBert),embedding (SBert+ret_class-arx),embedding (SBert+ret_class-zbm),embedding (Specter 2+MP),embedding (Specter 2.0+MP+class-arx),embedding (Specter 2.0+MP+class-zbm),embedding (Specter 2+MP+class-zbm+anch-arx),embedding (Specter 2+MP+anch-arx+class-arx+class-zbm),embedding (SGPT),embedding (SGPT+class-zbm)
14713,Heisenberg Idempotents on Unipotent Groups. Le...,"(math.RT,)","[-0.44673637, -0.117981836, 0.36835602, -0.182...","[-0.048817247, -0.32580274, -0.25932664, 0.486...","[-0.09339113, 0.55284745, 0.40907928, -0.47770...","[-0.029157367, 0.008551015, 0.016446434, 0.000...","[0.11753585, 0.47042903, 0.70793957, 1.3361214...","[-0.66416395, -0.5353767, -0.18511598, 0.39433...","[0.0252847, 0.052180655, 0.08451672, -0.157209...","[-0.869118, -0.9987617, -0.11206023, -0.209534...",...,"[-0.020525329, 0.022543224, -0.021192962, -0.0...","[-0.009449902, -0.028253904, -0.020228831, 0.0...","[0.03998939, -0.0033082857, -0.02631805, 0.016...","[0.23881957, 0.49273804, -0.017457835, 0.33633...","[0.53644896, 1.2513939, 2.0663762, -0.22485189...","[-0.002482295, 0.5286821, -0.25448278, -0.4108...","[1.4593894, 0.9060594, 0.8759439, 0.6162333, 1...","[0.9604389, 1.1320069, 1.1406655, 0.10618342, ...","[-0.7986909, -0.64885604, 1.0200502, -0.413885...","[0.28802395, 0.012298746, 0.21507654, 0.580377..."
87627,"(10, k) Reversible Multiples. We consider the ...","(math.GM,)","[-0.16335775, -0.06617161, 0.5106309, -0.15558...","[-0.9212624, 0.18871203, 0.11060377, 0.3638353...","[0.514042, 1.1183357, -0.55408525, -0.18689784...","[-0.012707563, 0.0381937, -0.0010121999, -0.01...","[1.012272, -0.10622712, -0.3762327, -0.3519919...","[0.27579632, -0.008231211, -0.3095563, 0.10317...","[0.13058585, 0.3956283, 0.35346895, -0.1929293...","[-0.032497577, 0.3461761, -0.27086192, -0.7554...",...,"[-0.014818843, -0.047378317, 0.0044712634, 0.0...","[0.0702136, 0.0017042622, 0.048090816, 0.05911...","[-0.0005425008, 0.008336401, -0.0090653915, -0...","[0.82968575, -0.13845076, -0.26792967, 0.28962...","[0.48242113, -0.27038956, 0.4837663, 0.9690887...","[0.8829145, -2.0528686, -1.0972743, 1.2415856,...","[1.4551079, -1.7115458, -0.6988823, 0.94670486...","[-0.0485003, -0.35216787, -0.35405836, 1.47451...","[-0.8695109, 0.8555988, 0.103301875, -0.973842...","[0.46707067, 1.6612374, 0.46632856, 0.33273768..."
189826,Averages of long Dirichlet polynomials. We con...,"(math.NT,)","[-0.18444498, -0.11001027, 0.32445168, -0.1935...","[-1.1984326, -0.28368613, 0.66014385, 0.139267...","[0.5973368, 0.013618904, 0.42929646, -0.302245...","[0.0021970184, 0.04423874, 0.00942348, 0.02889...","[0.91820145, -0.109373674, -0.25007832, -0.385...","[0.7929841, -0.24122973, -0.023994127, -0.1269...","[-0.0069884663, 0.089529134, 0.097870685, 0.08...","[-0.1290751, 0.29044434, -0.20054877, -0.81433...",...,"[-0.0637499, 0.0035657492, -0.019513607, 0.059...","[0.07701481, 0.012056987, 0.070335194, 0.05495...","[0.053555913, 0.06480057, 0.042535853, -0.0155...","[0.28643018, 0.57273793, 0.14995858, 0.0399353...","[1.1886165, -0.52009845, 0.78884196, 1.2583576...","[-0.21264921, -0.51283103, -0.18349276, 0.9446...","[-0.66535896, 1.1507801, 0.42430782, -0.088890...","[0.5022082, 0.34535408, 1.0854611, 0.67368925,...","[-0.08527299, -1.1973345, 0.26673844, -0.69951...","[0.65142787, -0.7052756, -0.8345003, 0.7399770..."
255083,A Pieri-Chevalley formula for K(G/B). The ring...,"(math.RT,)","[-0.28712058, -0.043906942, 0.2484594, -0.2178...","[0.027420793, -0.24388905, -0.05215356, 0.9634...","[0.2687007, 0.12688278, 0.39017326, 0.211602, ...","[-0.027314896, 0.02201781, 0.006237851, -0.000...","[0.06423249, 0.7417011, 0.8310071, 0.7334632, ...","[-1.008708, -0.42312184, 0.3800256, -0.3231577...","[0.19318767, -0.06424971, 0.044713028, -0.0326...","[-1.9310452, -0.7054479, 0.28916106, -0.071887...",...,"[-0.046625063, -0.04837608, -0.029482346, 0.06...","[-0.040590793, 0.008438152, 0.0034349842, -0.0...","[0.014794293, 0.077488184, -0.025242452, -0.02...","[0.21402888, 0.03785907, -0.41144055, 0.143087...","[-0.37271148, -0.34471813, 0.79242104, -0.7578...","[-0.19237441, 0.028460102, 0.28109065, 0.52579...","[0.3184207, -0.70041174, 0.50199157, -0.395749...","[-0.48638898, -1.0513365, 0.16423285, 0.882752...","[0.53998387, 1.5832734, -0.006798582, -0.77460...","[0.2028964, 1.8658131, 0.5162393, 0.3256827, -..."
179824,Quantum Euclidean Spaces with Noncommutative D...,"(math.OA,)","[-0.28733668, -0.23055288, 0.38958988, -0.2490...","[0.57414204, -0.24860351, -0.107429475, 0.3834...","[0.42945963, 0.24710399, -0.23467553, -0.36466...","[-0.022679398, 0.033999454, 0.021820456, -0.01...","[0.2736504, 0.26398134, -0.08840776, 0.3314680...","[-1.0973357, -0.56398743, 0.051087204, -0.9894...","[0.1188062, 0.280105, 0.17062257, -0.17981789,...","[-0.39327112, -0.76186806, 0.5398385, -0.72155...",...,"[0.06680033, -0.0007584461, -0.04221967, 0.006...","[0.053765, 0.022055607, -0.0035675038, 0.01215...","[0.018427951, 0.07720141, -0.001431768, -0.044...","[0.095985636, 0.4896472, 0.33805853, -0.001373...","[0.25898236, 1.255652, 2.136975, -0.4343123, 0...","[0.60509795, -0.5543667, 0.9901327, -0.1548276...","[0.385985, 0.46274638, 1.0840086, -0.5587508, ...","[0.75580287, 1.1679249, 1.113982, 0.14102475, ...","[-2.6578612, 0.11497679, 0.08613574, -0.988827...","[-0.40530223, -0.95392776, -0.42633703, 0.0678..."


In [13]:
#df.to_pickle("arxiv-title-embeddings-single-class-multi.pkl")
df.to_pickle("arxiv-full-text-embeddings-single-class-multi.pkl")