# Recreation of Huang et al. 2021 "Balancing Methods for Multi-label Text Classification with Long-Tailed Class Distribution" Results, Comparison to Our Own Initial Experiments with BERT

# Notes on Huang et al. 2021 + Comparisons to Our Own Experiments with BERT

## 1. Introduction

This paper aims to tackle the issue of datasets with long-tailed class distributions as well as linkage (co-occurrence) of labels in the context of multi-label text classifiation. This sort of class imbalance occurs when a small subset of labels have many instances (head labels), while the majority of labels (the tail labels) have only a couple instances/examples. This is evident in the Reuters-21578 datset, where the vast majority of documents are occupied by these head labels, while most labels (for example, copper, strategic metal, and nickel) appear in less than 5% of the training data. Co-occurrence/label linkage is a challenge when some head (common) labels seem to coincide with rare or tail labels. In such cases, the model tends to be biased in terms of classifying the documents to the common labels. Some solutions, such as samples with less frequent labels in classification, using co-occurrence information in model initialization, and hybrid solutions for head and tail categories with multi-task architecture have been proposed, but they are not suitable for imbalanced datasets or are dependent on the model architecture.

In this paper, the authors are introducing the use of balancing loss function for multi label text classification. They work with two datasets: the aforementioned Reuters-21578 dataset, as well as another one called the PubMed dataset. For both datasets the proposed distribution balancing methods outperform other loss functions for total metrics and lead to significant improvement for classifying on tail labels.  

## 2. Loss Functions

BCE is a loss function commonly used for multi-label text classification. Given a dataset 

$$\{(x^1, y^1),...,(x^N, y^N)\}$$

with $N$ training instances, with each $x^k$ having a corresponding

$$y^k = [y^k_1,...,y^k_C] \in \{0, 1\}^C$$ 

where $C$ is the number of classes (basically each $y^k$ is an array the same length as the total number of labels, where each position can be 0 or 1, indicating that the label for that position is or is not associated with the given document $x^k$), and a corresponding classifier output of 

$$z^k = [z^k_1,...,z^k_C] \in R$$ 

where $z^k$ is a C-dimensional vector of real numbers (not just 0s and 1s), each $z^k_i$ is the raw logit (score) predicted by the model for a given label $i$, $\in R$ means that each raw logit can be any floating point value, and where each logit can be passed through a sigmoid function to get probabilities between 0 and 1 (that is, $p^k_i\ (the\ probability\ for\ a\ given\ logit) = \sigma(z^k_i)$), the BCE function is defined as:

$$\frac{-1}{N}\sum_{k}^{N}\sum_{i}^{C}[y^{k}_{i}\log(p^{k}_{i})+(1-y^{k}_{i})\log(1-p^{k}_{i})]$$ 

(the log is base $e$ or the natural logarithm ($\ln$))

Close observation of the BCE loss function and its role in the weight update pipeline reveals some key characteristics:

1. Head labels influence the direction of the gradient more strongly than rare labels, thereby influencing the weight updates to a greater degree, optimizing them for head labels.

2. Documents with many labels, similarly, influence the direction of the gradient more strongly than rare labels, leading to a similar effect. 

Consider a multi-label dataset with 1000 documents to illustrate both of these points.

During training:

Head label "business" (1000 documents):
- Often contains "profit"
- Consistently tells model weights/parameters: "increase to better detect 'profit'"
- 1000 documents push weights/parameters in similar direction
- weights become well-tuned for detecting "profit"

Rare label "rare-disease-research" (10 documents):
- Contains "rare-disease"
- Only 10 documents tell weights/paramaters: "increase to detect 'rare-disease'"
- Not enough consistent signal to tune weights well to this specific label

In greater detail:

1. Binary Cross Entropy Loss for Individual Documents:

$$L_k = -\sum_i [y^k_i \log(p^k_i) + (1-y^k_i)\log(1-p^k_i)]$$

where:
$$y^k_i \in \{0,1\}$$ is the true label
$$p^k_i = \sigma(z^k_i)$$ is predicted probability
$$\sigma(x) = \frac{1}{1 + e^{-x}}$$ is sigmoid function

2. Average Loss Over N Documents:
$$L_{avg} = \frac{1}{N} \sum_k L_k$$
$$= \frac{1}{N} \sum_k (-\sum_i [y^k_i\log(p^k_i) + (1-y^k_i)\log(1-p^k_i)])$$

3. Derivative of Average Loss with respect to any weight w:

$$\frac{\partial L_{avg}}{\partial w} = \frac{\partial}{\partial w}(\frac{1}{N} \sum_k L_k)$$

$$= \frac{1}{N} \sum_k \frac{\partial L_k}{\partial w}$$

4. Chain Rule for Individual Loss Terms:

$$\frac{\partial L_k}{\partial w} = \sum_i \frac{\partial L_k}{\partial p^k_i} \times \frac{\partial p^k_i}{\partial z^k_i} \times \frac{\partial z^k_i}{\partial w}$$

where:

$$\frac{\partial L_k}{\partial p^k_i} = -\frac{y^k_i}{p^k_i} + \frac{1-y^k_i}{1-p^k_i}$$

$$\frac{\partial p^k_i}{\partial z^k_i} = p^k_i(1-p^k_i)$$

The important part here is that the loss function is ultimately a function that depends on the weights, just via several layers of abstraction. The loss function depends on the probabilities calculated for a given document, which themselves depend on the logits calculated for that document, which themselves depend on the weights and inputs and how they interact. Thus, via the chain rule, the loss function is ultimately dependent on the weights, and thus can be differentiated with respect to the weights.

5. Multi-label Document Effect:
For single-label doc:

$$\frac{\partial L_k}{\partial w} = \frac{\partial L_k}{\partial p_k} \times \frac{\partial p_k}{\partial z_k} \times \frac{\partial z_k}{\partial w}$$

For three-label doc:

$$\frac{\partial L_k}{\partial w} = \frac{\partial L_k}{\partial p_{k1}} \times \frac{\partial p_{k1}}{\partial z_{k1}} \times \frac{\partial z_{k1}}{\partial w} + \frac{\partial L_k}{\partial p_{k2}} \times \frac{\partial p_{k2}}{\partial z_{k2}} \times \frac{\partial z_{k2}}{\partial w} + \frac{\partial L_k}{\partial p_{k3}} \times \frac{\partial p_{k3}}{\partial z_{k3}} \times \frac{\partial z_{k3}}{\partial w}$$

6. Weight Update:

$$w_{new} = w_{old} - \eta \times \frac{\partial L_{avg}}{\partial w}$$

where η is the learning rate.

As a note, one should remember that the loss function is used both as a function, as well as a number. The number representation, or loss value, helps us answer whether the model is learning (loss should be decreasing), the rate at which its learning, if there is any overfitting (comparing training vs validation loss), and which model/hyperparameters are better (comparing loss values). The scalar loss values aren't used for the actual training process though (that is, computing gradients and updating weights). For this, we use the function representation instead.

