# NLBSE'25 Code Comment Classification Using A Lightweight BERT Transformer

## Authors
- Brendan Scheidt
- Coen Petto
- Jeremiah Geisterfer

## Overview

The primary objective of this notebook is to classify code comments into accurate sentiment categories. Given that real-world datasets show inherent class imbalance, techniques like oversampling and advanced loss functions are integrated to enhance model perfomance while keeping the model lightweight and efficient. Additionally, custom probability thresholds are injected and optimized for better evaluation metrics like precision, recall, and F1.

***Techniques***:
1.   **Oversampling via synonym replacement for training data**
* Addresses the initial class imbalance by duplicating training samples and replacing words within them with synonyms
2.   **Focal loss function with label smoothing**
* Forces model during training to focus on harder examples and stabalizes training by smoothing out the labels
3.   **Custom probability thresholds for evaluation**
* Optimizes classification thresholds per label on the validation set to improve test evaluation metrics



## Imports and config

### Libraries

* Installing ```datasets```, ```transformers```, and ```torch``` is essential for handling data, building custom transformer models, and performing deep learning tasks

### Imports

* **NumPy and Random**: For numerical operations and randomness
* **NLTK**: Used for natural language processing tasks like synonym replacement and POS tagging words into grammatical categories like noun, adverb, adjective, and verb
* **Hugging Face's Datasets and Transformers**: Used for loading datasets and leveraging pre-trained transformer models for custom augmentation
* **Torch**: Essential library for tensor operations and efficient model training
* **SciPy's Expit***: A Sigmoid function used for probability calculations
* **Scikit-learn Metrics**: for evaluating model's performance
* **Pandas**: Used for manipulating data and analyzing it
* **Time**: Used for measuring execution time
* **Autocast**: Used for mixed-precision training to speed up computations and reduce memory usage

### NLTK Downloads

* **WordNet**: A text database for English used for synonym replacement
* **OMW (Open Multilingual Wordnet)**: A database for multiple languages used in conjunction with the English WordNet
* **Averaged Perceptron Tagger**: Used for POS tagging words into nouns, adverbs, adjectives, and verbs
* **Stopwords**: A database of common words that are filtered out during processing

*https://www.nltk.org/*

In [27]:
!pip install -q datasets transformers torch

import numpy as np
import random
import nltk
from nltk.corpus import wordnet
from datasets import Dataset, concatenate_datasets, load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
import torch
import json
from scipy.special import expit
from sklearn.metrics import precision_recall_fscore_support, f1_score
import pandas as pd
import time
from torch.cuda.amp import autocast

nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('stopwords')

from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## Data Augmentation: Oversampling and Synonym Replacement

### Overview

*In order to mitigate the effects of duplicate entries when balancing class samples, synonym replacement is used to introduce variability through a form of data augmentation. Oversampling is used to increase the number of samples for minority classes.*

### Defined Functions

- ```get_wordnet_pos()```

This function converts POS (Part-Of-Speech) tags from the format given by NLTK's ```pos_tag()``` function to the format expected by WordNet. POS tagging is important for selecting synonyms that are appropriate for retaining grammatical integrity of the sample.

* ```synonym_replacement()```

This function takes in a hyperparameter ```n``` which specifies the number of words in a given sentence to replace with a synonym and replaces them. Introducing synonyms to the training samples increases the diversity of the data and helps the model generalize better while reducing overfitting.

* **Process**
1. Tokenization splits the passed in sentence into words.
2. POS tagging assigns POS tags to each word to integrate accurate synonym replacement into the sentence.
3. Identify each valid *canidate* word in the sentence (each word allowed for replacement not defined in ```stop_words``` and that have a valid POS tag).
4. For each canidate word, retrieve synonyms from WordNet and check that they are valid (alphabetical and different than the original).
5. Randomly select a synonym and replace the original word while at the same time preserving the original casing.



> Code Based on: *https://github.com/jasonwei20/eda_nlp/blob/master/code/eda.py*

> From the paper: *https://arxiv.org/pdf/1901.11196*

> Knowledge of WordNet obtained from: *https://iaoa.org/isc2012/docs/encycloped.article.pdf*



* ```oversample_multilabel()```

This function balances the label distribution of the training dataset through oversampling and augmenting the oversampled samples using ```synonym_replacement()```. This step is important in multilabel classification to prevent the model from showing bias for majority classes and increasing representation for minority classes.

* **Process**

1. Analyze the label distribution by counting the number of instances per label, calculating the frequency of each label relative to the dataset size, computing the inverse frequencies for prioritizing minority classes, and normalizing the inversed frequencies for balanced weighting.
2. Assign weights to each sample based on the inverse frequency sum of the labels assigned to that sample which gives minority classes higher weights.
3. Calculate the target number of samples to end up with after oversampling based on the passed in hyperparameter ```target_multiplier```.
4. Oversample indices with replacement according to the normalized sample weights.
5. Perform synonym replacement on the oversampled samples to introduce variability and help balance the dataset better.
6. Report the new label distribution after oversampling and synonym augmentation.

> Oversampling technique is a variant of SMOTE based on : *https://arxiv.org/pdf/1106.1813*

