In [None]:
# --- Helper Functions ---
import random
import numpy as np
import torch
import os

RANDOM_SEED = 42 # Your chosen seed

def set_all_seeds(seed_value):
    print(f"Setting all seeds to: {seed_value}")
    os.environ['PYTHONHASHSEED'] = str(seed_value) # Set it for consistent hashing
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value) # for multi-GPU.
        torch.cuda.manual_seed(seed_value)     # for current GPU.

# Call this AT THE VERY BEGINNING of your script, before almost any other import or operation
set_all_seeds(RANDOM_SEED)

# ... then your other imports like pandas, tqdm, sklearn, transformers ...
# ... then your model definition, tokenizer, dataset, dataloader creation ...
# ... then your training loop ...

In [None]:
import os
import pandas as pd
import re

Context Window Set


In [None]:
def create_context_window(text, target_string, window_size=200):

    target_index = text.find(target_string)

    if target_index != -1:
        start_index = max(0, target_index - window_size)
        end_index = min(len(text), target_index + len(target_string) + window_size)
        context_window = text[start_index:end_index]
        return context_window

    return None


In [None]:
import re
import string
def clean_text(text, remove_non_printable=True):
    """
    Clean a single text string for BART fine-tuning.
    
    Args:
    - text (str): The input text.
    - remove_non_printable (bool): Whether to remove non-printable characters.
    
    Returns:
    - str: Cleaned text.
    """
    if not isinstance(text, str):
        return ""
    
    # 1. Strip leading/trailing whitespace
    cleaned = text.strip()
    
    # 2. Replace multiple newlines/tabs with a space
    cleaned = re.sub(r'[\r\n\t]+', ' ', cleaned)
    
    # 3. Remove excessive spaces
    cleaned = re.sub(r'\s+', ' ', cleaned)
    
    # 4. Optionally remove embedded <eos> or </s> tokens (Bart uses </s> as EOS)
    cleaned = re.sub(r'(</s>|<eos>)', '', cleaned)
    
    # 5. Optionally remove non-printable characters
    if remove_non_printable:
        printable_chars = set(string.printable)
        cleaned = ''.join(filter(lambda x: x in printable_chars, cleaned))
    
    return cleaned

In [None]:
import pandas as pd
import numpy as np

def process_dataframe(input_df: pd.DataFrame):
    """
    Prepare X_text, X_candidate, Y_labels from a raw dataframe.
    - Cleans NaN/±inf
    - Casts text/candidate to string
    - Casts label to int (dropping NaN labels)
    - Builds clean_text and modified_text columns
    Returns: (X_text, X_candidate, Y_labels)
    """

    # --- 1) Validate input ---
    if not isinstance(input_df, pd.DataFrame):
        raise TypeError("Input must be a Pandas DataFrame.")
    required_cols = {"Issue_id","text", "candidate_string", "label"}
    missing = required_cols - set(input_df.columns)
    if missing:
        raise KeyError(f"Missing required column(s): {missing}")

    if input_df.empty:
        print("Warning: Input DataFrame is empty.")
        return [], [], []

    # --- 2) Work on a copy ---
    df = input_df.copy()

    # --- 3) Clean candidate_string & text: replace inf, drop NaN if you need strict non-empty ---
    for col in ["candidate_string", "text"]:
        df[col] = df[col].replace([np.inf, -np.inf], np.nan)
        # If you prefer dropping rows where these are missing, uncomment next line:
        df = df.dropna(subset=[col])
        # Otherwise, keep rows but fill with empty string to avoid tokenizer errors:
        #df[col] = df[col].astype("string").fillna("")

    # --- 4) Clean label: replace inf→NaN, drop NaN, cast to int ---
    df["label"] = df["label"].replace([np.inf, -np.inf], np.nan)
    df = df.dropna(subset=["label"])
    df["label"] = df["label"].astype(int)

    # --- 5) Build clean_text then modified_text ---
    # Ensure clean_text always returns str; wrap if needed
    def _clean_safe(x):
        try:
            return clean_text("" if x is None else str(x))
        except Exception:
            # Fallback: still return string to avoid downstream type errors
            return "" if x is None else str(x)

    df["clean_text"] = df["text"].map(_clean_safe)

    # create_context_window(clean_text, candidate_string)
    def _ctx_safe(row):
        try:
            return create_context_window(row["clean_text"], row["candidate_string"])
        except Exception:
            # Fallback: if something goes wrong, just return the clean_text
            return row["clean_text"]

    df["modified_text"] = df.apply(_ctx_safe, axis=1)

    # --- 6) Build outputs ---
    X_text = df["modified_text"].astype(str).tolist()
    X_candidate = df["candidate_string"].astype(str).tolist()
    Y_labels = df["label"].tolist()

    return X_text, X_candidate, Y_labels


