In [None]:
!pip install datasets transformers --quiet
!pip install ipywidgets --user --quiet

# Dataset preparation

In [None]:
import tensorflow as tf
from tensorflow.keras import activations, optimizers, losses
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
from datasets import load_dataset
import pickle

# v177
dataset = load_dataset("arxiv_dataset", data_dir="datasets", split="train", trust_remote_code=True, verification_mode="no_checks")#.select(range(100000)) 

def keep_first_arxiv_category(example):
    example["category"] = example["categories"].split(' ', 1)[0]
    return example

dataset = dataset.map(keep_first_arxiv_category)
dataset = dataset.class_encode_column("category")
dataset = dataset.remove_columns(["id", "submitter", "authors", "comments", "journal-ref", "doi",
                                  "report-no", "categories", "license", "abstract", "update_date"])
dataset = dataset.shuffle(seed=42)
dataset = dataset.train_test_split(test_size=0.2, seed=42)
dataset

In [None]:
num_labels = dataset['train'].features['category'].num_classes
id2label = dict(enumerate(dataset['train'].features['category'].names))
label2id = {val: key for key, val in id2label.items()}

num_labels

In [None]:
batch_size = 16
MODEL_NAME = 'distilbert-base-uncased'
MAX_LEN = 200

tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)

def get_tf_dataset(split):
    train_features = tokenizer(dataset[split]["title"], max_length=MAX_LEN, truncation=True, padding=True)
    return tf.data.Dataset.from_tensor_slices((dict(train_features), dataset[split]["category"])).batch(batch_size)

train_tf_dataset = get_tf_dataset("train")
test_tf_dataset = get_tf_dataset("test")

# Training

In [None]:
model = TFDistilBertForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=num_labels, id2label=id2label, label2id=label2id)

In [None]:
from time import time, strftime
num_epochs = 5

optimizer = optimizers.Adam(learning_rate=3e-5)
loss = losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer,
              loss=loss,
              steps_per_execution=100,
              metrics=["accuracy"])

filename = "distilbert_arxiv_" + strftime("%Y%m%d-%H%M%S")
print(filename)

start = time()
history = model.fit(train_tf_dataset,
                    batch_size=batch_size,
                    epochs=num_epochs, initial_epoch=1)
end = time()
print(end-start)

model.save_pretrained(f"arxiv_model/{filename}")
with open(f"arxiv_model/{filename} info.pkl", "wb") as f:
    pickle.dump((MODEL_NAME, MAX_LEN, start, end), f)
with open(f"arxiv_model/{filename} history.pkl", "wb") as f:
    pickle.dump(history.history, f)

In [None]:
benchmarks = model.evaluate(test_tf_dataset, return_dict=True, batch_size=batch_size)
print(benchmarks)

with open(f"arxiv_model/{filename} eval.pkl", "wb") as f:
    pickle.dump(benchmarks, f)

# Write examples, labels, and predictions to a file

In [None]:
all_examples = dataset["train"]["title"] + dataset["test"]["title"]
all_examples_tf_dataset = tf.data.Dataset.from_tensor_slices(dict(tokenizer(all_examples, max_length=MAX_LEN, truncation=True, padding=True))).batch(batch_size)

In [None]:
pred_start = time()
preds = model.predict(all_examples_tf_dataset).logits
pred_end = time()
print(pred_end - pred_start, "seconds")

preds = activations.softmax(tf.convert_to_tensor(preds)).numpy()
preds.shape

In [None]:
import pandas as pd

all_labels = dataset["train"]["category"] + dataset["test"]["category"]
df = pd.concat([pd.DataFrame({'example': all_examples, 'label': all_labels}),
                pd.DataFrame(data=preds, columns=dataset['train'].features['category'].names)], axis=1)
df

In [None]:
df.to_csv(f"arxiv_model/{filename} outputs.csv", index=False)

In [None]:
import numpy as np

correct = (preds.argmax(axis=1) == np.array(all_labels))
correct_train = correct[:len(dataset["train"])].sum()
correct_test = correct[len(dataset["train"]):].sum()
print("Train accuracy", correct_train / len(dataset["train"]), correct_train)
print("Test accuracy", correct_test / len(dataset["test"]), correct_test)

# Query the model interactively

In [None]:
from transformers import TextClassificationPipeline
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
pipe("Learned retrieval data structures", top_k=None)