In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from datasets import DatasetDict, Dataset, ClassLabel
from transformers import pipeline
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score
import pandas as pd
import glob
from sklearn.model_selection import train_test_split
from src.utils import map_category


In [3]:
# Define the path to the interim data directory
data_path = "data/interim/part-*.json"
json_files = glob.glob(data_path)
papers_df = pd.concat([pd.read_json(file, lines=True) for file in json_files], ignore_index=True)

papers_df["label"] = papers_df["main_category"].apply(map_category)
papers_df["text"] = papers_df["title"] + "\n" + papers_df["summary"]
papers_df = papers_df[["text", "label"]]


In [4]:
# First split: 80% training, 20% temp (test + validation)
train_df, temp_df = train_test_split(
    papers_df, 
    test_size=0.2,
    stratify=papers_df["label"],
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df["label"],
    random_state=42
)


In [5]:
# Print split sizes to verify
print(f"Training set: {len(train_df)} examples ({len(train_df)/len(papers_df)*100:.1f}%)")
print(f"Validation set: {len(val_df)} examples ({len(val_df)/len(papers_df)*100:.1f}%)")
print(f"Test set: {len(test_df)} examples ({len(test_df)/len(papers_df)*100:.1f}%)")

# Check category distribution across all splits using pandas
print(f"Total categories to classify: {len(papers_df['label'].unique())}")

# Dictionary of DataFrames for easy iteration
split_dfs = {
    "train": train_df,
    "validation": val_df,
    "test": test_df
}

for split_name, df in split_dfs.items():
    # Get unique categories in this split
    split_categories = set(df["label"].unique())
    
    # Check if all categories are present
    missing_categories = set(papers_df["label"]) - split_categories
    
    print(f"\nSplit: {split_name}")
    print(f"Number of unique categories: {len(split_categories)}")
    print(f"Categories present: {len(split_categories)}/{len(papers_df['label'].unique())}")


Training set: 7121 examples (80.0%)
Validation set: 890 examples (10.0%)
Test set: 891 examples (10.0%)
Total categories to classify: 20

Split: train
Number of unique categories: 20
Categories present: 20/20

Split: validation
Number of unique categories: 20
Categories present: 20/20

Split: test
Number of unique categories: 20
Categories present: 20/20


In [6]:
# Create the final DatasetDict with train, validation, and test splits
data = DatasetDict({
    "train": Dataset.from_pandas(train_df, preserve_index=False),
    "validation": Dataset.from_pandas(val_df, preserve_index=False),
    "test": Dataset.from_pandas(test_df, preserve_index=False)
})

# Print the final split dataset structure
print(data)

labels = sorted(papers_df["label"].unique())
class_label = ClassLabel(names=labels)

data = data.cast_column("label", class_label)
print(data)


DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 7121
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 890
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 891
    })
})


Casting the dataset:   0%|          | 0/7121 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/890 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/891 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 7121
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 890
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 891
    })
})


In [8]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

# Convert text to embeddings
train_embeddings = model.encode(data["train"]["text"], show_progress_bar=True)
valid_embeddings = model.encode(data["validation"]["text"], show_progress_bar=True)
test_embeddings = model.encode(data["test"]["text"], show_progress_bar=True)


Batches:   0%|          | 0/223 [00:00<?, ?it/s]

Batches:   0%|          | 0/28 [00:00<?, ?it/s]

Batches:   0%|          | 0/28 [00:00<?, ?it/s]

In [10]:
from sklearn.linear_model import LogisticRegression

lr_clf = LogisticRegression(random_state=42)
lr_clf.fit(train_embeddings, data["train"]["label"])
y_pred = lr_clf.predict(valid_embeddings)
print(classification_report(data["validation"]["label"], y_pred))


              precision    recall  f1-score   support

           0       0.91      0.91      0.91        56
           1       0.82      0.86      0.84        64
           2       0.87      0.93      0.90       381
           3       0.50      0.33      0.40         6
           4       0.61      0.37      0.46        38
           5       0.61      0.73      0.67        15
           6       1.00      0.20      0.33         5
           7       0.00      0.00      0.00         3
           8       0.62      0.83      0.71        18
           9       0.75      0.69      0.72        13
          10       0.85      0.87      0.86       150
          11       0.00      0.00      0.00         4
          12       0.00      0.00      0.00         4
          13       1.00      0.33      0.50         3
          14       0.60      0.50      0.55         6
          15       0.77      0.66      0.71        56
          16       0.50      0.14      0.22         7
          17       0.33    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [13]:
from sklearn.svm import SVC

svc_clf = SVC(random_state=42, class_weight="balanced")
svc_clf.fit(train_embeddings, data["train"]["label"])
y_pred = svc_clf.predict(valid_embeddings)
print(classification_report(data["validation"]["label"], y_pred))


              precision    recall  f1-score   support

           0       0.96      0.88      0.92        56
           1       0.79      0.83      0.81        64
           2       0.95      0.78      0.86       381
           3       0.50      0.67      0.57         6
           4       0.36      0.68      0.47        38
           5       0.62      0.87      0.72        15
           6       1.00      0.80      0.89         5
           7       1.00      0.67      0.80         3
           8       0.80      0.89      0.84        18
           9       0.53      0.77      0.62        13
          10       0.84      0.82      0.83       150
          11       0.00      0.00      0.00         4
          12       0.50      0.25      0.33         4
          13       1.00      0.67      0.80         3
          14       0.80      0.67      0.73         6
          15       0.65      0.73      0.69        56
          16       0.27      0.43      0.33         7
          17       0.50    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [12]:
import xgboost as xgb

params = {
    "n_estimators": 1000,
    'max_depth': 7,
    'eta': 0.3,
    'objective': 'multi:softmax',
    'num_class': len(labels),
    'eval_metric': 'mlogloss'
}

xgb_clf = xgb.XGBClassifier(**params)
xgb_clf.fit(train_embeddings, data["train"]["label"])
y_pred = xgb_clf.predict(valid_embeddings)
print(classification_report(data["validation"]["label"], y_pred))


              precision    recall  f1-score   support

           0       0.91      0.91      0.91        56
           1       0.79      0.81      0.80        64
           2       0.87      0.94      0.90       381
           3       0.38      0.50      0.43         6
           4       0.65      0.29      0.40        38
           5       0.67      0.67      0.67        15
           6       1.00      0.40      0.57         5
           7       0.50      0.33      0.40         3
           8       0.57      0.67      0.62        18
           9       0.69      0.69      0.69        13
          10       0.84      0.87      0.86       150
          11       0.00      0.00      0.00         4
          12       0.00      0.00      0.00         4
          13       1.00      0.33      0.50         3
          14       0.50      0.33      0.40         6
          15       0.64      0.61      0.62        56
          16       0.00      0.00      0.00         7
          17       0.50    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