## works fine

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report,f1_score,precision_score,recall_score
import numpy as np
import os # For path joining
from tqdm.auto import tqdm

In [None]:
import gc
gc.collect()
if torch.cuda.is_available(): # If using PyTorch
    torch.cuda.empty_cache()

# Parameters

In [None]:
# --- Existing Strong Performers & Baselines ---
# 1. RoBERTa (Robustly Optimized BERT Pretraining Approach)
#MODEL_NAME    = "roberta-base"


# 2,3. BERT (Bidirectional Encoder Representations from Transformers)
#MODEL_NAME    = "bert-base-cased"
#MODEL_NAME    = "bert-base-uncased"


# 4. ELECTRA (Efficiently Learning an Encoder that Classifies Token Replacements Accurately)
#MODEL_NAME    = "google/electra-base-discriminator"


# 5,6. DistilBERT (Distilled version of BERT)
#MODEL_NAME    = "distilbert-base-cased"
#MODEL_NAME    = "distilbert-base-uncased"

# 7. CodeBERT (Pre-trained on code and natural language)
#MODEL_NAME    = "microsoft/codebert-base"

# --- Newer or More Specialized Models to Experiment With ---

# 8. ALBERT (A Lite BERT for Self-supervised Learning of Language Representations)
# - Parameter reduction techniques for lower memory consumption and faster training.
# - Often performs well, good efficiency/performance trade-off.
#MODEL_NAME    = "albert-base-v2"
#Crashing idk why

# 9. XLNet (Generalized Autoregressive Pretraining for Language Understanding)
# - Uses a permutation language modeling objective, different from BERT's masked LM.
# - Can capture bidirectional context well.
#MODEL_NAME    = "xlnet-base-cased"


# 10. BigBird (Sparse Attention for Longer Sequences) - Another option for long sequences
# - Uses sparse attention mechanisms (block sparse attention) to handle longer inputs efficiently.
# - Often a strong performer for tasks requiring understanding of long contexts.
MODEL_NAME    = "google/bigbird-roberta-base" # Based on RoBERTa architecture
#Crashing idk why


# 11. Funnel-Transformer
# - Gradually reduces the sequence length in deeper layers, focusing computation on higher-level representations.
# - Can be more efficient.
#MODEL_NAME    = "funnel-transformer/medium"


# 12. LUKE (Language Understanding with Knowledge-based Embeddings)
# - Enhances language models by incorporating entity embeddings and knowledge graph information.
# - Could be interesting if your secrets are named entities or have known structures.
#MODEL_NAME    = "studio-ousia/luke-base"



In [None]:
# --- Data Loading and Preprocessing (Placeholder) ---
train_df = pd.read_csv("../Data/train.csv")
val_df = pd.read_csv("../Data/val.csv")
test_df = pd.read_csv("../Data/test.csv")




DATASET_TYPE = "balanced"

X_text_train,X_candidate_train,Y_labels_train = process_dataframe(train_df)
X_text_val,X_candidate_val,Y_labels_val = process_dataframe(val_df)
X_text_test,X_candidate_test,Y_labels_test = process_dataframe(test_df)



