In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from datasets import DatasetDict, Dataset
from transformers import pipeline
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score
import pandas as pd
import glob
from sklearn.model_selection import train_test_split


In [3]:
# Define the path to the interim data directory
data_path = "data/interim/part-*.json"
json_files = glob.glob(data_path)
papers_df = pd.concat([pd.read_json(file, lines=True) for file in json_files], ignore_index=True)

# papers_df["arxiv_main_category"] = papers_df["main_category"].apply(map_category)
papers_df["text"] = papers_df["title"] + "\n" + papers_df["summary"]
papers_df = papers_df[["text", "main_category"]]


In [4]:
# Filter out rare categories
category_counts = papers_df['main_category'].value_counts()
min_examples = 10
candidate_categories = category_counts[category_counts >= min_examples].index.tolist()

# Filter the DataFrame to only include these categories
papers_df = papers_df[papers_df['main_category'].isin(candidate_categories)]


In [5]:
# First split: 80% training, 20% temp (test + validation)
train_df, temp_df = train_test_split(
    papers_df, 
    test_size=0.2,
    stratify=papers_df["main_category"],
    random_state=42
)

# Second split: Split temp into validation (50%) and test (50%)
# This gives us the final 10% validation, 10% test split from original dataset
val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df["main_category"],
    random_state=42
)


In [6]:
# Print split sizes to verify
print(f"Training set: {len(train_df)} examples ({len(train_df)/len(papers_df)*100:.1f}%)")
print(f"Validation set: {len(val_df)} examples ({len(val_df)/len(papers_df)*100:.1f}%)")
print(f"Test set: {len(test_df)} examples ({len(test_df)/len(papers_df)*100:.1f}%)")

# Check category distribution across all splits using pandas
print(f"Total categories to classify: {len(papers_df['main_category'].unique())}")

# Dictionary of DataFrames for easy iteration
split_dfs = {
    "train": train_df,
    "validation": val_df,
    "test": test_df
}

for split_name, df in split_dfs.items():
    # Get unique categories in this split
    split_categories = set(df["main_category"].unique())
    
    # Check if all categories are present
    missing_categories = set(papers_df["main_category"]) - split_categories
    
    print(f"\nSplit: {split_name}")
    print(f"Number of unique categories: {len(split_categories)}")
    print(f"Categories present: {len(split_categories)}/{len(papers_df['main_category'].unique())}")


Training set: 5310 examples (80.0%)
Validation set: 664 examples (10.0%)
Test set: 664 examples (10.0%)
Total categories to classify: 103

Split: train
Number of unique categories: 103
Categories present: 103/103

Split: validation
Number of unique categories: 103
Categories present: 103/103

Split: test
Number of unique categories: 103
Categories present: 103/103


In [7]:
# Create the final DatasetDict with train, validation, and test splits
data = DatasetDict({
    "train": Dataset.from_pandas(train_df, preserve_index=False),
    "validation": Dataset.from_pandas(val_df, preserve_index=False),
    "test": Dataset.from_pandas(test_df, preserve_index=False)
})

# Print the final split dataset structure
print(data)


DatasetDict({
    train: Dataset({
        features: ['text', 'main_category'],
        num_rows: 5310
    })
    validation: Dataset({
        features: ['text', 'main_category'],
        num_rows: 664
    })
    test: Dataset({
        features: ['text', 'main_category'],
        num_rows: 664
    })
})


In [8]:
# Path to HF model
model_path = "oracat/bert-paper-classifier-arxiv"

# Load model into pipeline
cls = pipeline(
    "text-classification",
    model=model_path,
    tokenizer=model_path,
    return_all_scores=True,
    device=-1,
    truncation=True,
    max_length=512
)


Device set to use cpu


In [9]:
def top_pct(preds, threshold=.95):
    preds = sorted(preds, key=lambda x: -x["score"])
    
    cum_score = 0
    for i, item in enumerate(preds):
        cum_score += item["score"]
        if cum_score >= threshold:
            break

    preds = preds[:(i+1)]
    
    return preds


def format_predictions(preds) -> str:
    """
    Prepare predictions and their scores for printing to the user
    """
    out = ""
    for i, item in enumerate(preds):
        out += f"{i+1}. {item['label']} (score {item['score']:.2f})\n"
    return out


In [10]:
# Process all examples and collect predictions in one pass
results = [(example["main_category"], 
            cls(example["text"])[0], 
            top_pct(cls(example["text"])[0])) 
           for example in tqdm(data["test"], desc="Classifying")]

# Unpack results
true_labels, raw_preds, all_predictions = zip(*results)
predicted_labels = [p[0]["label"] for p in all_predictions]

# Calculate and display metrics
accuracy = accuracy_score(true_labels, predicted_labels)
print(f"Test: {len(true_labels)} examples, {len(set(true_labels))} classes, {len(set(predicted_labels))} predicted")
print(f"Accuracy: {accuracy:.4f} ({int(accuracy * len(true_labels))}/{len(true_labels)})")
print(f"\nClassification Report:\n{classification_report(true_labels, predicted_labels)}")

# Show sample predictions
print("\nSample Predictions (first 5):")
for i, (true, pred, top_preds) in enumerate(zip(true_labels[:5], predicted_labels[:5], all_predictions[:5])):
    print(f"\nEx {i+1}: True: {true} | Pred: {pred}")
    print(format_predictions(top_preds))
    

Classifying: 100%|██████████| 664/664 [03:39<00:00,  3.02it/s]

Test: 664 examples, 103 classes, 59 predicted
Accuracy: 0.3178 (211/664)

Classification Report:
                    precision    recall  f1-score   support

       astro-ph.CO       0.09      0.50      0.15         8
       astro-ph.EP       0.00      0.00      0.00         6
       astro-ph.GA       0.00      0.00      0.00         9
       astro-ph.HE       0.00      0.00      0.00         8
       astro-ph.IM       0.11      0.75      0.19         4
       astro-ph.SR       0.00      0.00      0.00         7
   cond-mat.dis-nn       0.00      0.00      0.00         2
 cond-mat.mes-hall       0.00      0.00      0.00         9
 cond-mat.mtrl-sci       0.00      0.00      0.00        14
cond-mat.quant-gas       0.00      0.00      0.00         3
     cond-mat.soft       0.00      0.00      0.00         7
cond-mat.stat-mech       0.03      0.25      0.06         4
   cond-mat.str-el       0.00      0.00      0.00         8
 cond-mat.supr-con       0.00      0.00      0.00         4
  


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