Now, the paper we're working with describes three alternative approaches that address class imbalance in long-tailed datasets in the context of multi-label text classification. The main idea is to reweight the BCE loss function such that those labels considered uncommon can 'influence' the loss to a greater degree.

### 2.1 Focal Loss

This method essentially adds a 'modulating factor' to the two simplified cases of BCE. It places a higher weight of loss on 'hard-to-classify' instances predited with low probability on ground truth:

$$
L_{FL} = \begin{cases}
-(1-p_i^k)^\gamma \log(p_i^k) & \text{if } y_i^k = 1 \\
-(p_i^k)^\gamma \log(1-p_i^k) & \text{otherwise}
\end{cases}
$$

(remember, the two cases above are still being multiplied by either the label (when the label is positive, i.e = 1) or 1 minus the label (when the label is negative, i.e = 0), and then added together. This way one of the two components of the addition is always being zeroed out depending on whether the label is 1 or 0.)

This shows both the positive case (when $y_i^k = 1$) and the negative case (when $y_i^k = 0$), where:
- $p_i^k$ is the predicted probability
- $\gamma$ is the focusing parameter (≥ 0)
- $y_i^k$ is the true label

Consider an example where we have a positive example that the model has poorly predicted, with $p_i^k = 0.3$, and a positive example where the model has predicted rather well, with $p_i^k = 0.9$. The paper uses a focusing paramater of $\gamma = 2$ in their experiments.

For the poor prediction, the modulating factor is:

$(1-p_i^k)^\gamma = (1-0.3)^{2} = 0.49$

For the good prediction, the modulating factor is:

$(1-p_i^k)^\gamma = (1-0.9)^{2} = 0.01$

As you can see, for poor predictions, the modulating factor is higher, therefore making the contribution to the overall loss by this specific document and label higher. For good predictions, the modulating factor is lower, thereby reducing the impact of this already well classified example on the loss. Thus, this weighting scheme helps combat class imbalance by preventing easy, majority-class examples from dominating loss during training.

This method doesn't necessarily prevent the model from learning easily classified majority-class examples either. The modulating factor doesn't immediately zero out easy examples. Loss contribution decreases gradually as confidence for a given lable increases, thus creating a smooth curve. Furthermore, if the model begins to 'forget' an easy example, the confidence would drop, thus raising the modulating factor and consequently the contribution of this specific example to the loss.

### 2.2 Class-balanced Focal Loss

This method basically adds another factor to the earlier focal loss loss function that takes into account the number of samples for a given class/label, thereby capturing the 'diminishing marginal benefits' of data for each class:

$$
L_{CB} = \begin{cases}
-r_{CB}(1-p_i^k)^\gamma \log(p_i^k) & \text{if } y_i^k = 1 \\
-r_{CB}(p_i^k)^\gamma \log(1-p_i^k) & \text{otherwise}
\end{cases}
$$

The added term here is:

$$
r_{CB} = \frac{1-\beta}{1-\beta^{n_i}}
$$

Where:
- $n_i$ is the number of samples for class i
- $\beta \in [0,1)$ is a hyperparameter that controls how fast the effective number grows
- The paper uses $\beta = 0.9$ in their experiments


For classes with many samples (large $n_i$):
- $\beta^{n_i}$ approaches 0
- This makes $r_{CB}$ smaller, reducing the impact of these common classes

For classes with few samples (small $n_i$):
- $\beta^{n_i}$ stays larger
- This makes $r_{CB}$ larger, increasing the impact of rare classes

5. Example:
Let's say $\beta = 0.9$ and we have:
- A common class with 100 samples: $r_{CB} = \frac{0.1}{1-0.9^{100}} \approx 0.1$
- A rare class with 10 samples: $r_{CB} = \frac{0.1}{1-0.9^{10}} \approx 0.15$

The rare class gets about 1.5x more weight in loss contribution than the common class.

### 2.3 Distribution-Balanced Loss 

This is a more comprehensive approach to the problem of long-tailed class distributions and co-occurrence of labels. It has two parts, Rebalanced Focal Loss and Negative Tolerant Regularization. We will start with describing rebalanced focal loss.

The simplified version of the rebalanced Focal Loss loss function is defined as follows:

$$
L_{R-FL} = \begin{cases}
-\hat{r}_{DB}(1-p_i^k)^\gamma \log(p_i^k) & \text{if } y_i^k = 1 \\
-\hat{r}_{DB}\frac{1}{\lambda}(p_i^k)^\gamma \log(1-p_i^k) & \text{otherwise}
\end{cases}
$$

The new term here, $\hat{r}_{DB}$, is defined as:

$$
\hat{r}_{DB} = \alpha + \sigma(\beta \times (r_{DB} - \mu))
$$

Broadly speaking, you can think of $\hat{r}_{DB}$ as a weighting factor that adjusts the loss attributed to a given label in a given document by considering a number of factors. In order to explain these factors, let's look at the variables making up $\hat{r}_{DB}$.

$\alpha$, $\sigma$, $\beta$ are all hyperparameters that are tuned by the user, while $r_{DB}$ is a parameter dependent on the frequency of the labels associated with a given document + the total number of labels in a dataset.

- $\alpha$:

    - Sets the minimum value of $\hat{r}_{DB}$, alters the output range of $\hat{r}_{DB}$ such that it is $[\alpha, \alpha + 1]$. This ensures that even common classes get some minimum weight, thus preventing any class from being completely ignored.
    
    - Paper uses $\alpha = 0.1$
    
- $\beta$:

    - Controls the 'steepness' of the sigmoid curve. The key here is to visualize what the sigmoid function looks like when graphed: 
    
    $$f(x) = \frac{1}{1+e^{-x}}$$ 
    
    - In essence, it is a curve where increasingly negative values of x cause y to approach 0, and increasingly positive values cause y to approach 1. At a value of x = 0, y passes through an inflection point at y = 0.5. In our case, $x = \beta \times (r_{DB} - \mu)$. We will discuss terms $r_{DB}$ and $\mu$ soon, but for now, let's focus on $\beta$ itself. As you can see, $\beta$ acts as a sort of amplifying factor (or minimizing, in the case that $\beta \in (0, 1)$, or inverting+minimizing, in the case that $\beta \in (-1, 0)$, or just inverting in the case that $\beta \in (-\infty, -1]$) for the exponent to $e$. This influences the steepness of the sigmoid curve around its inflection point, with higher values of $\beta$ causing greater steepness and values approaching 0 causing the curve to essentially flatten (with $\beta=0$ causing the sigmoid to become a flat line where y is equal to 0.5 at all values of x). The intuition here is that $\beta$ is controlling the degree to which 'rare' labels and 'common' labels should have a difference in how they're weighted, with higher $\beta$ causing rare labels to have a much higher weight factor ($\hat{r}_{DB}$), and common labels to have a much lower weight factor, and a lower $\beta$ causing the difference to be less pronounced.
    
    - Paper uses $\beta = 10$
    