print(f"Train samples: {len(X_text_train)}")
print(f"Validation samples: {len(X_text_val)}")
print(f"Test samples: {len(X_text_test)}")




In [None]:
# --- Configuration ---

RANDOM_SEED = 42
BATCH_SIZE = 8
NUM_EPOCHS = 1 # Reduced for quicker demonstration, increase for real training
LEARNING_RATE = 2e-5 # Common starting point for fine-tuning transformers
MAX_LENGTH = 256 # Reduced for potentially faster training, RoBERTa can handle 512
BEST_MODEL_PATH = "models/"+DATASET_TYPE+"/"+"best_"+MODEL_NAME.replace("/", "_")+"_model.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# For your specific task, it seems you want to encode text and candidate separately
# and then pass them to a model that can handle two separate inputs.
# However, RobertaForSequenceClassification expects either:
# 1. A single sequence: tokenizer(text, truncation=True, padding=True)
# 2. A pair of sequences: tokenizer(text1, text2, truncation=True, padding=True)
#
# Your original `CustomDataset` suggests you are passing two separate tokenized inputs.
# This implies your model needs to be able to process them.
# If you are using `RobertaForSequenceClassification` directly, it will interpret the *first*
# set of input_ids and attention_mask as the primary input.
#
# Let's adjust to the common way of using `RobertaForSequenceClassification` for pairs:
# Concatenate or pass as pair to tokenizer
# Option A: Concatenate with [SEP]
# X_combined_train = [text + " [SEP] " + cand for text, cand in zip(X_text_train, X_candidate_train)]
# train_encodings = tokenizer(X_combined_train, padding=True, truncation=True, return_tensors='pt', max_length=MAX_LENGTH)

# --- Tokenization ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Option B: Pass as pair (preferred if model supports it well, like BERT, RoBERTa)
print("Tokenizing training data...")
train_encodings = tokenizer(X_text_train, X_candidate_train, padding=True, truncation=True, return_tensors='pt', max_length=MAX_LENGTH)
print("Tokenizing validation data...")
val_encodings = tokenizer(X_text_val, X_candidate_val, padding=True, truncation=True, return_tensors='pt', max_length=MAX_LENGTH)
print("Tokenizing test data...")
test_encodings = tokenizer(X_text_test, X_candidate_test, padding=True, truncation=True, return_tensors='pt', max_length=MAX_LENGTH)


In [None]:
class PairDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

    def __len__(self):
        return len(self.labels)

# Convert labels to numpy arrays of integers
Y_labels_train_arr = np.array(Y_labels_train).astype(int)
Y_labels_val_arr = np.array(Y_labels_val).astype(int)
Y_labels_test_arr = np.array(Y_labels_test).astype(int)

train_dataset = PairDataset(train_encodings, Y_labels_train_arr)
val_dataset = PairDataset(val_encodings, Y_labels_val_arr)
test_dataset = PairDataset(test_encodings, Y_labels_test_arr)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# --- Model, Optimizer, Scheduler ---
# Determine number of unique labels
num_labels = len(np.unique(np.concatenate([Y_labels_train_arr, Y_labels_val_arr, Y_labels_test_arr])))
print(f"Number of unique labels: {num_labels}")

model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=num_labels,use_safetensors=True)
model.to(DEVICE)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

num_training_steps = NUM_EPOCHS * len(train_loader)
num_warmup_steps = int(0.1 * num_training_steps) # 10% warmup is common

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)


In [None]:
from pathlib import Path

path = Path("models/"+DATASET_TYPE)
path.mkdir(parents=True, exist_ok=True)

In [None]:
import os
import math
import torch
from pathlib import Path
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score, precision_recall_fscore_support

