In [1]:
from datasets import DatasetDict, Dataset, ClassLabel, load_dataset
import pandas as pd
import glob
from sklearn.model_selection import train_test_split
from src.utils import map_category


In [2]:
data_path = "data/interim/part-*.json"
json_files = glob.glob(data_path)
stream_df = pd.concat([pd.read_json(file, lines=True) for file in json_files], ignore_index=True)

stream_df["label"] = stream_df["main_category"].apply(map_category)
stream_df["text"] = stream_df["title"] + "\n" + stream_df["summary"]
stream_df = stream_df[["text", "label"]]


In [3]:
stream_df["label"].value_counts()


label
cs          4747
math        1913
cond-mat     795
physics      711
astro-ph     700
quant-ph     521
eess         485
hep          453
stat         229
gr-qc        190
nucl         101
q-bio         86
econ          79
nlin          47
math-ph       43
q-fin         43
Name: count, dtype: int64

In [4]:
train_df, temp_df = train_test_split(
    stream_df, 
    test_size=0.3,
    stratify=stream_df["label"],
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df["label"],
    random_state=42
)

train_df["label"].value_counts()

label
cs          3323
math        1339
cond-mat     556
physics      498
astro-ph     490
quant-ph     365
eess         340
hep          317
stat         160
gr-qc        133
nucl          71
q-bio         60
econ          55
nlin          33
q-fin         30
math-ph       30
Name: count, dtype: int64

In [5]:
aux_data = load_dataset("real-jiakai/arxiver-with-category")

aux_data.set_format(type="pandas")
aux_df = aux_data["train"][:]
aux_df["label"] = aux_df["primary_category"].apply(map_category)
aux_df["title"] = aux_df["title"].str.replace("\n  ", " ")
aux_df["text"] = aux_df["title"] + "\n" + aux_df["abstract"]
aux_df = aux_df[["text", "label"]]


In [6]:
aux_df["label"].value_counts()


label
cs          26733
math         9611
cond-mat     4660
astro-ph     4453
physics      4163
quant-ph     2930
hep          2900
eess         2839
stat         1546
gr-qc        1187
q-bio         677
nucl          512
math-ph       365
econ          288
nlin          262
q-fin         231
Name: count, dtype: int64

In [7]:
all_train_df = pd.concat([train_df, aux_df], ignore_index=True)

all_train_df["label"].value_counts()


label
cs          30056
math        10950
cond-mat     5216
astro-ph     4943
physics      4661
quant-ph     3295
hep          3217
eess         3179
stat         1706
gr-qc        1320
q-bio         737
nucl          583
math-ph       395
econ          343
nlin          295
q-fin         261
Name: count, dtype: int64

In [8]:
# aug_train_df = pd.concat([train_df, sub_aux_df])
print(f"Shape after augmentation: {all_train_df.shape}")
all_train_df["label"].value_counts()


Shape after augmentation: (71157, 2)


label
cs          30056
math        10950
cond-mat     5216
astro-ph     4943
physics      4661
quant-ph     3295
hep          3217
eess         3179
stat         1706
gr-qc        1320
q-bio         737
nucl          583
math-ph       395
econ          343
nlin          295
q-fin         261
Name: count, dtype: int64

In [9]:
all_stream_data = DatasetDict({
    "train": Dataset.from_pandas(all_train_df, preserve_index=False),
    "validation": Dataset.from_pandas(val_df, preserve_index=False),
    "test": Dataset.from_pandas(test_df, preserve_index=False)
})

labels = sorted(all_train_df["label"].unique())
class_label = ClassLabel(names=labels)

all_stream_data = all_stream_data.cast_column("label", class_label)

all_stream_data.save_to_disk("data/processed/all_stream_data")


Casting the dataset:   0%|          | 0/71157 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/1671 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/1672 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/71157 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1671 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1672 [00:00<?, ? examples/s]