# Cleanlab Experiments

In [None]:
import pandas as pd
import numpy as np
import torch
from datasets import Dataset
from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
from setfit import SetFitModel, SetFitTrainer
from cleanlab.filter import find_label_issues
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx # Added for the network graph
import logging
import gc
import os

## Configurations

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configuration Constants (Using k=5 as requested)
N_SPLITS = 5  # K for cross-validation
BASE_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
CLASSIFIER_MAX_ITER = 2000
RANDOM_STATE = 42
BATCH_SIZE = 16
NUM_CONTRASTIVE_ITERATIONS = 20

# Device Selection
DEVICE = "cuda" if torch.cuda.is_available() else "mps" # we're on a MB
if not torch.cuda.is_available() and not torch.backends.mps.is_available():
     DEVICE = "cpu"
logger.info(f"Using device: {DEVICE}")

## Load Data

In [None]:
try:
    df_train = pd.read_csv("dataset/train.tsv",sep="\t")
    df_eval = pd.read_csv("dataset/dev.tsv",sep="\t")
    df_test = pd.read_csv("dataset/test.tsv",sep="\t")
    df = pd.concat([df_train,df_eval,df_test])
    del df_train, df_eval, df_test
    logger.info("Data loaded successfully.")
except FileNotFoundError:
    logger.error("Error: Dataset files not found. Make sure 'train.tsv', 'dev.tsv', and 'test.tsv' are in the 'dataset' directory.")
    exit() # Or handle error appropriately
except Exception as e:
    logger.error(f"Error loading data: {e}")
    exit()

In [None]:
# Load data
logger.info("Preparing data...")
texts = df["text"].tolist()
noisy_labels_raw = df["label"].values # Keep original potentially string labels if needed

# Encode labels to integers if they are not already
label_encoder = None # Initialize
if not np.issubdtype(noisy_labels_raw.dtype, np.integer):
    logger.info("Labels are not numeric. Applying LabelEncoder.")
    label_encoder = LabelEncoder()
    noisy_labels = label_encoder.fit_transform(noisy_labels_raw)
    logger.info(f"Label mapping: {dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))}\")")
    num_classes = len(label_encoder.classes_)
else:
    logger.info("Labels are already numeric.")
    noisy_labels = noisy_labels_raw.astype(int) # Ensure integer type
    # label_encoder remains None
    num_classes = len(np.unique(noisy_labels))

logger.info(f"Number of samples: {len(texts)}")
logger.info(f"Number of classes: {num_classes}")

## K-Fold Cross Validation Loop to obtain OOF predictions

In [None]:
# Stratified K-Fold Cross-Validation Loop 
logger.info(f"Starting Stratified {N_SPLITS}-Fold cross-validation to get OOF predictions...")
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)
# Initialize array to store out-of-fold predictions
oof_preds = np.zeros((len(texts), num_classes), dtype=float)
# Keep track of original indices corresponding to oof_preds rows
original_indices = np.arange(len(texts))

fold_counter = 0
for train_index, val_index in skf.split(texts, noisy_labels):
    fold_counter += 1
    logger.info(f"--- Starting Fold {fold_counter}/{N_SPLITS} ---")

    # Split Data for this Fold
    train_texts = [texts[i] for i in train_index]
    train_labels = noisy_labels[train_index]
    val_texts = [texts[i] for i in val_index]
    val_original_indices = original_indices[val_index]

    # Convert training data to Hugging Face Dataset format for SetFitTrainer
    train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
    logger.info(f"Fold {fold_counter}: Train size={len(train_texts)}, Validation size={len(val_texts)}")

    # Fine-tune Sentence Transformer Body using SetFitTrainer
    logger.info(f"Fold {fold_counter}: Fine-tuning SetFit body...")
    setfit_model_for_body_tuning = SetFitModel.from_pretrained(BASE_MODEL_NAME)
    setfit_model_for_body_tuning.to(DEVICE)

    trainer = SetFitTrainer(
        model=setfit_model_for_body_tuning,
        train_dataset=train_dataset,
        num_iterations=NUM_CONTRASTIVE_ITERATIONS,
        batch_size=BATCH_SIZE,
        seed=RANDOM_STATE,
    )
    trainer.train() # Runs contrastive body tuning + temporary head training

    # Extract the Fine-tuned Body
    fine_tuned_body = trainer.model.model_body.to(DEVICE)
    logger.info(f"Fold {fold_counter}: Body fine-tuning complete.")

    # Generate Embeddings using the Fine-tuned Body
    logger.info(f"Fold {fold_counter}: Generating embeddings...")
    with torch.no_grad():
        inference_batch_size = BATCH_SIZE * 2
        train_embeddings = fine_tuned_body.encode(train_texts, convert_to_tensor=True, device=DEVICE, batch_size=inference_batch_size)
        val_embeddings = fine_tuned_body.encode(val_texts, convert_to_tensor=True, device=DEVICE, batch_size=inference_batch_size)

    train_embeddings_np = train_embeddings.cpu().numpy()
    val_embeddings_np = val_embeddings.cpu().numpy()
    logger.info(f"Fold {fold_counter}: Embeddings generated.")

    # Train Logistic Regression Head Manually (Handling Skew)
    logger.info(f"Fold {fold_counter}: Training separate balanced Logistic Regression head...")
    manual_classifier_head = LogisticRegression(
        class_weight='balanced',
        max_iter=CLASSIFIER_MAX_ITER,
        random_state=RANDOM_STATE,
        solver='liblinear',
        C=1.0
    )
    manual_classifier_head.fit(train_embeddings_np, train_labels)
    logger.info(f"Fold {fold_counter}: Manual head training complete.")

    # Predict Probabilities on Validation Set using Manual Head
    logger.info(f"Fold {fold_counter}: Predicting probabilities for validation set...")
    val_pred_probs = manual_classifier_head.predict_proba(val_embeddings_np)

    # Store Out-of-Fold Predictions
    oof_preds[val_original_indices] = val_pred_probs
    logger.info(f"Fold {fold_counter}: Stored OOF predictions.")

    # Cleanup Fold Resources
    del setfit_model_for_body_tuning, trainer, fine_tuned_body, manual_classifier_head
    del train_embeddings, val_embeddings, train_embeddings_np, val_embeddings_np, train_dataset
    if DEVICE == "cuda":
        torch.cuda.empty_cache()
    gc.collect()
    logger.info(f"--- Finished Fold {fold_counter}/{N_SPLITS} ---")