# Define checkpoint directory and paths
CKPT_DIR = Path("models/" + DATASET_TYPE + "/" + MODEL_NAME.replace("/", "_") + "/checkpoints")
CKPT_DIR.mkdir(parents=True, exist_ok=True)
BEST_CKPT_PATH = CKPT_DIR / "best.pt"
LATEST_CKPT_PATH = CKPT_DIR / "latest.pt"

# Function to save checkpoint
def save_checkpoint(path, model, optimizer, scheduler, epoch, global_step, best_metric):
    state = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None,
        "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
        "epoch": epoch,
        "global_step": global_step,
        "best_metric": best_metric,
    }
    torch.save(state, path)

# Function to load checkpoint
def load_checkpoint(path, model, optimizer=None, scheduler=None):
    ckpt = torch.load(path, map_location="cpu")
    model.load_state_dict(ckpt["model_state_dict"])
    if optimizer is not None and ckpt.get("optimizer_state_dict") is not None:
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    if scheduler is not None and ckpt.get("scheduler_state_dict") is not None:
        scheduler.load_state_dict(ckpt["scheduler_state_dict"])
    return ckpt.get("epoch", 0), ckpt.get("global_step", 0), ckpt.get("best_metric", -float("inf"))

# ---- If you want to resume automatically when 'latest.pt' exists ----
resume = os.path.exists(LATEST_CKPT_PATH)
start_epoch = 0
global_step = 0
best_val_f1_macro = -float("inf")  # fix double-init bug

if resume:
    start_epoch, global_step, best_val_f1_macro = load_checkpoint(LATEST_CKPT_PATH, model, optimizer, scheduler)

# --- Optional: also save periodic step checkpoints ---
SAVE_EVERY_STEPS = 0  # set e.g. 500 to enable

# --- Compute class weights for imbalanced dataset ---
def compute_weights(train_loader):
    all_labels = []
    for batch in train_loader:
        all_labels.extend(batch['labels'].cpu().numpy())
    class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(all_labels), y=all_labels)
    return torch.tensor(class_weights, dtype=torch.float).to(DEVICE)

# Get class weights for the dataset
# class_weights = compute_weights(train_loader)
# criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

