# Test detection of corrupted labels based on synthetic noise
Compare methods:
* Use voting-based method from paper [Detecting Corrupted Labels Without Training a Model to Predict](https://proceedings.mlr.press/v162/zhu22a.html) on STT 2 dataset using BERT embeddings.
* ChatGPT

## Imports

In [1]:
import sys
from pathlib import Path

from matplotlib import pyplot as plt
import pandas as pd
import torch
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformers import BertForSequenceClassification, BertTokenizer, TrainingArguments, Trainer
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

parent_dir = Path("..").absolute()
code_dir = str(parent_dir / "src")
if code_dir not in sys.path:
    sys.path.append(code_dir)
from master_thesis.datasets import load_ste_dataset
from master_thesis.detection import detect_noisy_labels_based_on_local_votes
from master_thesis.noise_generation import add_instance_dependent_noise, add_symmetric_noise
from master_thesis.metrics import compute_metrics

## Config

In [12]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
tqdm.pandas()

In [3]:
DATA_PATH = Path("/home/cicheck/Downloads/stanfordSentimentTreebank/stanfordSentimentTreebank")

## Load dataset

In [4]:
train_df, val_df, test_df = load_ste_dataset(DATA_PATH)

In [5]:
NUMBER_OF_CLASSES = len(train_df["sentiment"].unique())

In [6]:
CLASS_TO_IDX_MAP = {
        sentiment: idx
        for idx, sentiment in enumerate(test_df["sentiment"].unique())
    }

In [7]:
train_df["true_label"] = train_df["sentiment"].map(
    CLASS_TO_IDX_MAP
)
val_df["true_label"] = val_df["sentiment"].map(
    CLASS_TO_IDX_MAP
)
test_df["true_label"] = test_df["sentiment"].map(
    CLASS_TO_IDX_MAP
)
test_df["true_label"].head(2)

1116    0
1117    1
Name: true_label, dtype: int64

## Load pretrained models

In [8]:
TOKENIZER = BertTokenizer.from_pretrained('bert-base-uncased')
PRETRAINED_MODEL = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=NUMBER_OF_CLASSES)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

## Obtain embeddings

In [13]:
def get_bert_embedding(texts: list[str], model: BertForSequenceClassification) -> list[float]:
    """Get text embeddings obtained from last BERT layer averaged via mean.
    
    Args:
        texts: Series of texts.progress_apply
    Returns:
        List of embedding vectors.
    
    """
    tokenized_input = TOKENIZER(texts, return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        # We want to extract BERT embeddings!
        outputs = model.bert(**tokenized_input)
        last_hidden_state = outputs.last_hidden_state
    # alternative approach
    # embeddings = last_hidden_state[:, 0, :]
    embeddings = last_hidden_state.mean(dim=1)
    return embeddings.tolist()

### Without finetuning

In [14]:
train_df["embedding"] = get_bert_embedding(train_df["text"].tolist(), PRETRAINED_MODEL)
val_df["embedding"] = get_bert_embedding(val_df["text"].tolist(), PRETRAINED_MODEL)
test_df["embedding"] = get_bert_embedding(test_df["text"].tolist(), PRETRAINED_MODEL)

### With finetuning

In [14]:
def compute_training_metrics(predictions):
    logits, labels = predictions
    predictions = np.argmax(logits, axis=-1)
    return {
        # By default first metric is used during loading checkpoint
        "f1": f1_score(labels, predictions, average="macro"),
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions, average="macro"),
        "recall": recall_score(labels, predictions, average="macro"),
    }


def plot_learning_curve(train_history, eval_history):
    _, axs = plt.subplots(1, 2, figsize=(10, 5))

    axs[0].plot(train_history["f1"], label="F1")
    axs[0].plot(train_history["precision"], label="Precision")
    axs[0].plot(train_history["recall"], label="Recall")
    axs[0].set_title("Train")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Score")
    axs[0].legend()

    axs[1].plot(eval_history["f1"], label="F1")
    axs[1].plot(eval_history["precision"], label="Precision")
    axs[1].plot(eval_history["recall"], label="Recall")
    axs[1].set_title("Eval")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Score")
    axs[1].legend()

    plt.show()


train_encodings = TOKENIZER(train_df["text"].tolist(), padding=True, truncation=True)
eval_encodings = TOKENIZER(val_df["text"].tolist(), padding=True, truncation=True)


class TextClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]).to(DEVICE) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx]).to(DEVICE)
        return item

    def __len__(self):
        return len(self.labels)