- $\mu$

    - This controls the 'x' location of the inflection point. In other words, it controls which labels are considered on the rarer side and which ones are on the more common side. Below the inflection point, the y-value (or in other words, $\hat{r}_{DB}$, or the weighting factor we're discussing) approaches 0 + $\alpha$ quickly before leveling out near $\alpha$, and above it, the y-value approaches 1 + $\alpha$ before leveling out near there.
    - Paper uses $\mu = 0.9$, and thus inflection point is at $x = 0.9$
    
- $r_{DB}$

    - The formula for this term is:
    
        $$\frac{P_i^C}{P^I}$$
    
        Where:
        
        $${P_i^C} = \frac{1}{C}\frac{1}{n_i}$$
    
        $$P^I = \frac{1}{C}\sum_{y_i^k = 1}\frac{1}{n_i}$$
    
        $n_i$ is the number of documents associated with a given label $i$, $C$ is the total number of labels in the dataset, and $y_i^k$ is each label $i$ associated with a given document $k$
        
    - As you can see, once the $\frac{1}{C}$ terms cancel out, it's essentially the inverse frequency of a label divided by the sum of the inverse frequencies of each label associated with a given document. In the case of single-label documents, $r_{DB}$ becomes 1 (which is quite close to the paper's $\mu$ value for the Reuters dataset). In practice, this helps common labels in a given document get downweighted, and rare labels in a given document get upweighted. 



Now, onto the Negative Tolerant Regularization. 

$$
L_{NTR-FL} = \begin{cases}
-(1-q_i^k)^\gamma\log(q_i^k) & \text{if } y_i^k = 1 \\
-\frac{1}{\lambda}(q_i^k)^\gamma\log(1-q_i^k) & \text{otherwise}
\end{cases}
$$

$q_i^k = \sigma(z_i^k - v_i)$ when  $y_i^k = 1$

$q_i^k = \sigma(\lambda(z_i^k - v_i))$ when  $y_i^k = 0$

$v_i = -\kappa\times\hat{b}_i$

study uses $\kappa = 0.05$ and $\lambda = 2$

$\hat{b}_i = -log(\frac{1}{p_i}-1)$

$p_i = \frac{n_i}{N}$

Let's consider an example where we have 10000 documents total, and a rare class with about 100 documents associated. This means the rare class has 100 positive examples, and 9900 negative examples. We also have another class, with 1000 documents associated, and thus 9000 negative examples. We then have another class with 9000 documents associated, and thus 1000 negative examples:

$p_1 = \frac{100}{10000} = 0.01$

$\hat{b}_1 = -log(\frac{1}{p_1}-1) = -\log(\frac{1}{0.01}-1) \approx-4.595$

$v_1 = -\kappa\times\hat{b}_1 = -0.05 \times -4.595 \approx0.2298$

$p_2 = \frac{1000}{10000} = 0.1$

$\hat{b}_2 = -log(\frac{1}{p_2}-1) = -\log(\frac{1}{0.1}-1) \approx-2.197$

$v_2 = -\kappa\times\hat{b}_2 = -0.05 \times -2.197 \approx0.1099$



$p_3 = \frac{9000}{10000} = 0.9$

$\hat{b}_3 = -log(\frac{1}{p_3}-1) = -\log(\frac{1}{0.9}-1) \approx2.197$

$v_3 = -\kappa\times\hat{b}_3 = -0.05 \times 2.197 \approx-0.1099$


As you can see from the above, rare labels are assigned increasingly positive $v_i$, whereas common labels are assigned increasingly negative $v_i$. Let's continue this exercise by examining the $q_i^k$ for each example document, in both the positive and the negative classification case (that is, the case where the prediction should indeed be closer to 'yes' for a given label i, and the case where the prediction should be closer to 'no' for a given label i). 

Positive ($y_i^k = 1$):

$q_1^k = \sigma(z_1^k - v_1) = \frac{1}{1+e^{-(z_1^k - 0.2298)}}$ -> rare

$q_2^k = \sigma(z_2^k - v_2) = \frac{1}{1+e^{-(z_2^k - 0.1099)}}$ -> uncommon

$q_3^k = \sigma(z_3^k - v_3) = \frac{1}{1+e^{-(z_3^k - (-0.1099))}} = \frac{1}{1+e^{-(z_3^k + 0.1099))}}$ -> very common

As you can see from the above, when using the NTR-FL loss function's method for calculating probabilities in the positive case, the way the logit is changed depends on whether the label is rare, uncommon, or common, with rarer labels having their logits artificially reduced, and common labels having their logits artificially increased. Smaller and smaller logits cause the calculated probabilities to be lower and lower (note, the negative sign in the exponent is not part of the logit when considering the logit's value). Consequently, with lower probabilities, there will be a greater loss, a greater penalty, and stronger gradient updates toward better predicting true positives for rare labels. Conversely, in the case of very common labels, you can see that the logits are increased. Larger logits cause the calculated probabilities to be higher, and thus cause the loss to be smaller. Therefore, gradient updates focusing on predicting true positives for common labels are less pronounced.

Negative case($y_i^k = 0$):

$q_1^k = \sigma(\lambda(z_1^k - v_1)) = \frac{1}{1+e^{-2(z_1^k - 0.2298)}}$ -> rare

$q_2^k = \sigma(\lambda(z_2^k - v_2)) = \frac{1}{1+e^{-2(z_2^k - 0.1099)}}$ -> uncommon

$q_3^k = \sigma(\lambda(z_3^k - v_3)) = \frac{1}{1+e^{-2(z_3^k - (-0.1099))}} = \frac{1}{1+e^{-2(z_3^k + 0.1099))}}$ -> very common

The difference here is the presence of the $\lambda$ term. Similar to the $\beta$ term from the rebalanced FL loss function discussed earlier, this introduces an amplifier that influences the 'steepness' of the sigmoid function, with >1 values causing the sigmoid to be far steeper around the inflection point, 0<lambda<1 values causing the sigmoid function to be far less steep, 0 causing the sigmoid to flatline, and values below 0 following the same pattern but in the inverse. The impact of this is that small deviations in the logit will lead to larger changes in the predicted probability away from the inflection point, which, consequently, will cause larger changes in the loss (as in what would have high loss before will have even higher loss, and what would have had lower loss would have had even lower loss). However, given that the final loss calculation also has a $\frac{1}{\lambda}$ term, we can't immediately say that what would have had high loss before will have even higher loss, and what would have had lower loss would have even lower loss. The loss is scaled down by a factor of $\lambda$, and that can cause for example:

$-\log(1-(\frac{1}{1+e^{-(\lambda\times1)}}))\div\lambda \approx \frac{-2.127}{\lambda} \approx 1.063$

$-\log(1-(\frac{1}{1+e^{-(1)}})) \approx 1.313$

As you can see, without the lambda division, the lambda amplification by 2x in the exponent of the $e$ causes a larger loss. However, with the lambda division, the loss goes down by double in the upper case, and thus the final loss is actually less than what the lower case has.