In [37]:
def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return None

def synonym_replacement(sentence, n=1):
    words = sentence.split()
    if len(words) == 0:
        return sentence

    # Get the POS tags for better synonym replacement
    pos_tags = nltk.pos_tag(words)

    # Generate canidates for replacement based on if their POS tag is valid and the original word is not a stop word.
    candidates = [i for i, (word, pos) in enumerate(pos_tags)
                  if get_wordnet_pos(pos) is not None and word.lower() not in stop_words]

    # If no valid canidates for replacement, return the original sentence
    if len(candidates) == 0:
        return sentence

    # Based on hyperparameter n, choose number of words to replace based on the minimum between n and the valid canidates
    random.shuffle(candidates)
    num_replacements = min(n, len(candidates))

    new_words = words.copy()
    replaced = 0

    for idx in candidates:
        word = words[idx]
        pos = get_wordnet_pos(pos_tags[idx][1])
        synonyms = wordnet.synsets(word, pos=pos)
        if not synonyms:
            continue
        # Extract synonyms, excluding the original word
        synonym_words = set()
        for syn in synonyms:
            for lemma in syn.lemmas():
                synonym = lemma.name().replace('_', ' ').lower()
                if synonym != word.lower() and synonym.isalpha():
                    synonym_words.add(synonym)
        if synonym_words:
            new_word = random.choice(list(synonym_words))
            # Preserve the original casing
            if word[0].isupper():
                new_word = new_word.capitalize()
            new_words[idx] = new_word
            replaced += 1
        if replaced >= num_replacements:
            break

    return ' '.join(new_words)

def oversample_multilabel(train_dataset, labels_list, language, target_multiplier=2, augment_synonyms=True, n_synonyms=1):
    labels = labels_list[language]
    labels_array = np.array(train_dataset['labels'])

    print(f"\n=== Oversampling for Language: {language} ===")

    # Compute label counts
    label_counts = labels_array.sum(axis=0)
    print("Initial Label Counts:")
    for label, count in zip(labels, label_counts):
        print(f"  {label}: {count}")

    # Compute label frequencies
    label_freqs = label_counts / len(train_dataset)
    print("\nLabel Frequencies:")
    for label, freq in zip(labels, label_freqs):
        print(f"  {label}: {freq:.4f}")

    # Compute inverse frequencies
    inv_freqs = 1.0 / (label_freqs + 1e-8)
    print("\nInverse Label Frequencies (Before Normalization):")
    for label, inv_freq in zip(labels, inv_freqs):
        print(f"  {label}: {inv_freq:.4f}")

    # Normalize inverse frequencies
    inv_freqs_normalized = inv_freqs / inv_freqs.mean()
    print("\nInverse Label Frequencies (After Normalization):")
    for label, inv_freq_norm in zip(labels, inv_freqs_normalized):
        print(f"  {label}: {inv_freq_norm:.4f}")

    # Compute sample weights as the sum of inverse frequencies of labels
    sample_weights = labels_array.dot(inv_freqs_normalized)
    print("\nSample Weights Statistics:")
    print(f"  Min Weight: {sample_weights.min():.4f}")
    print(f"  Max Weight: {sample_weights.max():.4f}")
    print(f"  Mean Weight: {sample_weights.mean():.4f}")
    print(f"  Median Weight: {np.median(sample_weights):.4f}")

    # Normalize sample weights to sum to 1
    sample_weights_normalized = sample_weights / sample_weights.sum()

    # Determine the number of samples after oversampling using target_multiplier hyperparameter
    n_samples_before = len(train_dataset)
    n_samples_after = int(n_samples_before * target_multiplier)
    print(f"\nNumber of Samples Before Oversampling: {n_samples_before}")
    print(f"Target Number of Samples After Oversampling: {n_samples_after}")
    # Resample indices based on sample weights
    indices = np.random.choice(
        len(train_dataset),
        size=n_samples_after,
        replace=True,
        p=sample_weights_normalized
    )
    print(f"  Number of Samples Selected for Resampling: {len(indices)}")

    # Create oversampled dataset
    oversampled_dataset = train_dataset.select(indices)

    # Compute new label counts after oversampling
    new_labels_array = np.array(oversampled_dataset['labels'])
    new_label_counts = new_labels_array.sum(axis=0)
    print("\nLabel Counts After Oversampling:")
    for label, count in zip(labels, new_label_counts):
        print(f"  {label}: {count}")

    # Calculate and display the increase in samples per label
    print("\nIncrease in Label Counts:")
    for label, initial, new in zip(labels, label_counts, new_label_counts):
        increase = new - initial
        print(f"  {label}: +{increase}")

    # Perform synonym replacement on oversampled samples
    if augment_synonyms:
        print("\nPerforming Synonym Replacement on Oversampled Samples...")
        augmented_samples = []
        for idx in indices:
            sample = train_dataset[int(idx)]
            original_sentence = sample['comment_sentence']
            augmented_sentence = synonym_replacement(original_sentence, n=n_synonyms)
            if augmented_sentence != original_sentence:
                augmented_sample = {
                    'class': sample['class'],
                    'comment_sentence': augmented_sentence,
                    'labels': sample['labels']
                }
                augmented_samples.append(augmented_sample)

        print(f"  Number of Samples Augmented with Synonyms: {len(augmented_samples)}")

        if augmented_samples:
            # Transform list of dicts to dict of lists
            augmented_dict = {key: [] for key in augmented_samples[0].keys()}
            for sample in augmented_samples:
                for key, value in sample.items():
                    augmented_dict[key].append(value)

            augmented_dataset = Dataset.from_dict(augmented_dict)
            oversampled_dataset = concatenate_datasets([oversampled_dataset, augmented_dataset])
            print(f"  Total Samples After Augmentation: {len(oversampled_dataset)}")
        else:
            print("  No samples were augmented with synonyms.")

    # Compute final label counts after augmentation
    final_labels_array = np.array(oversampled_dataset['labels'])
    final_label_counts = final_labels_array.sum(axis=0)
    print("\nFinal Label Counts After Augmentation:")
    for label, count in zip(labels, final_label_counts):
        print(f"  {label}: {count}")

    # Overall summary
    print("\n=== Oversampling and Augmentation Completed ===\n")

    return oversampled_dataset


