# SciBERT Cross-Encoder for Citation Linking

This notebook contains Arian Houshmand's and C.J. DuHamel's work to train and evaluate a **SciBERT-based cross-encoder** that predicts whether a given sentence cites a specific reference paper. (or not)

**Task definition**

- **Input:**  
  - A `sentence` from a scientific paper.  
  - A `ref_block` describing a candidate reference (title, authors, and reference text).

- **Output:**  
  - A binary label indicating whether the sentence is truly referring to that reference (`1 = refers`, `0 = does not refer`).



In [1]:
# imports
import json
import os
from collections import Counter
from statistics import mean
import re

import numpy as np
import pandas as pd


In [2]:
#!pip install -U "transformers" "scikit-learn" "accelerate" -q

import transformers
import torch
import sklearn

print("Transformers version:", transformers.__version__)
print("Torch version:", torch.__version__)
print("sklearn version:", sklearn.__version__)

Transformers version: 4.57.3
Torch version: 2.9.0+cu126
sklearn version: 1.6.1


In [3]:
#loading json object file
DATA_PATH = "complete_dataset.jsonl"
assert os.path.exists(DATA_PATH), f"File not found: {DATA_PATH}"


examples = []
with open(DATA_PATH, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        examples.append(json.loads(line))

print(f"Loaded {len(examples)} examples.")



Loaded 4120 examples.


In [4]:

df = pd.DataFrame(examples)

print("\nColumns in dataset:")
print(df.columns.tolist())

print("\nFirst 3 rows:")
df.head()



Columns in dataset:
['original_paper_id', 'sentence', 'ref_paper_id', 'ref_paper_title', 'ref_paper_authors', 'ref_paper_text', 'label']

First 3 rows:


Unnamed: 0,original_paper_id,sentence,ref_paper_id,ref_paper_title,ref_paper_authors,ref_paper_text,label
0,7255,It is believed that eosinophil transmigration ...,BIBREF37,,,\nUpdate on Anticytokine Treatment for Asthma\...,1
1,7255,It is believed that eosinophil transmigration ...,BIBREF39,,,\nAnti-interleukin-5 therapy in severe asthma\...,0
2,7255,It is believed that eosinophil transmigration ...,BIBREF45,In Vitro Generation of Interleukin 10–producin...,"F. Barrat, D. Cua, A. Boonstra, D. Richards, C...",\nIntroduction\n\nCD4 ϩ T cell subsets include...,0
3,7255,"Specifically, IL-5 is the most important media...",BIBREF37,,,\nUpdate on Anticytokine Treatment for Asthma\...,0
4,7255,"Specifically, IL-5 is the most important media...",BIBREF39,,,\nAnti-interleukin-5 therapy in severe asthma\...,1


In [5]:

#check for distubution/class imbalance

if "label" in df.columns:
    label_counts = df["label"].value_counts().to_dict()
    total = len(df)
    print("\nLabel distribution:")
    for label, count in label_counts.items():
        print(f"  {label}: {count} ({count/total:.2%})")
else:
    print("\nNo 'label' column found in the data!")




Label distribution:
  0: 2818 (68.40%)
  1: 1302 (31.60%)


In [6]:
#chekc for missing key fields

def is_empty_or_nan(x):
    """ret true if value is NaN, None, or an empty"""
    if x is None:
        return True
    if isinstance(x, float) and np.isnan(x):
        return True
    if isinstance(x, str) and x.strip() == "":
        return True
    return False

for col in ["sentence", "ref_paper_text", "ref_paper_title", "ref_paper_authors"]:
    if col in df.columns:
        missing_custom = df[col].apply(is_empty_or_nan).sum()
        print(
            f"Missing (NaN or empty) in {col}: "
            f"{missing_custom}/{len(df)} ({missing_custom/len(df):.2%})"
        )
    else:
        print(f"Column {col} not in dataset.")


Missing (NaN or empty) in sentence: 0/4120 (0.00%)
Missing (NaN or empty) in ref_paper_text: 4/4120 (0.10%)
Missing (NaN or empty) in ref_paper_title: 2013/4120 (48.86%)
Missing (NaN or empty) in ref_paper_authors: 2013/4120 (48.86%)


In [7]:

# Length stats

def safe_len_words(x):
    """Return word count if x is a string, else 0."""
    if not isinstance(x, str):
        return 0
    return len(x.split())

def safe_len_chars(x):
    """Return character count if x is a string, else 0."""
    if not isinstance(x, str):
        return 0
    return len(x)

# Sentence stats
if "sentence" in df.columns:
    sent_lens = df["sentence"].apply(safe_len_words).tolist()
    sent_lens_sorted = sorted(sent_lens)
    print("\nSentence length (in words):")
    print(f"  Mean: {mean(sent_lens):.2f}")
    print(f"  Median: {sent_lens_sorted[int(0.5 * len(sent_lens))]}")
    print(f"  Max: {max(sent_lens)}")

# Reference text  stat(in characters)
if "ref_paper_text" in df.columns:
    ref_lens_chars = df["ref_paper_text"].apply(safe_len_chars).tolist()
    ref_lens_sorted = sorted(ref_lens_chars)
    print("\nref_paper_text length (in charas):")
    print(f"  Mean: {mean(ref_lens_chars):.2f}")
    print(f"  Median: {ref_lens_sorted[int(0.5 * len(ref_lens_chars))]}")
    print(f"  Max: {max(ref_lens_chars)}")


Sentence length (in words):
  Mean: 24.43
  Median: 23
  Max: 84

ref_paper_text length (in charas):
  Mean: 1650.48
  Median: 2000
  Max: 2000


## EDA results

basically we can see that there is a class imabalance of about 2 to 1 for our neg to pos ratio. We need to take this imabalance into account for our training (probably need to use weighted cross entropy).

We also have some missing values to deal with.

I calculated the lenght of sentence and ref texts to see if we need to modify the token limit for our model.


In [8]:



# 1. Drop rows with missing or empty ref paper text AND sentence

n_before = len(df)

if "ref_paper_text" not in df.columns:
    raise ValueError("ref_paper_text column not found in DataFrame.")
if "sentence" not in df.columns:
    raise ValueError("sentence column not found in DataFrame.")

mask_ref_ok = ~df["ref_paper_text"].apply(is_empty_or_nan)
mask_sent_ok = ~df["sentence"].apply(is_empty_or_nan)
mask_keep = mask_ref_ok & mask_sent_ok


df = df[mask_keep].reset_index(drop=True)
n_after = len(df)

print(f"Rows before cleaning: {n_before}")
print(f"Rows after  cleaning: {n_after}")
print(f"Dropped rows: {n_before - n_after} ({(n_before - n_after) / n_before:.2%})")



Rows before cleaning: 4120
Rows after  cleaning: 4116
Dropped rows: 4 (0.10%)


In [9]:


# Build a single ref_block = title + authors + ref_paper_text


def clean_text(x: str) -> str:
    """Normalize whitespace; return empty string if not a valid string."""
    if not isinstance(x, str):
        return ""
    x = x.strip()
    # turn multiple whitespace characters into a single space
    x = re.sub(r"\s+", " ", x)
    return x

def build_ref_block(row) -> str:
    title   = clean_text(row.get("ref_paper_title", ""))
    authors = clean_text(row.get("ref_paper_authors", ""))
    text    = clean_text(row.get("ref_paper_text", ""))

    parts = []
    if title:
        parts.append(title)
    if authors:
        parts.append(authors)
    if text:
        parts.append(text)

    # Join non-empty parts with ". " to stay comppact
    ref_block = ". ".join(parts)
    return ref_block.strip()

df["ref_block"] = df.apply(build_ref_block, axis=1)



In [10]:

print("\nSample sentence + ref_block pairs:\n")
num_examples_to_show = 2
for i in range(min(num_examples_to_show, len(df))):
    print(f"Example {i}")
    print("Sentence:", df.loc[i, "sentence"])
    rb = df.loc[i, "ref_block"]
    if len(rb) > 400:
        print("Ref block:", rb[:400] + " ...")
    else:
        print("Ref block:", rb)
    print()




Sample sentence + ref_block pairs:

Example 0
Sentence: It is believed that eosinophil transmigration into the airways is orchestrated by cytokines, such as IL-4, IL-5, TNF-, and IL-13, and is coordinated by specific chemokines, such as eotaxin.
Ref block: Update on Anticytokine Treatment for Asthma Hindawi Publishing CorporationCopyright Hindawi Publishing Corporation Luca Gallelli Clinical Pharmacology Unit Department of Health Science University "Magna Graecia" of Catanzaro Campus Universitario "S. Venuta" Viale Europa-Località Germaneto88100CatanzaroItaly Maria Teresa Busceti Department of Medical and Surgical Sciences University "Magna Graecia" ...

Example 1
Sentence: It is believed that eosinophil transmigration into the airways is orchestrated by cytokines, such as IL-4, IL-5, TNF-, and IL-13, and is coordinated by specific chemokines, such as eotaxin.
Ref block: Anti-interleukin-5 therapy in severe asthma Gilles Garcia gilles.garcia@bct.aphp.fr Faculté de médecine Université 

In [11]:
#len stats for re block

ref_block_lens = df["ref_block"].apply(lambda x: len(x) if isinstance(x, str) else 0).tolist()
ref_block_lens_sorted = sorted(ref_block_lens)

print("ref_block length (in characters):")
print(f"  Mean: {mean(ref_block_lens):.2f}")
print(f"  Median: {ref_block_lens_sorted[int(0.5 * len(ref_block_lens))]}")
print(f"  Max: {max(ref_block_lens)}")



ref_block length (in characters):
  Mean: 1730.69
  Median: 1990
  Max: 2760


In [12]:
#calculating class weights (for cross entropy) to deal with the imabalnces , this is just for eda we need to do this for train split only
if "label" not in df.columns:
    raise ValueError("label column not found in df.")

label_counts = df["label"].value_counts().to_dict()
N = len(df)
num_classes = 2

class_weights = {
    int(label): N / (num_classes * count)
    for label, count in label_counts.items()
}

print("\n Label dist AFTER dropping rows:")
for label, count in label_counts.items():
    print(f"  {label}: {count} ({count / N:.2%})")

print("\n totla class weights:")
for label, w in sorted(class_weights.items()):
    print(f"  label {label}: weight = {w:.4f}")


 Label dist AFTER dropping rows:
  0: 2816 (68.42%)
  1: 1300 (31.58%)

 totla class weights:
  label 0: weight = 0.7308
  label 1: weight = 1.5831


## cleaning and preprocessing results


before splitting the dataset into our train/test , I clean and prep our ref block

1. drop rows where our main metrics are empty / NA

2. Have a single field ref_block where I combine
   - `ref_paper_title`
   - `ref_paper_authors`
   - `ref_paper_text`  
   
 this is done after data cleaning

## Grouping train/validation/test split by paper

One important thing I noticed is that when splitting our data set into train, validation and test we need to avoid splitting up the same paper into multiple diff splits.

To avoid that potential data leakage:

- I group by `original_paper_id` and shuffled the unique paper ids once with a fixed random seed
- I assign 70% of papers to the training set, 15% to validation, and 15% to test.

Basically this causes every row to inherit its split from its `original_paper_id`, so no paper appears in more than one split.


After splitting I compute class weights from the training split only, which I will use later in the weighted cross entropy loss to handle label imbalances.


In [13]:

#grouping
if "original_paper_id" not in df.columns:
    raise ValueError("original_paper_id column not found - cant do grouped split.")

rng = np.random.RandomState(42)  #fixed seed for reporuducability

paper_ids = df["original_paper_id"].unique()
n_papers = len(paper_ids)
print(f"Total unique original_paper_id: {n_papers}")

# shuffle in place
rng.shuffle(paper_ids)

# 70% train, 15% val, 15% test
train_frac, val_frac = 0.70, 0.15
n_train = int(train_frac * n_papers)
n_val   = int(val_frac * n_papers)
n_test  = n_papers - n_train - n_val

train_ids = set(paper_ids[:n_train])
val_ids   = set(paper_ids[n_train:n_train + n_val])
test_ids  = set(paper_ids[n_train + n_val:])

print("\nAssigned paper IDs per split:")
print(f"  Train: {len(train_ids)}")
print(f"  Val:   {len(val_ids)}")
print(f"  Test:  {len(test_ids)}")

# Map each row to a split based on its original_paper_id
def assign_split(paper_id):
    if paper_id in train_ids:
        return "train"
    elif paper_id in val_ids:
        return "val"
    elif paper_id in test_ids:
        return "test"
    else:
        return "unknown"

df["split"] = df["original_paper_id"].apply(assign_split)

num_unknown = (df["split"] == "unknown").sum()
assert num_unknown == 0, f"{num_unknown} rows were not assigned to any split!"

# Create separate dfs
df_train = df[df["split"] == "train"].reset_index(drop=True)
df_val   = df[df["split"] == "val"].reset_index(drop=True)
df_test  = df[df["split"] == "test"].reset_index(drop=True)

print("\nRow counts per split:")
print(f"  Train: {len(df_train)}")
print(f"  Val:   {len(df_val)}")
print(f"  Test:  {len(df_test)}")




Total unique original_paper_id: 747

Assigned paper IDs per split:
  Train: 522
  Val:   112
  Test:  113

Row counts per split:
  Train: 2916
  Val:   513
  Test:  687


In [14]:

# Label distribution per split


def print_label_stats(name, dframe):
    counts = dframe["label"].value_counts().to_dict()
    total = len(dframe)
    print(f"\n label distribution in {name}:")
    for label, count in sorted(counts.items()):
        print(f"  {label}: {count} ({count/total:.2%})")

print_label_stats("train", df_train)
print_label_stats("val", df_val)
print_label_stats("test", df_test)



 label distribution in train:
  0: 2029 (69.58%)
  1: 887 (30.42%)

 label distribution in val:
  0: 326 (63.55%)
  1: 187 (36.45%)

 label distribution in test:
  0: 461 (67.10%)
  1: 226 (32.90%)


In [15]:
#calculating class weighted for trian split

train_counts = df_train["label"].value_counts().to_dict()
N_train = len(df_train)
num_classes = 2

class_weights_train = {
    int(label): N_train / (num_classes * count)
    for label, count in train_counts.items()
}

print("\n train split weights:")
for label, w in sorted(class_weights_train.items()):
    print(f"  label {label}: weight = {w:.4f}")


 train split weights:
  label 0: weight = 0.7186
  label 1: weight = 1.6437


## Baseline model training run

I will first run a baseline training using the default hyperparams
(lr = 2e-5, 5 epochs, class-weighted loss).

The set up is below


In [16]:
from torch.utils.data import Dataset as TorchDataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    set_seed,
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
#from sklearn.metrics import precision_recall_fscore_support

import torch

#model configts
set_seed(42)
MODEL_NAME = "allenai/scibert_scivocab_uncased"
MAX_LENGTH = 512  # this is to truncute long examples to math Bert's token count

# SciBERT tokenizer and model for binary classification
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,  # 0 = non-citation, 1 = true citation
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/442M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/442M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:

#tokenizing

class CitationDataset(TorchDataset):
    """
   createsz training data set wtih;
        - 'sentence': str
        - 'ref_block': str
        - 'label': int (0 or 1)

    Each item is tokenized as:
        [CLS] sentence [SEP] ref_block [SEP]

    """

    def __init__(self, df, tokenizer, max_length=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sentence = row["sentence"]
        ref_block = row["ref_block"]
        label = int(row["label"])

        encoding = self.tokenizer(
            sentence,
            ref_block,
            truncation="only_second",
            max_length=self.max_length,
        )

        encoding["labels"] = label
        return encoding

train_dataset = CitationDataset(df_train, tokenizer, MAX_LENGTH)
val_dataset   = CitationDataset(df_val, tokenizer, MAX_LENGTH)
test_dataset  = CitationDataset(df_test, tokenizer, MAX_LENGTH)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size:   {len(val_dataset)}")
print(f"Test dataset size:  {len(test_dataset)}")


Train dataset size: 2916
Val dataset size:   513
Test dataset size:  687


In [18]:
#using class weigths from earlier computation

w0 = class_weights_train[0]
w1 = class_weights_train[1]
class_weights_tensor = torch.tensor([w0, w1], dtype=torch.float)
print("\nUsing class weights:", class_weights_tensor)




Using class weights: tensor([0.7186, 1.6437])


In [19]:
# trainer with out own weighted loss

class WeightedTrainer(Trainer):
    def __init__(self, class_weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):

        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits


        loss_fct = torch.nn.CrossEntropyLoss(
            weight=self.class_weights.to(logits.device)
        )
        loss = loss_fct(
            logits.view(-1, self.model.config.num_labels),
            labels.view(-1),
        )

        if return_outputs:
            return loss, outputs
        return loss

# Metrics: accuracy, precision, recall, F1 for positive class (label=1)(model
# selection is by val loss only)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)

    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels,
        preds,
        average="binary",
        pos_label=1,
        zero_division=0,
    )

    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

# for dynamic padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


# 7. eval & save each epoch,
#    and keep the checkpoint with the LOWEST validation loss.


training_args = TrainingArguments(
    output_dir="./scibert_citation_classifier",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,          # base line info which we will hopefully tune later
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_steps=50,

    eval_strategy="epoch",
    save_strategy="epoch",

    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=2,
    seed=42,
    report_to="none",
)

#init trainer

trainer = WeightedTrainer(
    class_weights=class_weights_tensor,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("\n done. ready to start training")



 done. ready to start training


  super().__init__(*args, **kwargs)


In [20]:

trainer.train()


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6393,0.683443,0.641326,0.508021,0.508021,0.508021
2,0.5172,0.705323,0.596491,0.46124,0.636364,0.534831
3,0.371,1.133602,0.651072,0.523256,0.481283,0.501393
4,0.3226,1.632048,0.645224,0.514286,0.481283,0.497238
5,0.1945,1.843032,0.647173,0.516854,0.491979,0.50411


TrainOutput(global_step=1825, training_loss=0.4026275765405942, metrics={'train_runtime': 688.0712, 'train_samples_per_second': 21.19, 'train_steps_per_second': 2.652, 'total_flos': 3798647361763920.0, 'train_loss': 0.4026275765405942, 'epoch': 5.0})

In [21]:
# Evaluate on validation
trainer.evaluate(eval_dataset=val_dataset)


{'eval_loss': 0.6834434270858765,
 'eval_accuracy': 0.6413255360623782,
 'eval_precision': 0.5080213903743316,
 'eval_recall': 0.5080213903743316,
 'eval_f1': 0.5080213903743316,
 'eval_runtime': 7.3037,
 'eval_samples_per_second': 70.239,
 'eval_steps_per_second': 8.9,
 'epoch': 5.0}

We can see the base line model performs very mediocre with a 66% accuracy and 47% F1 (55% precision and 41% recall). Lets see how we can improve that. Our viable options are probability threshold tunning and hyperparameter tuning.

Now  we to calcualte our probability thershold to see whats the best prob for our purposes. We certainly will need to prioratize recall over precision because altho both are importnat missing a citation is more catostrophic (plagarism) than making extra citations


Then we can do a grid search iover lr and # of epochs to find best settings for our model



## Threshold tuning on the validation set

After training the baseline model, I keep the model weights fixed and tune thedecision threshol on the validation set.

Our model outputs a probability for class 1 ("true citation"). Instead of always using 0.5 as the cutoff, I sweep over thresholds from 0.1 to 0.9 and compute precision, recall, and F1 for each value. I am not changing the model just figuring out the best decision rule for us.


I then select the threshold that maximizes F1 on the validation set and store it as `BEST_THRESHOLD`. A value I will use for hyper parameter tuning and final model eval


In [22]:

val_pred = trainer.predict(val_dataset)
val_logits = val_pred.predictions
val_labels = val_pred.label_ids

#softmax activation
val_probs = torch.softmax(torch.tensor(val_logits), dim=-1)[:, 1].numpy()

thresholds = np.linspace(0.1, 0.9, 33)

best = {"t": None, "precision": 0.0, "recall": 0.0, "f1": 0.0}

print("Threshold sweep on val set:")
for t in thresholds:
    preds_t = (val_probs >= t).astype(int)
    precision, recall, f1, _ = precision_recall_fscore_support(
        val_labels,
        preds_t,
        average="binary",
        pos_label=1,
        zero_division=0,
    )
    print(f"t={t:.2f}  P={precision:.3f}  R={recall:.3f}  F1={f1:.3f}")
    if f1 > best["f1"]:
        best = {"t": t, "precision": precision, "recall": recall, "f1": f1}

BEST_THRESHOLD = float(best["t"])

print("\nBest threshold on val by F1:")
print(
    f"t={BEST_THRESHOLD:.2f}, "
    f"precision={best['precision']:.3f}, "
    f"recall={best['recall']:.3f}, "
    f"F1={best['f1']:.3f}"
)
print(f"\n saving best threshold {BEST_THRESHOLD:.2f} ")


Threshold sweep on val set:
t=0.10  P=0.365  R=1.000  F1=0.534
t=0.12  P=0.365  R=1.000  F1=0.534
t=0.15  P=0.365  R=1.000  F1=0.534
t=0.18  P=0.365  R=1.000  F1=0.534
t=0.20  P=0.365  R=1.000  F1=0.535
t=0.23  P=0.383  R=0.947  F1=0.545
t=0.25  P=0.420  R=0.866  F1=0.565
t=0.28  P=0.447  R=0.813  F1=0.577
t=0.30  P=0.460  R=0.733  F1=0.565
t=0.33  P=0.470  R=0.674  F1=0.554
t=0.35  P=0.478  R=0.636  F1=0.546
t=0.38  P=0.487  R=0.594  F1=0.535
t=0.40  P=0.495  R=0.572  F1=0.531
t=0.43  P=0.500  R=0.545  F1=0.522
t=0.45  P=0.508  R=0.535  F1=0.521
t=0.47  P=0.508  R=0.519  F1=0.513
t=0.50  P=0.508  R=0.508  F1=0.508
t=0.53  P=0.500  R=0.481  F1=0.490
t=0.55  P=0.503  R=0.465  F1=0.483
t=0.58  P=0.521  R=0.455  F1=0.486
t=0.60  P=0.528  R=0.449  F1=0.486
t=0.62  P=0.547  R=0.439  F1=0.487
t=0.65  P=0.543  R=0.406  F1=0.465
t=0.68  P=0.537  R=0.385  F1=0.449
t=0.70  P=0.556  R=0.374  F1=0.447
t=0.72  P=0.561  R=0.342  F1=0.425
t=0.75  P=0.574  R=0.310  F1=0.403
t=0.78  P=0.565  R=0.257  F

We can see that at a threshold of 0.25 we have the highest F1 score of 56% with a recall of 89% !!! and a precision of 41% :/

We can always balance this out more with a precision of 52% and recall of 58% but I do believe that the 30% increase to recall is worth the 10% decrease to precision.

Altho it is important to noce that this discrepancy between recall and precision will definitely affect our accuracy in a negative way

## Hyperparameter tuning with Grid search over learning rate and number of epochs

We do a very simplistic but effective grid search over two of our key hyper params:
- learning rate: `1e-5`, `2e-5`, `3e-5`  
- number of epochs: `3`, `4`


For each combination:

1. I create a new Scibert model and train it with the same setup as before
   (class-weighted loss, grouped train/val split, best epoch chosen by lowest validation loss).
2. After training, I evaluate the model on the validation setusing the tuned probability threshold `BEST_THRESHOLD` (found earlier on the baseline model).
3. I compute validation accuracy, precision, recall, and F1 at this fixed threshold and
   store the results.


At the end, I collect all runs in a DataFrame, sort them by validation F1, and select
the best hyperparameter configuration.


In [None]:
import itertools
import pandas as pd
from transformers import AutoModelForSequenceClassification, TrainingArguments

learning_rates = [1e-5, 2e-5, 3e-5]
num_epochs_list = [3, 4]
BATCH_SIZE = 8

print(f"Using best thresh {BEST_THRESHOLD:.2f} for grid search starting now")

results = []

for lr, num_epochs in itertools.product(learning_rates, num_epochs_list):
    run_name = f"lr{lr}_epochs{num_epochs}"
    print(f"\n ===== Running config: {run_name} =====")


    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=2,
    )

    training_args = TrainingArguments(
        output_dir=f"./scibert_grid/{run_name}",
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=num_epochs,
        learning_rate=lr,
        weight_decay=0.01,
        logging_steps=50,

        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        save_total_limit=1,
        report_to="none",
        seed=42,
    )

    trainer_grid = WeightedTrainer(
        class_weights=class_weights_tensor,
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )


    _ = trainer_grid.train()



    val_pred = trainer_grid.predict(val_dataset)
    val_logits = val_pred.predictions
    val_labels = val_pred.label_ids


    val_probs = torch.softmax(torch.tensor(val_logits), dim=-1)[:, 1].numpy()
    val_preds_t = (val_probs >= BEST_THRESHOLD).astype(int)

    prec, rec, f1, _ = precision_recall_fscore_support(
        val_labels,
        val_preds_t,
        average="binary",
        pos_label=1,
        zero_division=0,
    )
    acc = (val_preds_t == val_labels).mean()

    print(
        f"Validation @ t={BEST_THRESHOLD:.2f}: "
        f"Acc={acc:.3f}, P={prec:.3f}, R={rec:.3f}, F1={f1:.3f}"
    )

    results.append(
        {
            "run_name": run_name,
            "learning_rate": lr,
            "num_epochs": num_epochs,
            "val_accuracy_t": float(acc),
            "val_precision_t": float(prec),
            "val_recall_t": float(rec),
            "val_f1_t": float(f1),
        }
    )

    # Clean up before next run
    del trainer_grid, model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# summarizing results

results_df = pd.DataFrame(results).sort_values("val_f1_t", ascending=False)
print("\n=====  Results (sorted by val F1 =====")
display(results_df)

best_cfg = results_df.iloc[0].to_dict()
print("\nBest config:")
for k, v in best_cfg.items():
    print(f"  {k}: {v}")


Using best thresh 0.25 for grid search starting now

 ===== Running config: lr1e-05_epochs3 =====


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  super().__init__(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6382,0.668624,0.645224,0.513369,0.513369,0.513369
2,0.5136,0.686945,0.606238,0.473118,0.705882,0.566524
3,0.4225,0.783704,0.651072,0.519417,0.572193,0.544529


Validation @ t=0.25: Acc=0.405, P=0.377, R=0.968, F1=0.543

 ===== Running config: lr1e-05_epochs4 =====


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  super().__init__(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6375,0.6799,0.668616,0.56391,0.40107,0.46875
2,0.5118,0.667723,0.621832,0.486891,0.695187,0.572687
3,0.4272,0.932317,0.664717,0.552448,0.42246,0.478788
4,0.3334,1.070866,0.658869,0.535294,0.486631,0.509804


Validation @ t=0.25: Acc=0.526, P=0.427, R=0.882, F1=0.576

 ===== Running config: lr2e-05_epochs3 =====


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  super().__init__(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6437,0.670173,0.635478,0.5,0.540107,0.51928
2,0.4976,0.735885,0.615984,0.48227,0.727273,0.579957
3,0.3142,1.073618,0.658869,0.536585,0.470588,0.501425


Validation @ t=0.25: Acc=0.517, P=0.422, R=0.877, F1=0.569

 ===== Running config: lr2e-05_epochs4 =====


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  super().__init__(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6534,0.666562,0.668616,0.560284,0.42246,0.481707
2,0.5419,0.68475,0.621832,0.486275,0.663102,0.561086
3,0.4218,1.287294,0.682261,0.62,0.331551,0.432056
4,0.2601,1.592849,0.660819,0.535912,0.518717,0.527174


Validation @ t=0.25: Acc=0.485, P=0.407, R=0.904, F1=0.561

 ===== Running config: lr3e-05_epochs3 =====


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  super().__init__(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6367,0.68915,0.664717,0.549669,0.44385,0.491124
2,0.4942,0.743475,0.617934,0.479263,0.55615,0.514851
3,0.362,1.27705,0.639376,0.505747,0.470588,0.487535


Validation @ t=0.25: Acc=0.552, P=0.438, R=0.813, F1=0.569

 ===== Running config: lr3e-05_epochs4 =====


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  super().__init__(*args, **kwargs)


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.644,0.681056,0.664717,0.555556,0.40107,0.465839
2,0.5363,0.675583,0.662768,0.53125,0.636364,0.579075
3,0.4679,1.244384,0.692008,0.633028,0.368984,0.466216
4,0.2272,1.56768,0.688109,0.588235,0.481283,0.529412


Validation @ t=0.25: Acc=0.515, P=0.419, R=0.856, F1=0.562

=====  Results (sorted by val F1 =====


Unnamed: 0,run_name,learning_rate,num_epochs,val_accuracy_t,val_precision_t,val_recall_t,val_f1_t
1,lr1e-05_epochs4,1e-05,4,0.526316,0.427461,0.882353,0.575916
2,lr2e-05_epochs3,2e-05,3,0.516569,0.421594,0.877005,0.569444
4,lr3e-05_epochs3,3e-05,3,0.551657,0.43804,0.812834,0.569288
5,lr3e-05_epochs4,3e-05,4,0.51462,0.418848,0.855615,0.56239
3,lr2e-05_epochs4,2e-05,4,0.48538,0.407229,0.903743,0.561462
0,lr1e-05_epochs3,1e-05,3,0.405458,0.377083,0.967914,0.542729



Best config:
  run_name: lr1e-05_epochs4
  learning_rate: 1e-05
  num_epochs: 4
  val_accuracy_t: 0.5263157894736842
  val_precision_t: 0.4274611398963731
  val_recall_t: 0.8823529411764706
  val_f1_t: 0.5759162303664922


In [30]:
#best_cfg = results_df.sort_values("val_f1_t", ascending=False).iloc[0].to_dict()
#BEST_LR = float(best_cfg["learning_rate"])
#BEST_EPOCHS = int(best_cfg["num_epochs"])
#FINAL_THRESHOLD = BEST_THRESHOLD
#BATCH_SIZE = 8

BEST_LR = 1e-5
BEST_EPOCHS = 4
FINAL_THRESHOLD = 0.25

print("Best hyperparameters from grid search:")
print(f"  learning_rate = {BEST_LR}")
print(f"  num_epochs    = {BEST_EPOCHS}")
print(f"Using FINAL_THRESHOLD = {FINAL_THRESHOLD:.3f} (from validation sweep).")


Best hyperparameters from grid search:
  learning_rate = 1e-05
  num_epochs    = 4
Using FINAL_THRESHOLD = 0.250 (from validation sweep).


In [31]:
model_final = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
)

training_args_final = TrainingArguments(
    output_dir="./scibert_final_model",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=BEST_EPOCHS,
    learning_rate=BEST_LR,
    weight_decay=0.01,
    logging_steps=50,

    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=1,
    report_to="none",
    seed=42,
)

trainer_final = WeightedTrainer(
    class_weights=class_weights_tensor,
    model=model_final,
    args=training_args_final,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("\n Starting final training with best hyperparameters")
final_train_result = trainer_final.train()
print("\n Done")



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  super().__init__(*args, **kwargs)



 Starting final training with best hyperparameters


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6293,0.669395,0.668616,0.552147,0.481283,0.514286
2,0.5107,0.687285,0.617934,0.483019,0.684492,0.566372
3,0.4086,0.942903,0.635478,0.5,0.475936,0.487671
4,0.3063,1.078487,0.637427,0.502646,0.508021,0.505319



 Done


In [32]:

def compute_metrics_at_final_threshold(trainer, dataset, split_name):

    pred = trainer.predict(dataset)
    logits = pred.predictions
    labels = pred.label_ids


    probs = torch.softmax(torch.tensor(logits), dim=-1)[:, 1].numpy()
    preds = (probs >= FINAL_THRESHOLD).astype(int)


    precision, recall, f1, _ = precision_recall_fscore_support(
        labels,
        preds,
        average="binary",
        pos_label=1,
        zero_division=0,
    )
    acc = (preds == labels).mean()

    metrics = {
        "accuracy": float(acc),
        "precision": float(precision),
        "recall": float(recall),
        "f1": float(f1),
    }
    print(f"\n{split_name} metrics @ FINAL_THRESHOLD = {FINAL_THRESHOLD:.3f}: {metrics}")
    return metrics


final_val_metrics  = compute_metrics_at_final_threshold(trainer_final, val_dataset,  "VAL (final model)")
final_test_metrics = compute_metrics_at_final_threshold(trainer_final, test_dataset, "TEST (final model)")



VAL (final model) metrics @ FINAL_THRESHOLD = 0.250: {'accuracy': 0.5126705653021443, 'precision': 0.42065491183879095, 'recall': 0.893048128342246, 'f1': 0.571917808219178}



TEST (final model) metrics @ FINAL_THRESHOLD = 0.250: {'accuracy': 0.5021834061135371, 'precision': 0.3862745098039216, 'recall': 0.8716814159292036, 'f1': 0.5353260869565217}


## Saving best model for inference

I save the model and all it needs in one folder for easier portability and for infernece later.

Saved_Model includes:
- the fine-tuned SciBERT weights and config,
- the tokenizer vocabulary and config,
- `inference_config.json` :
  - `FINAL_THRESHOLD`,
  - the max sequence length,
  - the label mapping (`0 → NOT_CITATION`, `1 → CITATION`),
  - and the main training hyperparameters.



In [33]:


SAVE_DIR = "saved_model"
os.makedirs(SAVE_DIR, exist_ok=True)

#label mapping stored

id2label = {0: "NOT_CITATION", 1: "CITATION"}
label2id = {"NOT_CITATION": 0, "CITATION": 1}

trainer_final.model.config.id2label = id2label
trainer_final.model.config.label2id = label2id

#max lenght stored
trainer_final.model.config.max_length = MAX_LENGTH

#model weights and congif

trainer_final.save_model(SAVE_DIR)

#save tokenizer

tokenizer.save_pretrained(SAVE_DIR)

#save inf meta data(threshold, hyperparam)

inference_meta = {
    "threshold": float(FINAL_THRESHOLD),
    "model_name": MODEL_NAME,
    "max_length": int(MAX_LENGTH),
    "id2label": id2label,
    "label2id": label2id,
    "learning_rate": float(BEST_LR),
    "num_epochs": int(BEST_EPOCHS),
    "batch_size": int(BATCH_SIZE),
}

with open(os.path.join(SAVE_DIR, "inference_config.json"), "w") as f:
    json.dump(inference_meta, f, indent=2)

print(f"Saved all to: {os.path.abspath(SAVE_DIR)}")


Non-default generation parameters: {'max_length': 512}


Saved all to: /content/saved_model


In [34]:

!zip -r saved_model.zip saved_model


from google.colab import files
files.download("saved_model.zip")


  adding: saved_model/ (stored 0%)
  adding: saved_model/special_tokens_map.json (deflated 42%)
  adding: saved_model/vocab.txt (deflated 52%)
  adding: saved_model/config.json (deflated 50%)
  adding: saved_model/inference_config.json (deflated 37%)
  adding: saved_model/tokenizer.json (deflated 71%)
  adding: saved_model/training_args.bin (deflated 53%)
  adding: saved_model/model.safetensors (deflated 7%)
  adding: saved_model/tokenizer_config.json (deflated 74%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>