You'll notice that for correct negative predictions, the loss remains small and only gets smaller by a factor of lambda, whereas for incorrect negative predictions, the loss will initially be smaller than if lambda was not involved in either the exponent or as divisor in the loss function, but then levels out as the loss grows greater.

Overall, the impact of these adjustments (including the adjustments to the logit calculation with $v_i$) in the negative case is that the model is less penalized for incorrectly predicting negative instances of rare classes, and more penalized for incorrectly predicting negative instances of common classes. Thus, the model is encouraged to focus on correctly classifying positive instances of rare classes, and encouraged to maintain high accuracy on common classes.

Finally, the Distribution-Balanced loss function makes use of both Rebalanced-FL and Negative-Tolerant Regularization as follows:

$$L_{DB} = \begin{cases}
-\hat{r}_{DB}(1-q_i^k)^\gamma\log(q_i^k) & \text{if } y_i^k = 1 \\
-\hat{r}_{DB}\frac{1}{\lambda}(q_i^k)^\gamma\log(1-q_i^k) & \text{otherwise}
\end{cases}$$

## 3. Experiments

### 3.1 Datasets

In this section, the authors describe the specific datasets utilized in this study. Both the Reuters-21578 and the PubMed dataset are suited for multi-label text classification experiments. However, given that our own experiments from earlier focused on multi-label classificaiton of Reuters-21578 utilizing BERT and SBERT, we will be focusing on only the Reuters-21578 dataset as well. The authors of Huang et al. utilized the 'aptemod' split (not to be confused with the 'modapte' split described in the readme of the Reuters-21578 dataset, the one we were attempting to use in our set 'b' experiments). This split is partitioned such that 7769 of the documents are used for training (1000 of which went into validation), and 3019 for testing. The labels are nearly equally split into head (30 with $\geq$ 35 associated documents), medium (31 with 8-35 associated documents), and tail (30 with $\leq$ 8 associated documents).

### 3.2 Experimental Settings, Comparison to Our Own Experimental Settings

They compared the use of BCE and its modifications described in the loss functions section, as well as SVM one-vs-rest model. More specifically, they compare the performance of the following loss functions on Reuters-21578:

- Binary Cross Entropy
- Focal Loss with $\gamma = 2$
- Class-Balanced Focal Loss with $\beta = 0.9$
- Rebalanced Focal Loss with $\alpha = 0.1$, $\beta = 10$, and $\mu = 0.9$
- Negative-Tolerant Regularization Focal Loss with $\kappa = 0.05$ and $\lambda = 2$
- Distribution-Balanced Loss with the same values as R-FL and NTR-FL
- Class-Balanced Negative-Tolerant Regularization, in which the only difference is the use of CB weight $r_{CB}$ instead of the rebalancing weight $\hat{r}_{DB}$

For the Reuters dataset, they utilized bert-base-cased with a backbone of BertForSequenceClassification. They used a maximal data length of 512, with a batch size of 32, an AdamW optimizer with a weight decay of 0.01, and a learning rate determined by hyperparameter search. They implemented their experiments in PyTorch, utilizing one V100 GPU. For their SVM one-vs-rest model, they utilized sklearn with TF-IDF Features. For their hyperparameter search, they applied linear kernel and hyper-plane shifting optimized on each validation set. Alongside the main experiment, they also investigated the effectiveness of loss functions against he number of labels per instance. For the Reuters dataset, they split the test instances into two groups, 2583 single-label instances, and 436 multi-label instances. During evaluation, they select as their final model that which scores best on micro-f1 on the validation set, and then they evaluate its performance on the test set with micro-f1 and macro-f1 scores.

In our own initial experiments utilizing BERT, we utilized bert-base-cased with a backbone of BertForSequenceClassification as well. We had a maximal data length of 512, a batch size of 16, an AdamW optimizer with weight decay of 0.01, and a learning rate of 2e-5 or 0.00002. We did not utilize hyperparameter search. We only utilized BCE as our loss function.

## 4. Results

Here is the table showing their micro-f1 and macro-f1 scores for each model/loss function across the Reuters dataset and its subsets:

![alt text](Figure-2.png)

Our own initial micro-f1 score for the Reuters-21578 dataset utilizing Bert-Base-Cased as well as BertForSequenceClassification was roughly 69.46. This is in stark contrast to the 89.14 score achieved by the authors of Huang et al. utilizing the same model and transformer. The difference can likely be chalked up to differences in weight decay and epoch count.

## Recreation of Huang et al.'s pipeline, with the exception of the SVM model