## Custom Loss Function and Model Architecture

### Focal Loss with Label Smoothing

*This loss function was chosen for its effectiveness in addressing class imbalance and improving model robustness.*

Focal loss reduces the contribution to the loss function from easy examples and focuses on hard negatives, further handling the initial class imbalance. Label smoothing softens the target labels which prevents the model from becoming over confident and improves its generalization. These techniques in conjunction will help blass imbalance and overconfidence in predictions, hopefully leading to a more balances and reliable model.

* **Parameters**

The ```alpha``` hyperparameter assigns different weights to classes based on their improtance or frequency in the data and is passed in as a tensor.

The ```gamma``` hyperparameter is used in the Focal Loss and adjusts the rate at which easy examples are weighted downwards.

The ```smoothing``` hyperparameter specifies the degree to which label smoothing is applied to the targets.

The ```reduction``` technique hyperparameter specifies the reduction type to be applied to the output ('mean' or 'sum')

* **Forward Pass**

1. Label smoothing adjusts the target label towards a uniform distribution.
2. The BCE (Binary Cross-Entropy) loss is computed without reduction to retain the per-sample loss.
3. ```pt``` is the probability computation used to calculate the probability of the true class.
4. Focal weighting is applied to the BCE loss.
5. The reduction step aggregates the loss based on the specified reduction method.



> Focal Loss formula implementation based on: *https://arxiv.org/pdf/1708.02002*

> Focal Loss code based on: *https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py*

> Label smoothing idea based on: *https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=7780677*



### Custom Model Class

*This class defines a custom model that integrates a pre-trained transformer, ```prajjwal1/bert-tiny```, with the custom loss function, ```FocalLossWithLabelSmoothing``` for the multilabel classification task.*

It utilizes ```AutoModelForSequenceClassification``` from Hugging Face Transformers in conjunction with our custom loss function. By extending this pre-trained model with a specialized loss function, we allow the model to be tailored for diverse training dynamics, improving performance on imbalanced multilabel data.

* **Forward Pass**
1. Passes the passed in ```input_ids``` and ```attention_mask``` to the pre-trained model to obtain logits.
2. If the labels have been provided, compute loss via our custom loss function.
3. Returns a tuple that contains the loss and logits if labels were provided, or just the logits if labels were not provided.

> Knowledge of BERT obtained from: *https://arxiv.org/pdf/1810.04805*


In [None]:
class FocalLossWithLabelSmoothing(torch.nn.Module):
    def __init__(self, alpha=None, gamma=2, smoothing=0.1, reduction='mean'):
        super(FocalLossWithLabelSmoothing, self).__init__()
        self.alpha = alpha  # Class weights
        self.gamma = gamma
        self.smoothing = smoothing
        self.reduction = reduction
        self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        # Apply label smoothing
        targets = targets * (1 - self.smoothing) + 0.5 * self.smoothing
        bce_loss = self.bce_loss(inputs, targets.float())
        pt = torch.exp(-bce_loss)
        if self.alpha is not None:
            alpha = self.alpha.unsqueeze(0)  # Match batch size
            bce_loss = alpha * bce_loss
        focal_loss = (1 - pt) ** self.gamma * bce_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        else:
            return focal_loss.sum()

class CustomModel(torch.nn.Module):
    def __init__(self, model_name, num_labels, alpha=None, gamma=2, smoothing=0.1):
        super(CustomModel, self).__init__()
        self.num_labels = num_labels
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels,
            problem_type='multi_label_classification'
        )
        self.loss_fn = FocalLossWithLabelSmoothing(alpha=alpha, gamma=gamma, smoothing=smoothing)

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=None
        )
        logits = outputs.logits
        if labels is not None:
            loss = self.loss_fn(logits, labels)
            outputs = (loss, logits)
        else:
            outputs = (logits,)
        return outputs


## Metric Function

*This function calculates the evaluation metrics, precision/recall/F1 score, based on the model's predictions and its true labels*

