In [1]:
!pip install sentence-transformers

Collecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- done
Building wheels for collected packages: sentence-transformers
  Building wheel for sentence-transformers (setup.py) ... [?25l- \ | done
[?25h  Created wheel for sentence-transformers: filename=sentence_transformers-2.2.2-py3-none-any.whl size=125938 sha256=d5d44ea43576e56e594e5211f890d92187f5a1694c8aa429c673c531776b6457
  Stored in directory: /root/.cache/pip/wheels/bf/06/fb/d59c1e5bd1dac7f6cf61ec0036cc3a10ab8fecaa6b2c3d3ee9
Successfully built sentence-transformers
Installing collected packages: sentence-transformers
Successfully installed sentence-transformers-2.2.2
[0m

In [2]:
!ls "../input/arxiv"

arxiv-metadata-oai-snapshot.json


In [3]:
import json
import pandas as pd
import os
import re
import string


DATA_PATH = "../input/arxiv/arxiv-metadata-oai-snapshot.json"
YEAR_CUTOFF = 2000
YEAR_PATTERN = r"(19|20[0-9]{2})"
ML_CATEGORY = ["cs.LG", "cs.AI", "stats.ML", "cs.CL", "cs.CV", "cs.SI", "cs.SE"]

In [4]:
from nltk import tokenize
def find_what(text):
    sentences = tokenize.sent_tokenize(text)
    what_keywords = ["presents", "present", "constructs", "construct", "build", "in this paper", "this paper attempts"]
    whats = []
    for sent in sentences:
        sent = ' '.join(sent.split(' ')[:4])
        if any(w in sent.lower() for w in what_keywords):
            whats.append(sent)
    return whats

In [5]:
def find_why(text):
    sentences = tokenize.sent_tokenize(text)
    why_keywords = ["however", "tend to", "can't", "didn't", "despite"]
    whys = []
    for sent in sentences:
        sent = ' '.join(sent.split(' ')[:4])
        if any(w in sent.lower() for w in why_keywords):
            whys.append(sent)
    return whys

In [6]:
def process(paper: dict):
    paper = json.loads(paper)
    if paper['journal-ref']:
        years = [int(year) for year in re.findall(YEAR_PATTERN, paper['journal-ref'])]
        years = [year for year in years if (year <= 2022 and year >= 1991)]
        year = min(years) if years else None
    else:
        year = None
    
    what = find_what(paper['abstract'])
    why = find_why(paper['abstract'])
    return {
        'id': paper['id'],
        'title': paper['title'],
        'year': year,
        'authors': paper['authors'],
        'categories': ','.join(paper['categories'].split(' ')),
        'abstract': paper['abstract'],
        'what': " | ".join(what),
        'why': " | ".join(why)
    }

def papers():
    with open(DATA_PATH, 'r') as f:
        for paper in f:
            paper = process(paper)
            contains_cat = False
            for cat in ML_CATEGORY:
                if cat in paper['categories']:
                    contains_cat = True
                    break
                    
            if paper['year']:
                if paper['year'] >= YEAR_CUTOFF and contains_cat:
                    yield paper

In [7]:
df = pd.DataFrame(papers())
len(df)

30290

In [8]:
# Avg length of the abstracts
df.abstract.apply(lambda a: len(a.split())).mean()

164.04179597226806

In [9]:
def clean_description(description: str):
    if not description:
        return ""
    # remove unicode characters
    description = description.encode('ascii', 'ignore').decode()

    # remove punctuation
    description = re.sub('[%s]' % re.escape(string.punctuation), ' ', description)

    # clean up the spacing
    description = re.sub('\s{2,}', " ", description)

    # remove urls
    #description = re.sub("https*\S+", " ", description)

    # remove newlines
    description = description.replace("\n", " ")

    # remove all numbers
    #description = re.sub('\w*\d+\w*', '', description)

    # split on capitalized words
    description = " ".join(re.split('(?=[A-Z])', description))

    # clean up the spacing again
    description = re.sub('\s{2,}', " ", description)

    # make all words lowercase
    description = description.lower()

    return description

In [10]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

In [11]:
# Create embeddings from the title and abstract
emb = model.encode(df.apply(lambda r: clean_description(r['title'] + ' ' + r['abstract']), axis=1).tolist())

Batches:   0%|          | 0/947 [00:00<?, ?it/s]

In [12]:
# Add embeddings to df
df = df.reset_index().drop('index', axis=1)
df['vector'] = emb.tolist()

In [13]:
import pickle

# Export to file!
with open('arxiv_embeddings_10000.pkl', 'wb') as f:
    data = pickle.dumps(df)
    f.write(data)