First, we will be recreating the results of Huang et al. utilizing the same exact dataset, the same exact split, and consequently, the same exact hyperparameters and evaluation. Later, we will attempt to apply parts of Huang et al.'s methodology to our own initial experiments of type 'a' and 'b', and then compare results not only to Huang et al.'s results, but also to the results of our own initial experiments. Results for each of the different loss functions (except BCE, which is covered in this notebook) can be found in .ipynb files in the same folder as this notebook with corresponding names (for example, my replication of Huang et al.'s Focal Loss pipeline is located inside 'Huang-et-al-FL.ipynb').

### Huang et al. data preprocessing

In [44]:
# Importing aptemod dataset, getting training, validation, and test splits into same format as Huang et al.

import os
import xml.etree.ElementTree as ET
from collections import Counter
from sklearn.model_selection import train_test_split
import pickle

def read_labels(labels_path):
    """Parse labels file into a dict mapping doc_id to list of labels"""
    doc_to_labels = {}
    with open(labels_path, 'r', encoding='utf-8') as f:
        for line in f:
            doc_id, label = line.strip().split(' ', 1)
            doc_id = doc_id.replace('test/', '')
            doc_id = doc_id.replace('training/', '')
            doc_to_labels[doc_id] = label.split(' ')
    return doc_to_labels

def read_document(file_path):
    """Read a single document, clean its contents, and return them"""
    with open(file_path, 'r', encoding='latin-1') as f:
        
        content = f.read()
        content = content.replace('\n', ' ')
        content = ' '.join(content.split())
        return content

#     # Parse XML
#     root = ET.fromstring(content)
    
#     # Get document content
#     doc_content = root.find('.//document_content').text.strip()
    
#     return {
#         'text': doc_content,
#         'labels': labels.get(doc_id, [])  # Get labels for this doc_id
#     }

# Read in document ids and associated labels

labels_path = os.path.join('reuters-aptemod', 'cats.txt')
labels = read_labels(labels_path)

print(f"Number of documents with labels: {len(labels)}")

# Read in document texts

training_path = os.path.join('reuters-aptemod', 'training')
data_train_all = []
for file in os.listdir(training_path):
    if file in labels:
        file_dict = {
            'text': read_document(os.path.join(training_path, file)),
            'labels': labels[file]
        }
        data_train_all.append(file_dict)

test_path = os.path.join('reuters-aptemod', 'test')
data_test = []
for file in os.listdir(test_path):
    file_dict = {}
    if file in labels:
        file_dict = {
            'text': read_document(os.path.join(test_path, file)),
            'labels': labels[file]
        }
        data_test.append(file_dict)

# Split validation data from training data. 

data_train, data_validation = train_test_split(data_train_all, random_state = 100, test_size = 1000) # Using a different random seed relative to Huang et al. because their seed of 123 was splitting my 'data_train_all' variable such that the training set was missing a single label, 'groundnut-oil'. This discrepancy occurs despite the similar seed because our 'data_train_all' variable has its documents in a different order than what Huang et al. originally had. I could not determine the exact order in which Huang et al. had their training documents in prior to splitting off validation data, but this should not be a big issue so long as our training set still has all 90 labels. The results of the various loss functions should not vary greatly from Huang et al.'s original results since we're just working with a slightly different variation of their original split.

print(f"Number of training documents {len(data_train)}")

print(f"Number of validation documents {len(data_validation)}")

print(f"Number of testing documents {len(data_test)}")

Number of documents with labels: 10788
Number of training documents 6769
Number of validation documents 1000
Number of testing documents 3019


In [45]:
# Making sure number of unique labels in the entire dataset is 90

unique_labels = set()
for label_list in labels.values():
    unique_labels.update(label_list)
print(f"Number of unique labels in cats.txt: {len(unique_labels)}")
print(f"Labels are: {sorted(list(unique_labels))}")

Number of unique labels in cats.txt: 90
Labels are: ['acq', 'alum', 'barley', 'bop', 'carcass', 'castor-oil', 'cocoa', 'coconut', 'coconut-oil', 'coffee', 'copper', 'copra-cake', 'corn', 'cotton', 'cotton-oil', 'cpi', 'cpu', 'crude', 'dfl', 'dlr', 'dmk', 'earn', 'fuel', 'gas', 'gnp', 'gold', 'grain', 'groundnut', 'groundnut-oil', 'heat', 'hog', 'housing', 'income', 'instal-debt', 'interest', 'ipi', 'iron-steel', 'jet', 'jobs', 'l-cattle', 'lead', 'lei', 'lin-oil', 'livestock', 'lumber', 'meal-feed', 'money-fx', 'money-supply', 'naphtha', 'nat-gas', 'nickel', 'nkr', 'nzdlr', 'oat', 'oilseed', 'orange', 'palladium', 'palm-oil', 'palmkernel', 'pet-chem', 'platinum', 'potato', 'propane', 'rand', 'rape-oil', 'rapeseed', 'reserves', 'retail', 'rice', 'rubber', 'rye', 'ship', 'silver', 'sorghum', 'soy-meal', 'soy-oil', 'soybean', 'strategic-metal', 'sugar', 'sun-meal', 'sun-oil', 'sunseed', 'tea', 'tin', 'trade', 'veg-oil', 'wheat', 'wpi', 'yen', 'zinc']


In [46]:
term2count = Counter([x for docu in data_train for x in docu['labels']])
FREQ_CUTOFF = 0 
term_freq = sorted([term for term, count in term2count.items() if count>=FREQ_CUTOFF])
labels_ref = sorted([z for z in set([y for x in data_train for y in x['labels']]) if z in term_freq]) 
print(len(term2count), len(labels_ref))
class_freq = [term2count[x] for x in labels_ref]
train_num = len(data_train)

90 90


In [47]:
import numpy as np
print(
    np.quantile(np.array(list(class_freq)),q=0.25),
    np.quantile(np.array(list(class_freq)),q=0.33),
    np.quantile(np.array(list(class_freq)),q=0.5),
    np.quantile(np.array(list(class_freq)),q=0.66),
    np.quantile(np.array(list(class_freq)),q=0.75))

# Slightly different results compared to Huang et al's specific split

5.0 8.0 17.0 36.74 59.5


In [48]:
r_all = []
for docu in data_train:
    docu_p = [1/term2count[x] for x in docu['labels']]
    docu_p_sum = sum(docu_p)
    r_all.extend([p/docu_p_sum for p in docu_p])
    
import numpy as np
print(np.mean(r_all),np.median(r_all))

0.8088182578563747 1.0


In [49]:
(np.mean(r_all) + np.median(r_all)) / 2
## 0.9 still remains a good choice for mu (paper refers to this term as mu, whereas the code seems to refer to it as gamma for some reason). \mu essentially changes the location of the inflection point in the weighting factor \hat{r}_{DB}. The inflection point determines, in some sense, which documents are considered common as opposed to rare. By using the mean and median of the document-normalized inverse frequencies of all the documents, they're trying to set an inflection point that reflects the relative 'rareness' of labels.

0.9044091289281874

### Huang et al. Loss Function Implementation, i.e ResampleLoss class and Associated Functions

The util_loss.py module contains the ResampleLoss class, as well as some external helper functions. The ResampleLoss class is what primarily drives the implementation of the various loss functions described in the paper. Adjusting its parameters allows us to adjust and fine tune the spectrum of loss functions available. The class contains methods for reweighting, logit adjustments, and loss computation. The external helper functions in the module are for reducing and weighting the loss, as well as for computing the BCE. Since reimplementation of this module from scratch would be rather tedious, we've just gone ahead and downloaded the util_loss.py function directly from the Huang et al. repo.

In [50]:
from util_loss import ResampleLoss

## Huang et al. BCE Training and Evaluation + Comparison to my own methodology and prior results

In [51]:
import sys
import os
import torch
import json
import numpy as np
from torch import nn
from transformers import BertForSequenceClassification, BertTokenizer, AdamW
from tqdm import trange
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
from util_loss import ResampleLoss
from torch.utils.data import Dataset, DataLoader

In [52]:
# Initializing model and tokenizer
num_labels = len(labels_ref)
model = BertForSequenceClassification.from_pretrained(
    'bert-base-cased',
    num_labels=num_labels,
    problem_type="multi_label_classification"
)

tokenizer = BertTokenizer.from_pretrained('bert-base-cased') # For some reason, in their paper, the authors of Huang et al. mention utilizing bert-base-cased, but in their code repo, they use bert-case-uncased.

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [53]:
# Set up device and move model to it
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [54]:
# Defining optimizer
# Our own original experiments did not use grouped parameters to define which parameters should and shouldn't have weight decay applied. This is clearly a step forward relative to our model in that it allows more flexibility in terms of fine-tuning.

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4) # the learning rate applied is also different relative to our own experiments from before.



In [55]:
# Setting up BCE Loss Function according to Huang et al. methodology
loss_func = ResampleLoss(
    reweight_func=None,
    loss_weight=1.0,
    focal=dict(focal=False, alpha=0.5, gamma=2),
    logit_reg=dict(),
    class_freq=class_freq,
    train_num=train_num
)

In [56]:
# Tokenize texts, tensorize labels, and define wrapper class for our data. Then create the datasets using this custom class, and then create our dataloaders.

# In my previous experiments, I tokenized all the available training and testing texts before passing them into my wrapper for the 'Dataset' class. Similarly, I tensorized all of my labels for training and testing before passing them into the wrapper for the 'Dataset' class. Huang et al., on the otherhand, defines a custom function to tokenize and tensorize the text and labels of a given document on the fly. It then uses this function as part of its wrapper class, calling the function everytime we need to retrieve the tokenized text, the tensorized labels, and the attention mask for a given document. 