These metrics offer a balanced evaluation on the performance of the model in multilabel situations where each label's performance contributes to the overall score.

* **Process**
1. The sigmoid activation function converts the logits to probabilities.
2. A threshold of *0.5* is applied to determine binary predictions.
3. Using Scikit-learn's ```precision_recall_fscore_spoort()``` function, the macro-averaged metrics are calculated and cases with no positive predictions are handled by setting ```zero_division=0```.

> Scikit-learn Documentation used: *https://scikit-learn.org/1.5/modules/generated/sklearn.metrics.precision_recall_fscore_support.html#sklearn.metrics.precision_recall_fscore_support*


In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='macro', zero_division=0
    )
    return {'precision': precision, 'recall': recall, 'f1': f1}


## Threshold Optimization Function

*This function optimizes classification thresholds for each label of the validation set by maximizing the F1 score.*

This technique is used because the default threshold (0.5) may not be optimal for all labels, and especially in an imbalanced dataset. By setting a custom prediction threshold per label, we aim to improve the metrics achieved during testing. Squeezing out any better score is imperative in our model because it is very lightweight and focuses on being fast and efficient rather than accurate.

* **Process**
1. Iterate over a range of potential thresholds (0.1 - 0.9) to find the optimal value for each label classification.
2. For each threshold canidate, compute the F1 score and select the threshold that produces the highest F1 score for that label.
3. Assign the best threshold for each label in the dataset.

> Multilabel classification with probabilistic thresholding: *https://www.researchgate.net/publication/285805721_Multilabel_classifiers_with_a_probabilistic_thresholding_strategy*

In [None]:
def optimize_thresholds(y_true, y_probs, threshold_candidates=np.linspace(0.1, 0.9, 9)):
    best_thresholds = []
    for i in range(y_probs.shape[1]):
        best_f1 = 0
        best_thresh = 0.5
        for thresh in threshold_candidates:
            y_pred = (y_probs[:, i] >= thresh).astype(int)
            f1 = f1_score(y_true[:, i], y_pred, zero_division=0)
            if f1 > best_f1:
                best_f1 = f1
                best_thresh = thresh
        best_thresholds.append(best_thresh)
    return best_thresholds


## Inference Function for Evaluation

*This function performes the inference step on the unaugmented test dataset using the trained model and applies optimized thresholds to generate the final predictions on the test data.*

By integrating mixed precision with optimized thresholds, we not only achieve the highest metrics we can with a small model, but also speed it up as much as possible with smaller numerical computations.

* **Process**
1. No gradients are computed during inference which saves memory and computation.
2. ```input_ids``` and ```attention_mask``` are moved to the GPU for faster computation.
3. ```autocast``` lowers the floating point computations to 16 instead of 32 bit, speeding up inference.
4. We obtain the logit outputs from the model
5. Apply the sigmoid function to convert the logits to probabilities for applying the thresholds properly.
6. The thresholds are applied to determine binary predictions per label.
7. We return the predictions tensor transposed to follow the inference step of the baseline model.

> Mixed Precision Training: *https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html*

> Checklist applied to Inference Efficiency: *https://pytorch.org/serve/performance_checklist.html*

In [None]:
def inference():
    with torch.no_grad():
        input_ids = test_dataset['input_ids'].to(device)
        attention_mask = test_dataset['attention_mask'].to(device)
        with autocast():  # Enable autocast for FP16
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            probs = torch.sigmoid(logits)  # Convert logits to probabilities
            preds = (probs >= thresholds_tensor).int().cpu().numpy()
            return preds.T


## Dataset Loading

Load the dataset for the competition from the NLBSE '25 code comment classification competition.

> Labels generated using this paper: *https://www.sciencedirect.com/science/article/pii/S0164121221001448?via%3Dihub*

> Ethical and Legal Disection of using Open-Source Code for training LLMs: *https://arxiv.org/pdf/2302.13681*

> NLBSE 2025 Competition Homepage: *https://nlbse2025.github.io/tools/*

In [42]:
labels_list = {
    'java': ['summary', 'Ownership', 'Expand', 'usage', 'Pointer', 'deprecation', 'rational'],
    'python': ['Usage', 'Parameters', 'DevelopmentNotes', 'Expand', 'Summary'],
    'pharo': ['Keyimplementationpoints', 'Example', 'Responsibilities', 'Classreferences', 'Intent', 'Keymessages', 'Collaborators']
}

dataset = load_dataset('NLBSE/nlbse25-code-comment-classification')

langs = ['java', 'python', 'pharo']


## Model Training for Each Language

*During training, just like the baseline transformer, we used separate models for each language based on the same architecture. This allows for language-specific nuances in code comments to be captured effectively. By integrating our data augmentation/oversampling, custom loss function, and threshold calculations, we ensure robustness and balance within training and evaluating performance across multiple categories.*

* **Setting up the training**
1. Determine whether to use GPU or CPU for computations.
2. Initialize an ampty threshold dictionary to store optimized thresholds on validation sets for each language.

### The Training Loop

