In [2]:
import json, re
from datasets import load_dataset, load_from_disk, Dataset, DatasetDict

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def process_arxiv_id_with_regex(arxiv_id):
    pattern = r"^arXiv:(\d+\.\d+)(?:v\d+)?$"
    match = re.match(pattern, arxiv_id)
    if match:
        return match.group(1)
    return arxiv_id

In [4]:
with open("/scratch/lamdo/arxiv_dataset/arxivid2metadata.json") as f:
    arxivid2metadata = json.load(f)

In [6]:

ds = load_from_disk("/scratch/lamdo/arxiv_classification/arxiv_data")

In [7]:
ds["train"][0]

{'label': 8, 'arxiv_id': 'arXiv:1611.03253v1'}

In [8]:
processed_dataset = {}
for split in ["train", "validation", "test"]:
    original_split_arxiv_ids = ds[split]["arxiv_id"]
    labels = ds[split]["label"]

    split_arxiv_ids = [process_arxiv_id_with_regex(arxiv_id) for arxiv_id in original_split_arxiv_ids]
    metadata = [arxivid2metadata.get(arxiv_id) for arxiv_id in split_arxiv_ids]

    out = []
    for i in range(len(split_arxiv_ids)):
        if not metadata[i]:
            continue
        label = labels[i]
        arxiv_id = original_split_arxiv_ids[i]
        title = re.sub('\s+', ' ', metadata[i]["title"].replace("\n", " "))
        abstract = re.sub('\s+', ' ', metadata[i]["abstract"].replace("\n", " "))
        out.append({"label": label, "arxiv_id": arxiv_id, "title": title, "abstract": abstract})
    processed_dataset[split] = out

In [10]:
dataset_test = DatasetDict({
    split: Dataset.from_list([{"paper_id": line["arxiv_id"], "label": line["label"]} for line in processed_dataset[split]]) for split in ["train", "test"]
})

dataset = DatasetDict({
    "evaluation": Dataset.from_dict(
        {"doc_id": [line["arxiv_id"] for line in processed_dataset["train"] + processed_dataset["test"]], 
         "title": [line["title"] for line in processed_dataset["train"] + processed_dataset["test"]],
         "abstract": [line["abstract"] for line in processed_dataset["train"] + processed_dataset["test"]],
         "label": [line["label"] for line in processed_dataset["train"] + processed_dataset["test"]]}
         ),
})

In [11]:
dataset

DatasetDict({
    evaluation: Dataset({
        features: ['doc_id', 'title', 'abstract', 'label'],
        num_rows: 27395
    })
})

In [12]:
dataset_test.save_to_disk("/scratch/lamdo/arxiv_classification/arxiv_data_t+a_test/")

Saving the dataset (1/1 shards): 100%|██████████| 25200/25200 [00:00<00:00, 1304846.25 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2195/2195 [00:00<00:00, 387870.63 examples/s]


In [13]:
dataset.save_to_disk("/scratch/lamdo/arxiv_classification/arxiv_data_t+a/") 

Saving the dataset (1/1 shards): 100%|██████████| 27395/27395 [00:00<00:00, 648092.76 examples/s]