def preprocess_function(docu):
    labels = [1 if x in docu['labels'] else 0 for x in labels_ref]
    encodings = tokenizer(
        docu['text'],
        truncation=True,
        padding='max_length',
        max_length=512,
        return_tensors='pt'
    )
    return {
        'input_ids': encodings['input_ids'].flatten(),
        'attention_mask': encodings['attention_mask'].flatten(),
        'labels': torch.tensor(labels, dtype=torch.float)
    }

class CustomDataset(Dataset):
    def __init__(self, documents):
        self.documents = documents

    def __len__(self):
        return len(self.documents)

    def __getitem__(self, index):
        return preprocess_function(self.documents[index])

# Create datasets
train_dataset = CustomDataset(data_train)
val_dataset = CustomDataset(data_validation)

# Create data loaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True
)

validation_dataloader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False
)


In [58]:
# Create training loop as per Huang et al.
# Variables used to save/replace models and logs
source_dir = './'

prefix = 'reuters'
loss_func_name = 'BCE'
suffix = 'rand100'
model_name = 'bert-base-cased'

epochs = 40 # Epoch count utilized by Huang et al.
best_f1_for_epoch = 0 # Tracking best f1 score
epochs_without_improvement = 0 # Implementing early stop if loss does not improve

# Create directories if they don't already exist
model_dir = os.path.join(source_dir, 'models')
log_dir = os.path.join(source_dir, 'logs')

os.makedirs(model_dir, exist_ok=True)  # Creates models directory if it doesn't exist
os.makedirs(log_dir, exist_ok=True)    # Creates logs directory if it doesn't exist

for epoch in trange(epochs, desc='Epoch'): # Using trange from the tqdm library for the progress bar. 
    model.train()
    training_loss = 0
    training_steps = 0
    
    for batch in train_dataloader:
        batch = {key: value.to(device) for key, value in batch.items()} # Moving tensors in batch to GPU
        b_input_ids = batch['input_ids']
        b_input_mask = batch['attention_mask']
        b_labels = batch['labels']
        optimizer.zero_grad() # Clearing gradients from prior batch, prevent accumulation across batches
        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask) # passing input into BERT model to retrieve logits
        logits = outputs[0]
        loss = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels)) # calculating loss via the loss function we specified from the util_loss module's ResampleLoss class, in this case just regular BCE. Additionally, we're reshaping the logits to match the labels, converting labels to match the same data type as the logits, and also reshaping them.
        loss.backward() # Computing gradients
        optimizer.step() # Updating weights
        training_loss += loss.item() # Summing training loss
        training_steps += 1 # Counting training steps
        
    print("Train loss: {}".format(training_loss/training_steps))
    
    # Validation section
    model.eval()
    val_loss = 0
    val_steps = 0
    true_labels,pred_labels = [],[]
    
    for batch in validation_dataloader:
        batch = {key: value.to(device) for key, value in batch.items()}
        b_input_ids = batch['input_ids']
        b_input_mask = batch['attention_mask']
        b_labels = batch['labels']
        with torch.no_grad():
            outs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
            b_logit_pred = outs[0]
            pred_label = torch.sigmoid(b_logit_pred) # Applying sigmoid to logits to acquire probabilities
            loss = loss_func(b_logit_pred.view(-1,num_labels),b_labels.type_as(b_logit_pred).view(-1,num_labels))
            val_loss += loss.item()
            val_steps += 1
            
            b_logit_pred = b_logit_pred.detach().cpu().numpy()
            pred_label = pred_label.to('cpu').numpy()
            b_labels = b_labels.to('cpu').numpy()
            
        true_labels.append(b_labels)
        pred_labels.append(pred_label)
        
    print("Validation loss: {}".format(val_loss/val_steps))
    
    # Flatten outputs into 1d lists.
    true_labels = [item for sublist in true_labels for item in sublist]
    pred_labels = [item for sublist in pred_labels for item in sublist]
    
    # Calculate Accuracy
    threshold = 0.5
    true_bools = [tl==1 for tl in true_labels] # turning actual labels into booleans
    pred_bools = [pl>threshold for pl in pred_labels] # predicting labels based on threshold
    val_f1_accuracy = f1_score(true_bools,pred_bools,average='micro')
    val_precision_accuracy = precision_score(true_bools, pred_bools,average='micro')
    val_recall_accuracy = recall_score(true_bools, pred_bools,average='micro')
    print('F1 Validation Accuracy: ', val_f1_accuracy)
    print('Precision Validation Accuracy: ', val_precision_accuracy)
    print('Recall Validation Accuracy: ', val_recall_accuracy)
    
    # Calculate AUC as well, will need to look into this some more as I'm unsure what this is exactly
    val_auc_score = roc_auc_score(true_bools, pred_labels, average='micro')
    print('AUC Validation: ', val_auc_score)
    
    # Searching for best Threshold for f1. Essentially, what's going on here is that we're creating a range of thresholds from 0.4 to 0.6 with steps of 0.01 in between. Then, we're looping over this range and testing for which threshold yields the highest f1 score, printing that which gives the best results. 
    best_med_th = 0.5
    micro_thresholds = (np.array(range(-10,11))/100)+best_med_th
    f1_results, prec_results, recall_results = [], [], []
    for th in micro_thresholds:
        pred_bools = [pl>th for pl in pred_labels]
        test_f1_accuracy = f1_score(true_bools,pred_bools,average='micro')
        test_precision_accuracy = precision_score(true_bools, pred_bools,average='micro')
        test_recall_accuracy = recall_score(true_bools, pred_bools,average='micro')
        f1_results.append(test_f1_accuracy)
        prec_results.append(test_precision_accuracy)
        recall_results.append(test_recall_accuracy)
    best_f1_idx = np.argmax(f1_results) #best threshold value
    
    print('Best Threshold: ', micro_thresholds[best_f1_idx])
    print('Test F1 Accuracy: ', f1_results[best_f1_idx])
    
    # Save the model if this epoch gives the best f1 score in validation set
    if f1_results[best_f1_idx] > (best_f1_for_epoch * 0.995):
        best_f1_for_epoch = f1_results[best_f1_idx]
        epochs_without_improvement = 0
        for fname in os.listdir(model_dir):
            if fname.startswith('_'.join([prefix,model_name,loss_func_name,suffix])):
                os.remove(os.path.join(model_dir, fname))
        torch.save(model.state_dict(), os.path.join(model_dir, '_'.join([prefix,model_name,loss_func_name,suffix,'epoch'])+str(epoch+1)+'para'))
    else:
        epochs_without_improvement += 1    
        
    
    # Log all results in validation set with different thresholds
    with open(os.path.join(log_dir, '_'.join([prefix,model_name,loss_func_name,suffix,'epoch'])+str(epoch+1)+'.json'),'w') as f:
        d = {}
        d["f1_accuracy_default"] =  val_f1_accuracy
        d["pr_accuracy_default"] =  val_precision_accuracy
        d["rec_accuracy_default"] =  val_recall_accuracy
        d["auc_score_default"] =  val_auc_score
        d["thresholds"] =  list(micro_thresholds)
        d["threshold_f1s"] =  f1_results
        d["threshold_precs"] =  prec_results
        d["threshold_recalls"] =  recall_results
        json.dump(d, f)
    
    # If 5 epochs pass without improvement consider the model as saturated and exit
    if epochs_without_improvement > 4:
        break