1. Load the tokenizer for the pre-trained model (```prajjwal1/bert-tiny```).
2. Select from the super dataset the correct training and test set for the currently selected language.
3. Split the original training data into training and validation sets (80/20 split).
4. Apply the ```oversample_multilabel()``` function to the training set to balance label distributions and augment data with synonym replacement.
5. Define a ```tokenize()``` function that tokenizes the concatonated ```class``` and ```comment``` features with padding and truncation to a maximum of 128 tokens.
6. Apply the tokenization to the augmented training data and the validation data in batches of 1024 samples for efficiency.
7. Set the format of the tokenized data into tensors while specifying the columns needed for training.
8. Calculate the class weights by taking the frequency of each class in the augmented training data, compute the weights inversely proportional to the frequencies, and normalize then to have a mean of 1.
9. Create an instance of ```CustomModel``` with parameters like class weights, gamma, and label smoothing factor.
10. Define the training arguments: (where to save model checkpoints, number of training epochs, batch sizes, learning rate, evaluation and saving strategy, logging output, early stopping callback, mixed precision)
11. Initialize a Hugging Face Trainer with the training arguments to manage the training loop and validation evaluation with the early stopping callback.
12. Initiate training!
13. Save the trained model and tokenizer to output files for use in the test evaluation and submission score.
14. Optimize thresholds: (obtain logits from the validation set, convert to probabilities, calculate thresholds based on validation performance, store the thresholds in the thresholds dictionary)

> Trainer Documentation: *https://huggingface.co/docs/transformers/main_classes/trainer*

> Fine-tuning a pre-trained model: *https://huggingface.co/docs/transformers/en/training*

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
thresholds_dict = {}

for language in langs:
    print(f"\nProcessing language: {language}")
    num_labels = len(labels_list[language])

    # Initialize tokenizer
    model_name = 'prajjwal1/bert-tiny'
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    original_train_dataset = dataset[f'{language}_train']

    # Split the original training dataset into train and validation sets
    train_val_split = original_train_dataset.train_test_split(test_size=0.2, seed=42)
    train_original = train_val_split['train']
    val_original = train_val_split['test']

    # Perform oversampling and augmentation only on the training set
    augmented_train_dataset = oversample_multilabel(
        train_dataset=train_original,
        labels_list=labels_list,
        language=language,
        target_multiplier=2,
        augment_synonyms=True,
        n_synonyms=2
    )

    # Tokenization function
    def tokenize(batch):
        inputs = [f"{cls} {comment}" for cls, comment in zip(batch['class'], batch['comment_sentence'])]
        return tokenizer(inputs, padding='max_length', truncation=True, max_length=128)

    # Apply tokenization
    augmented_train_dataset = augmented_train_dataset.map(tokenize, batched=True, batch_size=1024)
    val_original = val_original.map(tokenize, batched=True, batch_size=1024)

    # Set format for PyTorch
    augmented_train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    val_original.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

    # Compute class weights using augmented_train_dataset
    labels_array = np.array(augmented_train_dataset['labels'])
    class_counts = labels_array.sum(axis=0)
    total_counts = labels_array.shape[0]
    class_freqs = class_counts / total_counts
    class_weights = 1.0 / (class_freqs + 1e-8)
    class_weights = class_weights / np.mean(class_weights)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    # Initialize the custom model
    model = CustomModel(
        model_name=model_name,
        num_labels=num_labels,
        alpha=class_weights,
        gamma=2,
        smoothing=0.3
    )
    model.to(device)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=f'./results/{language}',
        num_train_epochs=30 if language=='java' else 50,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=32,
        learning_rate=4e-5,
        evaluation_strategy='epoch',
        save_strategy='epoch',
        save_total_limit=1,
        logging_dir='./logs',
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model='eval_f1',
        greater_is_better=True,
        seed=42,
        save_safetensors=False,
        fp16=True,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=augmented_train_dataset,
        eval_dataset=val_original,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
    )

    # Train the model
    trainer.train()

    # Save the tokenizer and the underlying standard model
    tokenizer.save_pretrained(f'./models/{language}')
    model.model.save_pretrained(f'./models/{language}')

    # Obtain logits and labels from the validation set
    val_logits = trainer.predict(val_original).predictions
    val_labels = np.array(val_original['labels'])

    # Convert logits to probabilities
    val_probs = expit(val_logits)

    # Optimize thresholds based on validation set
    best_thresholds = optimize_thresholds(val_labels, val_probs)
    thresholds_dict[language] = best_thresholds
    print(f"Thresholds for {language}: {best_thresholds}")



Processing language: java


config.json:   0%|          | 0.00/285 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]


=== Oversampling for Language: java ===
Initial Label Counts:
  summary: 2897
  Ownership: 213
  Expand: 412
  usage: 1671
  Pointer: 721
  deprecation: 92
  rational: 244

Label Frequencies:
  summary: 0.4756
  Ownership: 0.0350
  Expand: 0.0676
  usage: 0.2743
  Pointer: 0.1184
  deprecation: 0.0151
  rational: 0.0401

Inverse Label Frequencies (Before Normalization):
  summary: 2.1025
  Ownership: 28.5962
  Expand: 14.7840
  usage: 3.6451
  Pointer: 8.4480
  deprecation: 66.2065
  rational: 24.9631