logger.info("Cross-validation finished. Out-of-fold predictions generated for all samples.")

In [None]:
# Run cleanlab
logger.info("Running cleanlab to find potential label issues...")
try:
    label_issues_indices = find_label_issues(
        labels=noisy_labels,
        pred_probs=oof_preds,
        return_indices_ranked_by='self_confidence', # Rank by confidence in given label
    )
    num_issues = len(label_issues_indices)
    logger.info(f"Cleanlab found {num_issues} potential label issues out of {len(texts)} samples.")
except Exception as e:
    logger.error(f"Error running cleanlab: {e}")
    logger.error("Cannot proceed without label issue indices.")
    exit()

In [None]:
# Inspect cleanlab results
logger.info("\n--- Analyzing Top Potential Label Issues ---")
N_ISSUES_TO_INSPECT = 50
if num_issues > 0:
    top_n_issues = min(N_ISSUES_TO_INSPECT, num_issues)
    logger.info(f"\n--- Top {top_n_issues} Potential Label Issues (Ranked by 'self_confidence') ---")

    # Get the indices for the top N issues
    top_issues_indices = label_issues_indices[:top_n_issues]

    # Create DataFrame for inspection
    issues_df = df.iloc[top_issues_indices].copy()
    issues_df['cleanlab_rank'] = range(top_n_issues) # Add rank 0..N-1

    # Get original labels (numeric and string if available)
    original_numeric_labels_issues = noisy_labels[top_issues_indices]
    if label_encoder:
        issues_df['original_label'] = label_encoder.inverse_transform(original_numeric_labels_issues)
    else:
        issues_df['original_label'] = original_numeric_labels_issues

    # Add predicted probability for the *given* label (self-confidence)
    pred_prob_for_given_label = [oof_preds[idx, noisy_labels[idx]] for idx in top_issues_indices]
    issues_df['pred_prob_for_given_label'] = pred_prob_for_given_label

    # Add the label cleanlab would suggest (based on highest predicted probability)
    suggested_numeric_label = np.argmax(oof_preds[top_issues_indices], axis=1)
    if label_encoder:
         issues_df['suggested_label'] = label_encoder.inverse_transform(suggested_numeric_label)
    else:
         issues_df['suggested_label'] = suggested_numeric_label

    # Select and print columns for inspection
    print(issues_df[['cleanlab_rank', 'text', 'original_label', 'suggested_label', 'pred_prob_for_given_label']].to_string())
    logger.info("--- End Potential Issues List ---")
else:
    logger.info("No potential label issues identified by cleanlab with current settings.")

In [None]:
# Calculate Conflict Matrix
conflict_matrix = None
class_names = []
if num_issues > 0:
    logger.info("\n--- Calculating Conflict Matrix for Flagged Issues ---")
    # Get original and predicted labels *only for the flagged issues*
    original_labels_for_issues = noisy_labels[label_issues_indices]
    predicted_labels_for_issues = np.argmax(oof_preds[label_issues_indices], axis=1)

    # Get class names consistently
    if label_encoder:
        class_names = label_encoder.classes_
        if len(class_names) != num_classes:
             logger.warning(f"Mismatch between label_encoder classes ({len(class_names)}) and num_classes ({num_classes}). Using encoder classes.")
             num_classes = len(class_names) # Adjust num_classes if needed based on encoder
    else:
        class_names = [str(i) for i in range(num_classes)]

    # Compute the confusion matrix if dimensions match
    if len(class_names) == num_classes:
        conflict_matrix = confusion_matrix(
            y_true=original_labels_for_issues,
            y_pred=predicted_labels_for_issues,
            labels=np.arange(num_classes) # Ensure all classes 0..N-1 are included
        )
        logger.info("Conflict matrix calculated.")
    else:
        logger.error(f"Final class name length ({len(class_names)}) does not match num_classes ({num_classes}). Cannot create matrix.")
        conflict_matrix = None # Ensure it's None if calculation failed