Epoch:   0%|                                                                                    | 0/40 [00:00<?, ?it/s]

Train loss: 0.03934908225991816
Validation loss: 0.032533054414670914
F1 Validation Accuracy:  0.6206509539842874
Precision Validation Accuracy:  0.9770318021201413
Recall Validation Accuracy:  0.45476973684210525
AUC Validation:  0.9631864947883568
Best Threshold:  0.4
Test F1 Accuracy:  0.6477644492911668


Epoch:   2%|█▊                                                                       | 1/40 [02:09<1:23:58, 129.19s/it]

Train loss: 0.03145492828960689
Validation loss: 0.0256824346142821
F1 Validation Accuracy:  0.6925498426023085
Precision Validation Accuracy:  0.9565217391304348
Recall Validation Accuracy:  0.5427631578947368
AUC Validation:  0.9754872355053306
Best Threshold:  0.4
Test F1 Accuracy:  0.7224448897795591


Epoch:   5%|███▋                                                                     | 2/40 [04:14<1:20:19, 126.82s/it]

Train loss: 0.02478542134060331
Validation loss: 0.020070059137651697
F1 Validation Accuracy:  0.7586891757696127
Precision Validation Accuracy:  0.9573934837092731
Recall Validation Accuracy:  0.6282894736842105
AUC Validation:  0.9836131115596338
Best Threshold:  0.4
Test F1 Accuracy:  0.7817836812144212


Epoch:   8%|█████▍                                                                   | 3/40 [06:25<1:19:20, 128.67s/it]

Train loss: 0.020093773715725204
Validation loss: 0.017657555523328483
F1 Validation Accuracy:  0.786527514231499
Precision Validation Accuracy:  0.929372197309417
Recall Validation Accuracy:  0.6817434210526315
AUC Validation:  0.9888136257362636
Best Threshold:  0.4
Test F1 Accuracy:  0.8115419296663662


Epoch:  10%|███████▎                                                                 | 4/40 [08:34<1:17:21, 128.92s/it]

Train loss: 0.01644830618115177
Validation loss: 0.015312486822949722
F1 Validation Accuracy:  0.8275555555555555
Precision Validation Accuracy:  0.9003868471953579
Recall Validation Accuracy:  0.765625
AUC Validation:  0.9922921578301211
Best Threshold:  0.41000000000000003
Test F1 Accuracy:  0.8360017308524449


Epoch:  12%|█████████▏                                                               | 5/40 [10:41<1:14:46, 128.18s/it]

Train loss: 0.01416897512576982
Validation loss: 0.013786337149213068
F1 Validation Accuracy:  0.8475637013857846
Precision Validation Accuracy:  0.9285014691478942
Recall Validation Accuracy:  0.7796052631578947
AUC Validation:  0.9934656750845932
Best Threshold:  0.4
Test F1 Accuracy:  0.8547968885047537


Epoch:  15%|██████████▉                                                              | 6/40 [12:48<1:12:20, 127.65s/it]

Train loss: 0.01222063318105801
Validation loss: 0.013000208862649743
F1 Validation Accuracy:  0.849911190053286
Precision Validation Accuracy:  0.9237451737451737
Recall Validation Accuracy:  0.7870065789473685
AUC Validation:  0.9929668159744288
Best Threshold:  0.4
Test F1 Accuracy:  0.8632034632034632


Epoch:  18%|████████████▊                                                            | 7/40 [14:58<1:10:39, 128.47s/it]

Train loss: 0.011818978413807685
Validation loss: 0.013076848736091051
F1 Validation Accuracy:  0.862023653088042
Precision Validation Accuracy:  0.922211808809747
Recall Validation Accuracy:  0.8092105263157895
AUC Validation:  0.9929256901433166
Best Threshold:  0.53
Test F1 Accuracy:  0.8629173989455184


Epoch:  20%|██████████████▌                                                          | 8/40 [17:05<1:08:18, 128.08s/it]

Train loss: 0.010243189450845404
Validation loss: 0.011936609080294147
F1 Validation Accuracy:  0.8761099365750528
Precision Validation Accuracy:  0.9016536118363795
Recall Validation Accuracy:  0.8519736842105263
AUC Validation:  0.9949967230863669
Best Threshold:  0.51
Test F1 Accuracy:  0.8769100169779288


Epoch:  22%|████████████████▍                                                        | 9/40 [19:15<1:06:28, 128.65s/it]

Train loss: 0.009053533187878476
Validation loss: 0.010589136472844984
F1 Validation Accuracy:  0.8901239846088072
Precision Validation Accuracy:  0.9269813000890472
Recall Validation Accuracy:  0.8560855263157895
AUC Validation:  0.9954523954425761
Best Threshold:  0.44
Test F1 Accuracy:  0.8910555320050869


Epoch:  25%|██████████████████                                                      | 10/40 [21:25<1:04:37, 129.25s/it]

Train loss: 0.00809018095990397
Validation loss: 0.011753163053072058
F1 Validation Accuracy:  0.8704883227176221
Precision Validation Accuracy:  0.8999122036874452
Recall Validation Accuracy:  0.8429276315789473
AUC Validation:  0.9937629713094346


Epoch:  28%|███████████████████▊                                                    | 11/40 [23:34<1:02:22, 129.05s/it]

Best Threshold:  0.43
Test F1 Accuracy:  0.8716870004206984
Train loss: 0.007865371762990143
Validation loss: 0.010847375931916758
F1 Validation Accuracy:  0.8837405223251894
Precision Validation Accuracy:  0.9058721934369602
Recall Validation Accuracy:  0.8626644736842105
AUC Validation:  0.9946476490696523


Epoch:  30%|█████████████████████▌                                                  | 12/40 [25:42<1:00:03, 128.69s/it]

Best Threshold:  0.4
Test F1 Accuracy:  0.8856669428334714
Train loss: 0.007011908712744151
Validation loss: 0.010273305972077651
F1 Validation Accuracy:  0.8813131313131314
Precision Validation Accuracy:  0.9025862068965518
Recall Validation Accuracy:  0.8610197368421053
AUC Validation:  0.9961134978089936


Epoch:  32%|████████████████████████                                                  | 13/40 [27:51<57:56, 128.74s/it]

Best Threshold:  0.47
Test F1 Accuracy:  0.8832147341984095
Train loss: 0.0064832204354285565
Validation loss: 0.010483085421583382
F1 Validation Accuracy:  0.8887949260042284
Precision Validation Accuracy:  0.9147084421235857
Recall Validation Accuracy:  0.8643092105263158
AUC Validation:  0.9945381561756029
Best Threshold:  0.5
Test F1 Accuracy:  0.8887949260042284