Inverse Label Frequencies (After Normalization):
  summary: 0.0989
  Ownership: 1.3457
  Expand: 0.6957
  usage: 0.1715
  Pointer: 0.3976
  deprecation: 3.1157
  rational: 1.1748

Sample Weights Statistics:
  Min Weight: 0.0989
  Max Weight: 3.8114
  Mean Weight: 0.3294
  Median Weight: 0.1715

Number of Samples Before Oversampling: 6091
Target Number of Samples After Oversampling: 12182
  Number of Samples Selected for Resampling: 12182

Label Counts After Oversampling:
  summary: 2003

Map:   0%|          | 0/21971 [00:00<?, ? examples/s]

Map:   0%|          | 0/1523 [00:00<?, ? examples/s]

pytorch_model.bin:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,0.0744,0.075453,0.76528,0.779637,0.764471
2,0.0612,0.067281,0.785679,0.847078,0.808946
3,0.0578,0.069092,0.787094,0.855009,0.812547
4,0.0564,0.068404,0.813168,0.849449,0.829165
5,0.0546,0.069403,0.797337,0.853507,0.82094
6,0.0545,0.071267,0.78726,0.844959,0.810022
7,0.0533,0.068325,0.825388,0.848403,0.835887
8,0.0526,0.070651,0.805703,0.852437,0.825239
9,0.0518,0.069701,0.810398,0.850114,0.828335
10,0.0519,0.068813,0.825971,0.843464,0.834249


Thresholds for java: [0.30000000000000004, 0.30000000000000004, 0.6, 0.7000000000000001, 0.7000000000000001, 0.8, 0.8]

Processing language: python

=== Oversampling for Language: python ===
Initial Label Counts:
  Usage: 464
  Parameters: 470
  DevelopmentNotes: 166
  Expand: 262
  Summary: 282

Label Frequencies:
  Usage: 0.3079
  Parameters: 0.3119
  DevelopmentNotes: 0.1102
  Expand: 0.1739
  Summary: 0.1871

Inverse Label Frequencies (Before Normalization):
  Usage: 3.2478
  Parameters: 3.2064
  DevelopmentNotes: 9.0783
  Expand: 5.7519
  Summary: 5.3440

Inverse Label Frequencies (After Normalization):
  Usage: 0.6098
  Parameters: 0.6021
  DevelopmentNotes: 1.7046
  Expand: 1.0800
  Summary: 1.0034

Sample Weights Statistics:
  Min Weight: 0.6021
  Max Weight: 3.3945
  Mean Weight: 0.9388
  Median Weight: 0.6098

Number of Samples Before Oversampling: 1507
Target Number of Samples After Oversampling: 3014
  Number of Samples Selected for Resampling: 3014

Label Counts After Over

Map:   0%|          | 0/5593 [00:00<?, ? examples/s]