# --- Training and Validation Loop ---
for epoch in range(start_epoch, NUM_EPOCHS):
    print(f"\n--- Epoch {epoch + 1}/{NUM_EPOCHS} ---")
    
    # Training
    model.train()
    total_train_loss = 0
    train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)
    for batch_idx, batch in enumerate(train_progress_bar):
        optimizer.zero_grad(set_to_none=True)
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        token_type_ids = batch.get('token_type_ids')
        kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
        if token_type_ids is not None:
            kwargs["token_type_ids"] = token_type_ids.to(DEVICE)

        outputs = model(**kwargs)
        #loss = criterion(outputs.logits, labels)  # Use weighted loss here
        loss = outputs.loss
        total_train_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        global_step += 1
        train_progress_bar.set_postfix({'loss': f"{loss.item():.4f}", 'step': global_step})

        # Periodic step checkpoint (optional)
        if SAVE_EVERY_STEPS and (global_step % SAVE_EVERY_STEPS == 0):
            step_path = CKPT_DIR / f"step_{global_step}.pt"
            save_checkpoint(step_path, model, optimizer, scheduler, epoch, global_step, best_val_f1_macro)
    
    avg_train_loss = total_train_loss / max(1, len(train_loader))
    print(f"Average Training Loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    total_val_loss = 0
    all_val_preds, all_val_labels = [], []
    val_progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1} Validation", leave=False)
    with torch.no_grad():
        for batch in val_progress_bar:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            token_type_ids = batch.get('token_type_ids')
            kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
            if token_type_ids is not None:
                kwargs["token_type_ids"] = token_type_ids.to(DEVICE)

            outputs = model(**kwargs)
            #loss = criterion(outputs.logits, labels)  # Use weighted loss for validation too
            loss = outputs.loss
            total_val_loss += loss.item()

            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)
            all_val_preds.extend(preds.cpu().numpy())
            all_val_labels.extend(labels.cpu().numpy())

            val_progress_bar.set_postfix({'val_loss': f"{loss.item():.4f}"})

    avg_val_loss = total_val_loss / max(1, len(val_loader))
    print(f"\n--- Detailed Validation Metrics for Epoch {epoch + 1} ---")
    print(f"Validation Loss: {avg_val_loss:.4f}")

    unique_labels_true, counts_true = np.unique(all_val_labels, return_counts=True)
    print(f"True label distribution in validation set: {dict(zip(unique_labels_true, counts_true))}")
    unique_labels_pred, counts_pred = np.unique(all_val_preds, return_counts=True)
    print(f"Predicted label distribution in validation set: {dict(zip(unique_labels_pred, counts_pred))}")

    labels_for_report = sorted(list(set(all_val_labels) | set(all_val_preds)))
    print("\nDetailed Classification Report (Validation):")
    print(classification_report(all_val_labels, all_val_preds, labels=labels_for_report, zero_division=0))

    val_accuracy = accuracy_score(all_val_labels, all_val_preds)
    if num_labels == 2:
        val_precision_binary, val_recall_binary, val_f1_binary, _ = precision_recall_fscore_support(
            all_val_labels, all_val_preds, average='binary', pos_label=1, zero_division=0
        )
    val_precision_macro, val_recall_macro, val_f1_macro, _ = precision_recall_fscore_support(
        all_val_labels, all_val_preds, average='macro', zero_division=0
    )
    val_precision_weighted, val_recall_weighted, val_f1_weighted, _ = precision_recall_fscore_support(
        all_val_labels, all_val_preds, average='weighted', zero_division=0
    )

    print(f"\nOverall Validation Metrics (Epoch {epoch+1}):")
    print(f"Validation Accuracy: {val_accuracy:.4f}")
    if num_labels == 2:
        print(f"Validation F1-Score (Binary, for class 1): {val_f1_binary:.4f}")
        print(f"Validation Precision (Binary, for class 1): {val_precision_binary:.4f}")
        print(f"Validation Recall (Binary, for class 1): {val_recall_binary:.4f}")
    print(f"Validation F1-Score (Macro): {val_f1_macro:.4f}")
    print(f"Validation Precision (Macro): {val_precision_macro:.4f}")
    print(f"Validation Recall (Macro): {val_recall_macro:.4f}")
    print(f"Validation F1-Score (Weighted): {val_f1_weighted:.4f}")
    print(f"Validation Precision (Weighted): {val_precision_weighted:.4f}")
    print(f"Validation Recall (Weighted): {val_recall_weighted:.4f}")

    # Choose metric to optimize
    current_metric_to_optimize = val_f1_macro  # or val_f1_binary, etc.

    # --- Save latest checkpoint every epoch (for resume) ---
    save_checkpoint(LATEST_CKPT_PATH, model, optimizer, scheduler, epoch, global_step, best_val_f1_macro)

    # --- Save best model (by chosen metric) ---
    if current_metric_to_optimize > best_val_f1_macro:
        best_val_f1_macro = current_metric_to_optimize
        save_checkpoint(BEST_CKPT_PATH, model, optimizer, scheduler, epoch, global_step, best_val_f1_macro)
        # (Optional) also keep a copy at your existing BEST_MODEL_PATH if needed
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print(f"🌟 New best model saved to {BEST_CKPT_PATH} (metric: {best_val_f1_macro:.4f})")
    else:
        print(f"No improvement this epoch. Best metric so far: {best_val_f1_macro:.4f}")


In [None]:
import shutil


if os.path.exists(CKPT_DIR):
    shutil.rmtree(CKPT_DIR)  # removes everything inside
    print(f"Removed: {CKPT_DIR}")
else:
    print("Directory does not exist")

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from typing import List, Optional, Union, Any