Epoch:  35%|█████████████████████████▉                                                | 14/40 [30:01<56:02, 129.32s/it]

Train loss: 0.005698524710924347
Validation loss: 0.010278599616867723
F1 Validation Accuracy:  0.8899006622516558
Precision Validation Accuracy:  0.8958333333333334
Recall Validation Accuracy:  0.884046052631579
AUC Validation:  0.994557204660216
Best Threshold:  0.56
Test F1 Accuracy:  0.8917676556623485


Epoch:  38%|███████████████████████████▊                                              | 15/40 [32:10<53:50, 129.22s/it]

Train loss: 0.005164993883501562
Validation loss: 0.010224846722849179
F1 Validation Accuracy:  0.8898233809924306
Precision Validation Accuracy:  0.9104991394148021
Recall Validation Accuracy:  0.8700657894736842
AUC Validation:  0.9955999436242662
Best Threshold:  0.59
Test F1 Accuracy:  0.8927511657481983


Epoch:  40%|█████████████████████████████▌                                            | 16/40 [34:20<51:41, 129.22s/it]

Train loss: 0.004868349572753464
Validation loss: 0.010554726797636249
F1 Validation Accuracy:  0.8889819857561793
Precision Validation Accuracy:  0.9060631938514091
Recall Validation Accuracy:  0.8725328947368421
AUC Validation:  0.9944486380236244
Best Threshold:  0.44
Test F1 Accuracy:  0.8903654485049833


Epoch:  42%|███████████████████████████████▍                                          | 17/40 [36:25<49:02, 127.93s/it]

Train loss: 0.00452144036617543
Validation loss: 0.010024860488556442
F1 Validation Accuracy:  0.8994132439228835
Precision Validation Accuracy:  0.917094017094017
Recall Validation Accuracy:  0.8824013157894737
AUC Validation:  0.9941813664342675
Best Threshold:  0.54
Test F1 Accuracy:  0.9007569386038689


Epoch:  45%|█████████████████████████████████▎                                        | 18/40 [38:32<46:53, 127.86s/it]

Train loss: 0.004248874939874887
Validation loss: 0.010567951312623336
F1 Validation Accuracy:  0.8897119341563785
Precision Validation Accuracy:  0.8904448105436573
Recall Validation Accuracy:  0.8889802631578947
AUC Validation:  0.9954875144940767


Epoch:  48%|███████████████████████████████████▏                                      | 19/40 [40:38<44:33, 127.30s/it]

Best Threshold:  0.48
Test F1 Accuracy:  0.8910809699958898
Train loss: 0.003880934821900841
Validation loss: 0.010713060961279552
F1 Validation Accuracy:  0.8925895087427144
Precision Validation Accuracy:  0.9038785834738617
Recall Validation Accuracy:  0.881578947368421
AUC Validation:  0.9948842846936029


Epoch:  50%|█████████████████████████████████████                                     | 20/40 [42:43<42:09, 126.49s/it]

Best Threshold:  0.6
Test F1 Accuracy:  0.8935456831517185
Train loss: 0.0038815156920738343
Validation loss: 0.014277639360443573
F1 Validation Accuracy:  0.8691666666666666
Precision Validation Accuracy:  0.8809121621621622
Recall Validation Accuracy:  0.8577302631578947
AUC Validation:  0.9872849119032826


Epoch:  52%|██████████████████████████████████████▊                                   | 21/40 [44:49<40:02, 126.47s/it]

Best Threshold:  0.59
Test F1 Accuracy:  0.8733587462939433
Train loss: 0.00442186042648314
Validation loss: 0.011962634616793366
F1 Validation Accuracy:  0.8736310025273799
Precision Validation Accuracy:  0.8955094991364422
Recall Validation Accuracy:  0.852796052631579
AUC Validation:  0.9944596651186559


Epoch:  55%|████████████████████████████████████████▋                                 | 22/40 [46:55<37:50, 126.13s/it]

Best Threshold:  0.45
Test F1 Accuracy:  0.8754180602006689
Train loss: 0.0042758013082246455
Validation loss: 0.010992140375492454
F1 Validation Accuracy:  0.8890728476821191
Precision Validation Accuracy:  0.895
Recall Validation Accuracy:  0.8832236842105263
AUC Validation:  0.9954033223224787


Epoch:  55%|████████████████████████████████████████▋                                 | 22/40 [49:00<40:05, 133.65s/it]

Best Threshold:  0.6
Test F1 Accuracy:  0.8905597326649959





Although we did not finish training and evaluation, we can see that the Huang et al. BCE pipeline performs better in Micro-F1 within 4 epochs of training, relative to my old BCE pipeline which performed worse after 5 epochs of training on the same metric. This discrepancy can probably be chalked up to a number of factors:

1. Threshold tuning for Optimal F1 Score. My pipeline uses a fixed threshold score of 0.5 to convert predicted probabilities into binary labels. We did not attempt to explore or adjust threshold values during evaluation. On the other hand, Huang et al.'s pipeline utilizes dynamic threshold tuning after each epoch in order to maximize F1 score on the validation set.

2. Higher learning rate. In my original pipeline, I utilized a learning rate of 2e-5, which is five times lower than the learning rate used in the Huang et al. pipeline. The fact that they have a learning rate of 1e-4 probably speeds up the model's convergence, allowing it to reach better performance levels in fewer epochs. Theoretically, my model might have caught up if we had let it run for many more epochs, but this is something we will have to test.

3. Grouping parameters into those with and without weight decay. In my original pipeline, I applied weight decay universally. On the other hand, Huang et al. utilizes grouping to discriminate between those weights that should and should not have weight decay. They use the recommended grouping from huggingface's documentation, applying weight decay to all weights EXCEPT those marked with 'bias' and 'LayerNorm.weight'. This probably allowed Huang et al.'s pipeline to have more effective optimization for weights relative to my pipeline. 

4. Batch size. Originally, I used a batch size of 16 for training. When replicating the Huang et al. pipeline, however, I increased the batch size to 32. This higher batch size probably lead to faster convergence as well. I will have to test my original pipeline with a similar batch size later.

5. Use of learning rate scheduler. I utilized a learning rate scheduler in my original pipeline. The replication of Huang et al. did not. This likely contributed to faster convergence for the Huang et al. pipeline as well, especially considering the fact that we had a lower learning rate -> lead to slower convergence for initial pipeline -> less risk of overshooting minima, especially with few epochs.

Given that both this pipeline and my old pipeline use BCE as the loss, we can surmise that utilizing the other loss functions described in the Huang et al. paper will lead to not only higher micro f-1 scores, but also (and arguably more importantly) better scores on rarer and co-occurring labels.

Once I have finished testing Huang et al.'s methodology, especially the loss functions adjusted for class imbalance and label co-occurrence, I will have to adopt the above changes to my original pipeline and see just how much the performance will improve.