Map:   0%|          | 0/377 [00:00<?, ? examples/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,0.1437,0.136703,0.460317,0.047636,0.086311
2,0.1226,0.123857,0.720705,0.413178,0.515859
3,0.1066,0.114451,0.671755,0.544909,0.599936
4,0.0891,0.113361,0.666377,0.614873,0.637963
5,0.0805,0.115645,0.654925,0.651758,0.652174
6,0.0754,0.121453,0.660663,0.670956,0.662119
7,0.0699,0.123819,0.655247,0.63732,0.64514
8,0.0654,0.128483,0.652895,0.636864,0.642999
9,0.0646,0.131634,0.67385,0.665528,0.667786
10,0.0653,0.134708,0.667908,0.656473,0.660348


Thresholds for python: [0.6, 0.4, 0.5, 0.5, 0.5]

Processing language: pharo

=== Oversampling for Language: pharo ===
Initial Label Counts:
  Keyimplementationpoints: 139
  Example: 424
  Responsibilities: 202
  Classreferences: 39
  Intent: 129
  Keymessages: 176
  Collaborators: 61

Label Frequencies:
  Keyimplementationpoints: 0.1339
  Example: 0.4085
  Responsibilities: 0.1946
  Classreferences: 0.0376
  Intent: 0.1243
  Keymessages: 0.1696
  Collaborators: 0.0588

Inverse Label Frequencies (Before Normalization):
  Keyimplementationpoints: 7.4676
  Example: 2.4481
  Responsibilities: 5.1386
  Classreferences: 26.6154
  Intent: 8.0465
  Keymessages: 5.8977
  Collaborators: 17.0164

Inverse Label Frequencies (After Normalization):
  Keyimplementationpoints: 0.7197
  Example: 0.2359
  Responsibilities: 0.4953
  Classreferences: 2.5651
  Intent: 0.7755
  Keymessages: 0.5684
  Collaborators: 1.6400

Sample Weights Statistics:
  Min Weight: 0.2359
  Max Weight: 5.4759
  Mean Weight: 0.

Map:   0%|          | 0/3895 [00:00<?, ? examples/s]

Map:   0%|          | 0/260 [00:00<?, ? examples/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,0.1339,0.129001,0.285714,0.05034,0.081349
2,0.1183,0.114861,0.673469,0.118863,0.193434
3,0.1051,0.105193,0.631406,0.325412,0.414986
4,0.0952,0.098962,0.645201,0.397842,0.475619
5,0.0859,0.095188,0.642861,0.428197,0.494648
6,0.0796,0.092311,0.678139,0.480498,0.54588
7,0.0734,0.093985,0.655793,0.484215,0.530454
8,0.069,0.09246,0.655134,0.527907,0.57269
9,0.0667,0.093728,0.646367,0.544967,0.578081
10,0.063,0.094301,0.676234,0.541494,0.584028


Thresholds for pharo: [0.5, 0.4, 0.6, 0.7000000000000001, 0.5, 0.6, 0.2]


## Submission Score Calculation

*This section of code was adapted from the baseline model with minimal modifications to ensure similarity is score calculation to the baseline. We essentially attempted to inject our model directly into the unedited score loop from the baseline. Evaluation across 10 runs for each language ensures averaged measurement of both performance metrics on the test dataset and computational efficiency.*

* **Setting up the score calculation**

1. Initialize the ```total_flops```, ```total_time```, and ```scores``` to default values to accumulate total operations across all runs on all languages.
2. Define the device as GPU if available or CPU otherwise.
3. Reload the dataset to ensure no leakage from the augmentation/training step.

### The Score Loop

1. Load the tokenizer and model previously saved from training.
2. prepare the test data by applying the same tokenization that was applied to the training and validation sets during training.
3. Format the test dataset to tensors and extract the labels.
4. Retrieve the optimized thresholds for the current language computed during training.
5. Use PyTorch's profiler to measure FLOPs during inference.
6. Run custom inference and measure the time taken for all 10 to run. (This step is slightly different than the baseline because SetFit has built in inference by passing the predictions to the ```model()``` function whereas we must use our custom one).
7. Accumulate FLOPS from profiling results.
8. Calculate the per-category metrics (precision, recall, F1 score) using the true labels and the predicted labels for each category.
9. Append the resulting score to the scores list.

> Code used from baseline: *https://github.com/nlbse2025/code-comment-classification/blob/main/SetFit_baseline.ipynb*

In [52]:
total_flops = 0
total_time = 0
scores = []

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

labels = {
    'java': ['summary', 'Ownership', 'Expand', 'usage', 'Pointer', 'deprecation', 'rational'],
    'python': ['Usage', 'Parameters', 'DevelopmentNotes', 'Expand', 'Summary'],
    'pharo': ['Keyimplementationpoints', 'Example', 'Responsibilities', 'Classreferences', 'Intent', 'Keymessages', 'Collaborators']
}

ds = load_dataset('NLBSE/nlbse25-code-comment-classification')

for lan in langs:
    print(f"Processing language: {lan}")
    num_labels = len(labels[lan])

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(f'./models/{lan}')
    model = AutoModelForSequenceClassification.from_pretrained(
        f'./models/{lan}',
        num_labels=num_labels,
        problem_type='multi_label_classification'
    )
    model.to(device)
    model.eval()

    # Load test dataset
    test_dataset = ds[f'{lan}_test']

    # Tokenization
    def tokenize(batch):
        inputs = [f"{cls} {comment}" for cls, comment in zip(batch['class'], batch['comment_sentence'])]
        return tokenizer(inputs, padding='max_length', truncation=True, max_length=128)

    test_dataset = test_dataset.map(tokenize, batched=True, batch_size=1024)
    test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    test_labels = np.array(test_dataset['labels'])

    # Load the thresholds for the current language
    best_thresholds = [float(threshold) for threshold in thresholds_dict[lan]]
    assert len(best_thresholds) == num_labels, f"Mismatch in thresholds and labels for {lan}"
    thresholds_tensor = torch.tensor(best_thresholds).unsqueeze(0).to(device)

    # Run inference multiple times
    with torch.profiler.profile(with_flops=True) as p:
        begin = time.time()
        for i in range(10):
            y_pred = inference()
        total = time.time() - begin
        total_time += total
    total_flops += (sum(k.flops for k in p.key_averages()) / 1e9)

    y_true = test_labels.T
    for i in range(len(y_pred)):
        assert len(y_pred[i]) == len(y_true[i])
        tp = sum([true == pred == 1 for (true, pred) in zip(y_true[i], y_pred[i])])
        tn = sum([true == pred == 0 for (true, pred) in zip(y_true[i], y_pred[i])])
        fp = sum([true == 0 and pred == 1 for (true, pred) in zip(y_true[i], y_pred[i])])
        fn = sum([true == 1 and pred == 0 for (true, pred) in zip(y_true[i], y_pred[i])])
        if tp + fp > 0:
            precision = tp / (tp + fp)
        else:
            precision = 0
        if tp + fn > 0:
            recall = tp / (tp + fn)
        else:
            recall = 0
        if (2 * tp + fp + fn) > 0:
            f1 = (2 * tp) / (2 * tp + fp + fn)
        else:
            f1 = 0
        scores.append({'lan': lan, 'cat': labels[lan][i], 'precision': precision, 'recall': recall, 'f1': f1})


Processing language: java


  with autocast():  # Enable autocast for FP16


Processing language: python
Processing language: pharo


## View Submission Score and Metrics for Each Language and Category

*The submission score calculation provides a balanced way to evaluate model performance based on not only accuracy but also computational efficiency. It is with the latter that we achieve a better score than the baseline. While our model produces slightly worse precision, recall, and F1 scores across almost all categories, It runs about 11x quicker than the baseline, allowing us to achieve a higher submission score than the baseline.*

**Here is the formula for the submission score calculation:**

\begin{align}
submission\_score(model) &= 0.60 \times avg. \space F_1
 + 0.2 \times \frac{(max\_avg\_runtime  - measured\_avg\_runtime)}{max\_avg\_runtime)}
 + 0.2 \times \frac{(max\_avg\_GFLOPS  - measured\_avg\_GFLOPS)}{max\_avg\_GFLOPS)}
 \end{align}