def plot_confusion_matrix(
    y_true: Union[np.ndarray, pd.Series, List[Any]],
    y_pred: Union[np.ndarray, pd.Series, List[Any]],
    classes: Optional[List[str]] = None,
    normalize: bool = False,
    title: str = 'Confusion Matrix',
    cmap: str = plt.cm.Blues, # Colormap, e.g., Blues, Greens, YlOrRd
    figsize: tuple = (8, 6),
    fmt: Optional[str] = None, # Format for annotations (e.g., '.2f' for normalize)
    print_raw_matrix: bool = False # Option to print the raw matrix to console
) -> None:
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.

    Args:
        y_true (array-like): Ground truth (correct) target values.
        y_pred (array-like): Estimated targets as returned by a classifier.
        classes (list of str, optional): List of names to map to classes (e.g., ['Class 0', 'Class 1']).
                                         If None, integers 0, 1, 2... will be used.
        normalize (bool, optional): Whether to normalize the confusion matrix. Defaults to False.
        title (str, optional): Title for the plot. Defaults to 'Confusion Matrix'.
        cmap (str or Colormap, optional): Matplotlib colormap. Defaults to plt.cm.Blues.
        figsize (tuple, optional): Figure size (width, height) in inches. Defaults to (8, 6).
        fmt (str, optional): String formatting code to use when adding annotations.
                             Defaults to 'd' for integers, '.2f' if normalize=True.
        print_raw_matrix (bool, optional): If True, prints the raw confusion matrix array to the console.
                                          Defaults to False.
    """
    # Calculate confusion matrix
    cm = confusion_matrix(y_true, y_pred)

    if print_raw_matrix:
        print("Raw Confusion Matrix:")
        print(cm)
        print("-" * 30)

    # Determine class labels for the plot
    if classes is None:
        # Infer classes from the unique values present in y_true and y_pred
        inferred_classes = np.unique(np.concatenate((y_true, y_pred)))
        tick_labels = [str(c) for c in inferred_classes]
    else:
        # Ensure the provided classes are used for plotting
        # The confusion_matrix might have a different order or subset if labels argument isn't used
        # For plotting, we use the provided 'classes' directly for tick labels.
        # If `labels` argument was used in `confusion_matrix` call, cm will match this order.
        # If not, and y_true/y_pred don't contain all classes, the plot might be misleading.
        # It's best if 'classes' provided here match the 'labels' used to generate 'cm'.
        cm = confusion_matrix(y_true, y_pred, labels=classes if classes else None)
        tick_labels = classes

    # Handle normalization
    if normalize:
        # Calculate row sums, handle cases where sum is 0 to avoid division by zero
        row_sums = cm.sum(axis=1)[:, np.newaxis]
        # Replace 0s in row_sums with 1s to avoid division by zero, result will be 0 anyway for that row
        safe_row_sums = np.where(row_sums == 0, 1, row_sums)
        cm_normalized = cm.astype('float') / safe_row_sums
        plot_data = cm_normalized
        if fmt is None:
            fmt = '.2f' # Default format for normalized values
        title = f'{title} (Normalized)'
    else:
        plot_data = cm
        if fmt is None:
            fmt = 'd' # Default format for counts

    # Plotting
    plt.figure(figsize=figsize)
    sns.heatmap(
        plot_data,
        annot=True,       # Show numbers in cells
        fmt=fmt,          # Format for the numbers
        cmap=cmap,        # Colormap
        xticklabels=tick_labels,
        yticklabels=tick_labels,
        linewidths=.5,    # Add lines between cells
        cbar=True         # Show color bar
    )

    plt.title(title, fontsize=16)
    plt.ylabel('True Label', fontsize=14)
    plt.xlabel('Predicted Label', fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12, rotation=0) # Keep y-axis labels horizontal
    plt.tight_layout() # Adjust plot to prevent labels from overlapping
    plt.show()

In [None]:
# Load the best model
model.load_state_dict(torch.load(BEST_MODEL_PATH))
model.to(DEVICE)  # Ensure model is on the correct device after loading
model.eval()

all_test_preds = []
all_test_labels = []
total_test_loss = 0

# Use the same weighted loss function as during training
#criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

test_progress_bar = tqdm(test_loader, desc="Testing", leave=False)
with torch.no_grad():
    for batch in test_progress_bar:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        token_type_ids = batch.get('token_type_ids')
        if token_type_ids is not None:
            token_type_ids = token_type_ids.to(DEVICE)
            outputs = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
        else:
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
        #loss = criterion(outputs.logits, labels)  # Use weighted loss here
        loss = outputs.loss
        total_test_loss += loss.item()

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        all_test_preds.extend(predictions.cpu().numpy())
        all_test_labels.extend(labels.cpu().numpy())
        test_progress_bar.set_postfix({'test_loss': f"{loss.item():.4f}"})

# Calculate overall test loss and accuracy
avg_test_loss = total_test_loss / len(test_loader)
test_accuracy = accuracy_score(all_test_labels, all_test_preds)

# Print and save the metrics
print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")
print("\nTest Set Classification Report:")

# Check actual label distribution in test set
unique_labels_true, counts_true = np.unique(all_test_labels, return_counts=True)
print(f"True label distribution in test set: {dict(zip(unique_labels_true, counts_true))}")

# Check predicted label distribution
unique_labels_pred, counts_pred = np.unique(all_test_preds, return_counts=True)
print(f"Predicted label distribution in test set: {dict(zip(unique_labels_pred, counts_pred))}")

# Detailed classification report
print("\nDetailed Classification Report (Test):")
labels_for_report = sorted(list(set(all_test_labels) | set(all_test_preds)))
print(classification_report(all_test_labels, all_test_preds, labels=labels_for_report, zero_division=0))
classification_report = classification_report(all_test_labels, all_test_preds, labels=labels_for_report, zero_division=0)

# Precision, Recall, F1-Score metrics
test_precision_macro, test_recall_macro, test_f1_macro, _ = precision_recall_fscore_support(
    all_test_labels, all_test_preds, average='macro', zero_division=0
)
test_precision_weighted, test_recall_weighted, test_f1_weighted, _ = precision_recall_fscore_support(
    all_test_labels, all_test_preds, average='weighted', zero_division=0
)
test_f1 = f1_score(all_test_labels, all_test_preds)
test_precision = precision_score(all_test_labels, all_test_preds)
test_recall = recall_score(all_test_labels, all_test_preds)

# Specify the file path for saving metrics
output_file = "models/" + DATASET_TYPE + "/" + "test_metrics.txt"

# Open the file in append mode
with open(output_file, "a") as f:
    f.write("\nOverall Test Metrics (" + MODEL_NAME + ")\n")
    f.write(f"Test Loss: {avg_test_loss:.4f}\n")
    f.write(f"Test Accuracy: {test_accuracy:.4f}\n")
    f.write(f"Test F1-Score: {test_f1:.4f}\n")
    f.write(f"Test Precision: {test_precision:.4f}\n")
    f.write(f"Test Recall: {test_recall:.4f}\n")
    f.write(f"Test F1-Score (Macro): {test_f1_macro:.4f}\n")
    f.write(f"Test Precision (Macro): {test_precision_macro:.4f}\n")
    f.write(f"Test Recall (Macro): {test_recall_macro:.4f}\n")
    f.write(f"Test F1-Score (Weighted): {test_f1_weighted:.4f}\n")
    f.write(f"Test Precision (Weighted): {test_precision_weighted:.4f}\n")
    f.write(f"Test Recall (Weighted): {test_recall_weighted:.4f}\n")
    f.write(f"Detailed Classification Report (Test):\n {classification_report}\n")

print(f"\nMetrics have been saved to '{output_file}'.")

In [None]:
plot_confusion_matrix(all_test_labels,all_test_preds)