train_dataset = TextClassificationDataset(train_encodings, train_df["true_label"].tolist())
eval_dataset = TextClassificationDataset(eval_encodings, val_df["true_label"].tolist())


# Compute class weights and define loss function
class_counts = np.bincount(train_df["true_label"])
total_count = len(train_df["true_label"])
class_weights = total_count / class_counts


class ClassWeightedTrainer(Trainer):
    def __init__(self, class_weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._class_weights = torch.tensor(class_weights).to(DEVICE)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss
        loss_fct = torch.nn.CrossEntropyLoss(weight=self._class_weights)
        loss = loss_fct(logits.view(-1, NUMBER_OF_CLASSES), labels.view(-1))
        return (loss, outputs) if return_outputs else loss


training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
)

trainer = ClassWeightedTrainer(
    class_weights=class_weights,
    model=PRETRAINED_MODEL,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()
# Load the best model from training
trainer.evaluate()
train_history = {
    key: [x[key] for x in trainer.state.log_history if "train" in x]
    for key in ["f1", "precision", "recall"]
}
eval_history = {
    key: [x[key] for x in trainer.state.log_history if "eval" in x]
    for key in ["f1", "precision", "recall"]
}

# Plot learning curve
plot_learning_curve(train_history, eval_history)
model_path = trainer.state.best_model_checkpoint
FINETUNED_MODEL = BertForSequenceClassification.from_pretrained(model_path)



  0%|          | 0/2670 [00:00<?, ?it/s]

RuntimeError: cannot pin 'torch.cuda.LongTensor' only dense CPU tensors can be pinned

In [None]:
train_df["finetuned_embedding"] = get_bert_embedding(train_df["text"].tolist(), FINETUNED_MODEL)
val_df["finetuned_embedding"] = get_bert_embedding(val_df["text"].tolist(), FINETUNED_MODEL)
test_df["finetuned_embedding"] = get_bert_embedding(test_df["text"].tolist(), FINETUNED_MODEL)

In [38]:
# Save embeddings to save time
train_df.to_pickle("data/ste_train.pickle")
val_df.to_pickle("data/ste_val.pickle")
test_df.to_pickle("data/ste_test.pickle")

In [7]:
train_df = pd.read_pickle("data/ste_train.pickle")
val_df = pd.read_pickle("data/ste_val.pickle")
test_df = pd.read_pickle("data/ste_test.pickle")

## Test detection

In [28]:
corruption_rates_space = [
    0.0,
    0.1,
    0.2,
    0.3,
    0.4,
    0.5,
    0.6,
    0.7,
    0.8,
    0.9,
    1.0,
]

In [31]:
EMBEDDINGS = torch.tensor(test_df["embedding"].tolist())
FINETUNED_EMBEDDINGS

### Symmetric noise

In [33]:
symmetric_noise_metrics = []

for corruption_rate in corruption_rates_space:
    true_labels = test_df["true_label"].tolist()
    noisy_labels = add_symmetric_noise(
        true_labels=true_labels,
        corruption_rate=corruption_rate,
    )
    ground_truth = [
        true_label != noisy_label for true_label, noisy_label in zip(true_labels, noisy_labels)
    ]
    detected_corruptions = detect_noisy_labels_based_on_local_votes(
        features=FEATURES,
        original_labels=torch.tensor(true_labels),
    )
    symmetric_noise_metrics.append(
        {
            "Noise Type": "Symmetric",
            "Noise Rate": corruption_rate,
            "Detection Method": "SimFeat-V",
            **compute_metrics(ground_truth=ground_truth, predictions=detected_corruptions),
        }
    )

## Asymmetric noise
Sentiment can only shift for 1, e.g. `VERY_NEGATIVE` can only transition into `NEGATIVE` and so on. 

In [16]:
SENTIMENT_TRANSITION_MAP = {
    "VERY_NEGATIVE": ["NEGATIVE"],
    "NEGATIVE": ["VERY_NEGATIVE", "NEUTRAL"],
    "NEUTRAL": ["NEGATIVE", "POSITIVE"],
    "POSITIVE": ["NEUTRAL", "VERY_POSITIVE"],
    "VERY_POSITIVE": ["POSITIVE"],
}
SENTIMENT_TRANSITION_MAP = {
    CLASS_TO_IDX_MAP[sentiment]: [
        CLASS_TO_IDX_MAP[transition_sentiment] for transition_sentiment in transition_sentiments
    ]
    for sentiment, transition_sentiments in SENTIMENT_TRANSITION_MAP.items()
}