> Code retireved from: *https://github.com/nlbse2025/code-comment-classification/blob/main/SetFit_baseline.ipynb*

> Background info on submission score calculation: *https://colab.research.google.com/drive/1GhpyzTYcRs8SGzOMH3Xb6rLfdFVUBN0P*

In [53]:
print("Compute in GFLOPs:", total_flops / 10)
print("Avg runtime in seconds:", total_time / 10)
scores = pd.DataFrame(scores)
print(scores)

# Submission score calculation (same as baseline)
max_avg_runtime = 5
max_avg_flops = 5000

def score(avg_f1, avg_runtime, avg_flops):
    return (0.6 * avg_f1 +
            0.2 * ((max_avg_runtime - avg_runtime) / max_avg_runtime) +
            0.2 * ((max_avg_flops - avg_flops) / max_avg_flops))

avg_f1 = scores['f1'].mean()
avg_runtime = total_time / 10
avg_flops = total_flops / 10

print(f"Submission Score: {round(score(avg_f1, avg_runtime, avg_flops), 2)}")


Compute in GFLOPs: 243.88685004799999
Avg runtime in seconds: 0.0898674488067627
       lan                      cat  precision    recall        f1
0     java                  summary   0.852679  0.856502  0.854586
1     java                Ownership   0.978261  1.000000  0.989011
2     java                   Expand   0.318182  0.343137  0.330189
3     java                    usage   0.929155  0.791183  0.854637
4     java                  Pointer   0.799043  0.907609  0.849873
5     java              deprecation   0.666667  0.666667  0.666667
6     java                 rational   0.211268  0.220588  0.215827
7   python                    Usage   0.744444  0.553719  0.635071
8   python               Parameters   0.740458  0.757812  0.749035
9   python         DevelopmentNotes   0.302326  0.317073  0.309524
10  python                   Expand   0.434211  0.515625  0.471429
11  python                  Summary   0.581395  0.609756  0.595238
12   pharo  Keyimplementationpoints   0.625000  

## Following Data Through the Model Pipeline

Along the way, the data goes through a series of transformations to end up as a tensor of predictions.

* **1. Raw Data → 2. Data Splitting → 3. Data Augmentation (Oversampling + Synonym Replacement) → 4. Tokenization → 5. Dataset Formatting → 6. Class Weight Computation → 7. Model Training → 8. Threshold Optimization → 9. Inference on Test Data → 10. Evaluation Metrics Calculation → 11. Submission Scoring**

### Breakdown of Data Transformations

1. The pipeline begins by loading the dataset using ```load_dataset```. It is then organized per programming language, with each containing its own training and testing splits. Labels are then defined to represent different types of code comments and help guide the model in multilabel classification.
2. Each training set is then further split into training and validation sets (80/20) to help monitor performance while training.
3. The training set is then oversampled to increase representation of minority classes and each oversampled sample goes through synonym replacement to increase diversity further.
4. For each sample, the class and comment sentence are concatenated together and fed into the ```prajjwal1/bert-tiny``` tokenizer to convert English words to numerical representations of them by applying padding and truncating to a fixed length of 128 tokens.
5. The data is then converted from its tokenized output to tensors wpecifying the ```input_ids```, ```attention_mask```, and ```labels```.
6. Class weights are then calculated inversely proportional to class frequencies, normalized, and used during the custom focal loss function to emphasize minority classed when training.
7. The model is then defined and trained on the data from the pre-trained ```prajjwal1/bert-tiny``` model with our custom loss function and tuned configuration via the ```Trainer```.
8. After training, the unaugmented validation set is used to fine-tune optimal thresholds per label to maximize the F1 score. These thresholds are used during inference to convert probabilities to binary predicions of labels.
9. When evaluating on the test set, the trained model and tokenizer is loaded for each language and the test set is tokenized in the same manner as the training/validation sets. Inference is then ran 10 times for each language while measuring the FLOPS and runtime of the inference function.
10. After obtaining a tensor of predections from inference, the precision, recall, and F1 score for each category are aggregated into a dataframe.
11. Using the submission score formula, the average GFLOPS, average runtime, and calculated metrics are used to generate the final submission score.