## Heatmap

In [None]:
# Conflict Heatmap
logger.info("\n--- Generating Basic Conflict Heatmap ---")
if conflict_matrix is not None:
    # Create DataFrame for visualization (without sums)
    conflict_df = pd.DataFrame(conflict_matrix, index=class_names, columns=class_names)

    plt.figure(figsize=(10, 8))
    sns.heatmap(conflict_df, # Plot the raw matrix
                annot=True,
                fmt="d",
                cmap="viridis",
                linewidths=.5,
                cbar=True)
    plt.title('Label Conflict Heatmap (Counts of Original vs. Predicted Labels for Flagged Issues)')
    plt.xlabel('Predicted Label (Suggested by Model)')
    plt.ylabel('Original Label (From Noisy Dataset)')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    # plt.savefig("label_conflict_heatmap_basic.png") # Optional save
    plt.show()
    logger.info("Basic heatmap generated.")
else:
    logger.info("Conflict matrix not available or empty, skipping heatmap generation.")
# --- END: Output 2 ---

## Network Graph

In [None]:
# Network Graph
logger.info("\n--- Generating Conflict Network Graph ---")
if conflict_matrix is not None and conflict_matrix.any():
    G = nx.DiGraph() # Directed graph

    # Add nodes (classes)
    for i, name in enumerate(class_names):
        # Optional: Calculate node attributes like total conflicts (can be used for size/color)
        # total_outgoing = conflict_matrix[i, :].sum() - conflict_matrix[i, i]
        # total_incoming = conflict_matrix[:, i].sum() - conflict_matrix[i, i]
        # G.add_node(name, outgoing=total_outgoing, incoming=total_incoming)
        G.add_node(name) # Simpler version without attributes for now

    # Add edges (conflicts) - thicker/darker for more conflicts
    max_conflict = conflict_matrix[~np.eye(num_classes, dtype=bool)].max() # Max off-diagonal value
    min_conflict_display = 1 # Minimum count to draw an edge

    edges_to_add = []
    for i in range(num_classes):
        for j in range(num_classes):
            # Only add edges for off-diagonal conflicts above the minimum threshold
            if i != j and conflict_matrix[i, j] >= min_conflict_display:
                weight = conflict_matrix[i, j]
                # Scale weight for visual thickness (adjust scaling factor '10' as needed)
                scaled_viz_weight = 1 + 10 * (weight / max_conflict if max_conflict > 0 else 0)
                edges_to_add.append((class_names[i], class_names[j], {'weight': weight, 'viz_weight': scaled_viz_weight}))

    if not edges_to_add:
        logger.info("No off-diagonal conflicts >= {min_conflict_display} found to draw in the network graph.")
    else:
        G.add_edges_from(edges_to_add)
        logger.info(f"Added {len(edges_to_add)} edges to the network graph.")

        plt.figure(figsize=(14, 14)) # Adjust figure size as needed
        # Use a layout algorithm (spring_layout is common)
        pos = nx.spring_layout(G, k=0.9, iterations=50, seed=RANDOM_STATE) # Increase k for more spread

        # Get edge widths from the 'viz_weight' attribute
        edge_widths = [d['viz_weight'] for u, v, d in G.edges(data=True)]

        # Draw the network components
        nx.draw_networkx_nodes(G, pos, node_size=700, node_color='lightblue', alpha=0.9)
        nx.draw_networkx_edges(G, pos, width=edge_widths, edge_color='grey', alpha=0.6,
                               arrows=True, arrowstyle='-|>', arrowsize=15, # Directed arrows
                               connectionstyle='arc3,rad=0.1') # Slightly curved edges
        nx.draw_networkx_labels(G, pos, font_size=10, font_weight='bold')

        plt.title("Label Conflict Network Graph (Edge Thickness ~ Conflict Count for Flagged Issues)")
        plt.axis('off') # Hide the axes
        plt.tight_layout()
        plt.show()
        logger.info("Network graph generated.")

elif conflict_matrix is not None and not conflict_matrix.any():
     logger.info("Conflict matrix exists but is all zeros. No conflicts to draw.")
else:
    logger.info("Conflict matrix not available, skipping network graph generation.")

logger.info("\nScript finished.")