In [1]:
# 1. Uninstall existing versions to clear conflicts
!pip uninstall -y protobuf tensorboard

# 2. Install a stable, compatible version of protobuf
!pip install -q protobuf==3.20.3

# 3. Reinstall tensorboard
!pip install -q tensorboard

Found existing installation: protobuf 6.33.0
Uninstalling protobuf-6.33.0:
  Successfully uninstalled protobuf-6.33.0
Found existing installation: tensorboard 2.18.0
Uninstalling tensorboard-2.18.0:
  Successfully uninstalled tensorboard-2.18.0
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
tensorflow 2.18.0 requires tensorboard<2.19,>=2.18, which is not installed.
opentelemetry-proto 1.37.0 requires protobuf<7.0,>=5.0, but you have protobuf 3.20.3 which is incompatible.
onnx 1.18.0 requires protobuf>=4.25.1, but you have protobuf 3.20.3 which is incompatible.
a2a-sdk 0.3.10 requires protobuf>=5.29.5, but you have protobuf 3.

In [2]:
import pandas as pd
import numpy as np
import re
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter  # <--- NEW IMPORT
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import BertModel, BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, accuracy_score, f1_score, recall_score
import matplotlib.pyplot as plt

2025-12-27 16:43:40.013827: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766853820.187667      20 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766853820.237761      20 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
MAX_LEN = 256      # Max length of tweets (BERT max is 512, but tweets are short)
BATCH_SIZE = 16    # 16 or 32 is standard for BERT
EPOCHS = 2         # BERT fine-tunes quickly (2-4 epochs is usually enough)
RETRAIN_EPOCHS = 3
LEARNING_RATE = 1e-5 
USE_BALANCED = False
lambdaa = 1
TOP_K = 50

Using device: cuda


In [4]:
# --- 3. Data Loading & Minimal Cleaning ---
writer = SummaryWriter(f'runs/BERT_{EPOCHS}_epochs_{LEARNING_RATE}_lr_BALANCED' if USE_BALANCED else f'BERT_NEGATIVE_MINING_{EPOCHS}_epochs_{RETRAIN_EPOCHS}_retrain_epochs_{LEARNING_RATE}_lr')

def clean_text_bert(text):
    # Minimal cleaning for BERT. It needs context, so we keep stopwords.
    text = str(text).lower()
    text = re.sub(r'rt\s', '', text)               # Remove RT
    text = re.sub(r'@\w+', '', text)               # Remove mentions
    text = re.sub(r'https?://\S+|www\.\S+', '', text) # Remove URLs
    text = re.sub(r'&#[0-9]+;', '', text)          # Remove HTML
    # We KEEP punctuation because BERT uses it for context/structure
    return text.strip()

# Load Data
df = pd.read_csv('/kaggle/input/sentiment-analysis-twitter-hate-speech/train.csv')
df_test = pd.read_csv('/kaggle/input/sentiment-analysis-twitter-hate-speech/test.csv')
df['clean_text'] = df['tweet'].apply(clean_text_bert)
df_test['clean_text'] = df_test['tweet'].apply(clean_text_bert)
# Split Data
X_train, X_val, y_train, y_val = train_test_split(
    df['clean_text'], df['class'], test_size=0.2, random_state=42
)

df_balanced_data = pd.read_csv('/kaggle/input/sentiment-analysis-twitter-hate-speech/balanced_data.csv')
X_train_balanced, X_val_balanced, y_train_balanced, y_val_balanced = train_test_split(
    df_balanced_data['clean_text'], df_balanced_data['class'], test_size=0.2, random_state=42
)

In [5]:
# --- 1. Initialize Tokenizer ---z
# 
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# --- 2. Custom Dataset Class ---
class TwoHeadDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        # Reset index to avoid errors if dataframe was shuffled/split
        self.texts = texts.reset_index(drop=True)
        self.labels = labels.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        label = self.labels[item]

        # Encoding: This handles Tokenization, Padding, and Truncation
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,    # Add [CLS] and [SEP]
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',       # Pad to max_len
            truncation=True,            # Truncate if too long
            return_attention_mask=True,
            return_tensors='pt',        # Return PyTorch tensors
        )

        return {
            'text': text,
            # Flatten because DataLoader adds the batch dimension later
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            # We pass the raw label (0, 1, or 2). 
            # The train_fn logic will handle splitting this into binary targets.
            'labels': torch.tensor(label, dtype=torch.long)
        }

# --- 3. Create DataLoaders ---

# Create Dataset objects
if not USE_BALANCED:
    train_dataset = TwoHeadDataset(
        texts=X_train, 
        labels=y_train, 
        tokenizer=tokenizer, 
        max_len=MAX_LEN
    )
    
    val_dataset = TwoHeadDataset(
        texts=X_val, 
        labels=y_val, 
        tokenizer=tokenizer, 
        max_len=MAX_LEN
    )
else:
    train_dataset = TwoHeadDataset(
        texts=X_train_balanced, 
        labels=y_train_balanced, 
        tokenizer=tokenizer, 
        max_len=MAX_LEN
    )
    
    val_dataset = TwoHeadDataset(
        texts=X_val_balanced, 
        labels=y_val_balanced, 
        tokenizer=tokenizer, 
        max_len=MAX_LEN
    )
test_dataset = TwoHeadDataset(
    texts=df_test['clean_text'],
    labels=df_test['class'],
    tokenizer=tokenizer,
    max_len=MAX_LEN
)
# Create DataLoaders
# shuffle=True for training to break correlations
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=2 # Optional: speeds up data loading
)

# shuffle=False for validation so results are reproducible
val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=2 
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2
)

print(f"Data Loaded: {len(train_dataset)} training samples, {len(val_dataset)} validation samples.")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Data Loaded: 15860 training samples, 3966 validation samples.


In [6]:
class BertTwoHeadHier(nn.Module):
    """
    Hierarchical 3-class via two binary heads:
      A: toxic?   (0=Neither, 1=Toxic)
      B: hate?    (0=Offensive, 1=Hate)  computed/trained only on toxic samples
    """
    def __init__(self, bert_name="bert-base-uncased", dropout=0.1):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        hidden = self.bert.config.hidden_size
        self.drop = nn.Dropout(dropout)
        self.head_toxic = nn.Linear(hidden, 1)  # logitA
        self.head_hate  = nn.Linear(hidden, 1)  # logitB

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        out = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True
        )
        cls = out.last_hidden_state[:, 0, :]  # [B, H] CLS token
        cls = self.drop(cls)
        logitA = self.head_toxic(cls).squeeze(-1)  # [B]
        logitB = self.head_hate(cls).squeeze(-1)   # [B]
        return logitA, logitB

In [7]:
def train_fn(data_loader, model, optimizer, device, scheduler=None, start_steps=0, lambdaa = 1.2):
    model.train()
    total_loss = 0
    # We use BCEWithLogitsLoss because our heads output raw logits (no sigmoid applied yet)
    criterion = nn.BCEWithLogitsLoss()
    
    for idx, batch in enumerate(tqdm(data_loader, desc="Training")):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        # Labels are 0 (Hate), 1 (Offensive), 2 (Neither)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        # 1. Forward Pass
        logitA, logitB = model(input_ids, attention_mask)

        # 2. Create Binary Targets on the Fly
        
        # Target A: 1 if Toxic (Class 0 or 1), 0 if Neither (Class 2)
        target_A = (labels <= 1).float()
        
        # Target B: 1 if Hate (Class 0), 0 if Offensive (Class 1)
        target_B = (labels == 0).float()

        # 3. Calculate Loss A (Toxic Detection)
        # This is calculated for EVERY sample in the batch
        loss_A = criterion(logitA, target_A)

        # 4. Calculate Loss B (Hate Detection)
        # This is calculated ONLY for samples that are actually Toxic (Label 0 or 1)
        
        # Create a mask: True where label is 0 or 1
        toxic_mask = (labels <= 1)
        
        if toxic_mask.sum() > 0:
            # Select only the logits and targets corresponding to toxic samples
            loss_B = criterion(logitB[toxic_mask], target_B[toxic_mask])
        else:
            # If batch has no toxic samples, Loss B is 0
            loss_B = torch.tensor(0.0, device=device)

        # 5. Total Loss
        # You can weigh these terms if needed (e.g., loss = loss_A + 2.0 * loss_B)
        loss = loss_A + lambdaa * loss_B
        
        loss.backward()
        
        # Clip gradients to prevent explosion
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        if scheduler:
            scheduler.step()
            
        total_loss += loss.item()
        current_step = start_steps + idx
        writer.add_scalar('Loss/Train', loss.item(), current_step)
    return total_loss / len(data_loader)

# --- 3. Custom Evaluation Function (Hierarchical Inference) ---
def evaluate_fn(data_loader, model, device, threshold = 0.5):
    model.eval()
    
    final_targets = []
    final_predictions = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Get Logits
            logitA, logitB = model(input_ids, attention_mask)
            
            # Convert to Probabilities
            probA = torch.sigmoid(logitA) # Prob of being Toxic
            probB = torch.sigmoid(logitB) # Prob of being Hate (conditional)
            
            # Hierarchical Decision Logic
            batch_preds = []
            for pA, pB in zip(probA, probB):
                # Step 1: Is it Toxic?
                if pA < 0.5:
                    # Not Toxic -> Predict Class 2 (Neither)
                    batch_preds.append(2)
                else:
                    # Is Toxic -> Step 2: Is it Hate?
                    if pB > 0.5:
                        # Hate -> Predict Class 0
                        batch_preds.append(0)
                    else:
                        # Not Hate (but Toxic) -> Predict Class 1 (Offensive)
                        batch_preds.append(1)
            
            final_targets.extend(labels.cpu().numpy())
            final_predictions.extend(batch_preds)
            
    # Metrics
    acc = accuracy_score(final_targets, final_predictions)
    print(f"\nValidation Accuracy: {acc:.4f}")
    
    target_names = ['Hate Speech (0)', 'Offensive (1)', 'Neither (2)']
    print(classification_report(final_targets, final_predictions, target_names=target_names))
    
    return acc

In [8]:
model = BertTwoHeadHier("bert-base-uncased").to(device)
    
# Initialize Optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# Assume train_loader and val_loader are already created from previous steps
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, 0, total_steps)
best_val_loss = float('inf')
save_path = f"BERT_{EPOCHS}_epochs_{LEARNING_RATE}_lr_BALANCED.pth" if USE_BALANCED else f"BERT_{EPOCHS}_epochs_{LEARNING_RATE}_lr.pth"
current_steps = 0
# Training Loop
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    train_loss = train_fn(train_loader, model, optimizer, device, scheduler, current_steps, lambdaa)
    current_steps += len(train_loader)
    model.eval()
    val_loss = 0
    val_criterion = nn.BCEWithLogitsLoss()
    
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logitA, logitB = model(input_ids, attention_mask)
            
            target_A = (labels <= 1).float()
            target_B = (labels == 0).float()
            
            loss_A = val_criterion(logitA, target_A)
            toxic_mask = (labels <= 1)
            if toxic_mask.sum() > 0:
                loss_B = val_criterion(logitB[toxic_mask], target_B[toxic_mask])
            else:
                loss_B = torch.tensor(0.0, device=device)
            
            val_loss += (loss_A + lambdaa * loss_B).item()
            
    avg_val_loss = val_loss / len(val_loader)
    
    # Get Accuracy from your existing evaluate function
    val_acc = evaluate_fn(val_loader, model, device) # Assuming this returns accuracy
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    # --- 3. Logging to TensorBoard ---
    writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
    writer.add_scalar('Accuracy/Validation', val_acc, epoch)
    
    # --- 4. Save Model if Val Loss Improved ---
    if avg_val_loss < best_val_loss:
        print(f"Validation loss decreased ({best_val_loss:.4f} --> {avg_val_loss:.4f}). Saving model...")
        torch.save(model.state_dict(), save_path)
        print(f"Saved at {save_path}")
        best_val_loss = avg_val_loss

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Epoch 1/2


Training: 100%|██████████| 992/992 [07:12<00:00,  2.29it/s]
Evaluating: 100%|██████████| 248/248 [00:28<00:00,  8.68it/s]



Validation Accuracy: 0.9160
                 precision    recall  f1-score   support

Hate Speech (0)       0.62      0.17      0.27       220
  Offensive (1)       0.93      0.97      0.95      3052
    Neither (2)       0.87      0.91      0.89       694

       accuracy                           0.92      3966
      macro avg       0.81      0.68      0.70      3966
   weighted avg       0.90      0.92      0.90      3966

Train Loss: 0.3580 | Val Loss: 0.2741 | Val Acc: 0.9160
Validation loss decreased (inf --> 0.2741). Saving model...
Saved at BERT_2_epochs_1e-05_lr.pth
Epoch 2/2


Training: 100%|██████████| 992/992 [07:15<00:00,  2.28it/s]
Evaluating: 100%|██████████| 248/248 [00:28<00:00,  8.63it/s]


Validation Accuracy: 0.9196
                 precision    recall  f1-score   support

Hate Speech (0)       0.61      0.32      0.42       220
  Offensive (1)       0.94      0.97      0.95      3052
    Neither (2)       0.89      0.89      0.89       694

       accuracy                           0.92      3966
      macro avg       0.81      0.73      0.76      3966
   weighted avg       0.91      0.92      0.91      3966

Train Loss: 0.2547 | Val Loss: 0.2765 | Val Acc: 0.9196





In [9]:
def hierarchical_predict(probA, probB, threshold=0.4, toxic_cut=0.5):
    probA = np.asarray(probA)
    probB = np.asarray(probB)
    pred = np.full_like(probA, 2, dtype=int)
    toxic = probA >= toxic_cut
    pred[toxic] = np.where(probB[toxic] > threshold, 0, 1)
    return pred


@torch.no_grad()
def mine_hard_negatives_from_loader(
    model,
    loader,
    device,
    threshold=0.4,
    toxic_cut=0.5,
    # selection knobs
    top_k=3000,
    min_pB=0.60,
    include_misclassified=True,
    score_mode="pB",  # "pB" or "pB_plus_uncertainty"
    alpha_uncert=0.5,
):
    """
    Mines hard negatives among TRUE offensive (label=1) inside loader.
    Returns a dataframe with columns: text, class, probA_toxic, probB_hate, pred, hard_score
    """
    model.eval().to(device)

    rows = []
    for batch in loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        y = batch["labels"].to(device)  # 0/1/2
        texts = batch["text"]           # list[str]

        logitA, logitB = model(input_ids=input_ids, attention_mask=attention_mask)
        pA = torch.sigmoid(logitA).detach().cpu().numpy()
        pB = torch.sigmoid(logitB).detach().cpu().numpy()
        y_np = y.detach().cpu().numpy()

        pred = hierarchical_predict(pA, pB, threshold=threshold, toxic_cut=toxic_cut)

        for t, label, a, b, pr in zip(texts, y_np, pA, pB, pred):
            if label == 1:  # only mine from true offensive
                rows.append({
                    "clean_text": t,
                    "class": int(label),
                    "probA_toxic": float(a),
                    "probB_hate": float(b),
                    "pred": int(pr),
                })

    df_off = pd.DataFrame(rows)
    if len(df_off) == 0:
        raise ValueError("No offensive samples found in this loader to mine from.")

    # Only consider toxic-ish offensive
    cand = df_off[df_off["probA_toxic"] >= toxic_cut].copy()

    # Score
    if score_mode == "pB":
        cand["hard_score"] = cand["probB_hate"]
    elif score_mode == "pB_plus_uncertainty":
        uncert = 1.0 - np.abs(cand["probB_hate"].to_numpy() - 0.5) * 2.0
        cand["hard_score"] = cand["probB_hate"].to_numpy() + alpha_uncert * uncert
    else:
        raise ValueError("score_mode must be 'pB' or 'pB_plus_uncertainty'.")

    # Misclassified offensive -> predicted hate
    misclf = cand[cand["pred"] == 0].copy()

    # Optional filter
    if min_pB is not None:
        cand = cand[cand["probB_hate"] >= float(min_pB)].copy()

    cand = cand.sort_values("hard_score", ascending=False)

    hard = cand.head(int(top_k)).copy() if top_k is not None else cand.copy()
    if include_misclassified:
        hard = pd.concat([hard, misclf], axis=0).drop_duplicates()

    stats = {
        "offensive_total_in_loader": int(len(df_off)),
        "toxicish_offensive": int(len(cand)),
        "misclassified_offensive_as_hate": int((df_off["pred"] == 0).sum()),
        "hard_selected": int(len(hard)),
        "threshold": float(threshold),
        "toxic_cut": float(toxic_cut),
        "min_pB": None if min_pB is None else float(min_pB),
    }

    return hard.reset_index(drop=True), stats

In [10]:
hard_off, stats = mine_hard_negatives_from_loader(
    model=model,
    loader=train_loader,
    device=device,
    threshold=0.4,
    toxic_cut=0.5,
    top_k=3000,
    min_pB=0.60,
    include_misclassified=True
)

print(stats)
display(hard_off.head())


{'offensive_total_in_loader': 12300, 'toxicish_offensive': 47, 'misclassified_offensive_as_hate': 317, 'hard_selected': 307, 'threshold': 0.4, 'toxic_cut': 0.5, 'min_pB': 0.6}


Unnamed: 0,clean_text,class,probA_toxic,probB_hate,pred,hard_score
0,you're killin' me man! &lt;3 not all conservat...,1,0.916685,0.762517,0,0.762517
1,the irish in #california are all white trash.,1,0.948489,0.761547,0,0.761547
2,only white trash and morons are offended by th...,1,0.844629,0.759079,0,0.759079
3,faggot is a public spectacle. most gays are no...,1,0.992158,0.736295,0,0.736295
4,i think i figured out where white trash americ...,1,0.933996,0.731573,0,0.731573


In [11]:
hard_off.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 307 entries, 0 to 306
Data columns (total 6 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   clean_text   307 non-null    object 
 1   class        307 non-null    int64  
 2   probA_toxic  307 non-null    float64
 3   probB_hate   307 non-null    float64
 4   pred         307 non-null    int64  
 5   hard_score   307 non-null    float64
dtypes: float64(3), int64(2), object(1)
memory usage: 14.5+ KB


In [12]:
def build_retrain_df(
    df_train,          # your original train dataframe (must include clean_text, class)
    hard_off_df,       # returned from mining (offensive-only)
    add_random_offensive=1000,
    random_state=42
):
    # Keep ALL hate + neither from original train
    df_hate = df_train[df_train["class"] == 0]
    df_nei  = df_train[df_train["class"] == 2]

    # Hard offensive (already class=1)
    df_hard = hard_off_df[["clean_text", "class"]].copy()

    parts = [df_hate, df_nei, df_hard]

    # Optional: add extra random offensive for diversity (not hard)
    if add_random_offensive and add_random_offensive > 0:
        df_off = df_train[df_train["class"] == 1].copy()
        hard_texts = set(df_hard["clean_text"].astype(str).tolist())
        remaining = df_off[~df_off["clean_text"].astype(str).isin(hard_texts)]
        n = min(int(add_random_offensive), len(remaining))
        if n > 0:
            parts.append(remaining.sample(n=n, random_state=random_state))

    df_retrain = (
        pd.concat(parts, axis=0)
          .drop_duplicates(subset=["clean_text", "class"])
          .sample(frac=1, random_state=random_state)
          .reset_index(drop=True)
    )

    vc = df_retrain["class"].value_counts().to_dict()
    summary = {
        "hate(0)": int(vc.get(0, 0)),
        "offensive(1)": int(vc.get(1, 0)),
        "neither(2)": int(vc.get(2, 0)),
        "total": int(len(df_retrain)),
    }
    return df_retrain, summary

In [13]:
df_train = pd.concat([X_train, y_train], axis = 1, keys=['clean_text', 'class'])
df_retrain, summary = build_retrain_df(
    df_train=df_train,            # IMPORTANT: the same train split dataframe
    hard_off_df=hard_off,
    add_random_offensive=1000     # try 0 / 500 / 1000
)

X_retrain, y_retrain = df_retrain['clean_text'], df_retrain['class']

print(summary)

{'hate(0)': 915, 'offensive(1)': 1305, 'neither(2)': 2600, 'total': 4820}


In [14]:
retrain_dataset = TwoHeadDataset(
    texts=X_retrain, 
    labels=y_retrain, 
    tokenizer=tokenizer, 
    max_len=MAX_LEN
)

retrain_loader = DataLoader(
    retrain_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=2 # Optional: speeds up data loading
)

In [15]:
optimizer = AdamW(model.parameters(), lr=2e-5)

# Assume train_loader and val_loader are already created from previous steps
model.load_state_dict(torch.load(save_path))
total_steps = len(retrain_loader) * RETRAIN_EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, 0, total_steps)
best_val_loss = float('inf')
save_path = f"BERT_RETRAIN_{RETRAIN_EPOCHS}_epochs_{LEARNING_RATE}_lr_BALANCED.pth" if USE_BALANCED else f"BERT_RETRAIN_{RETRAIN_EPOCHS}_epochs_{LEARNING_RATE}_lr.pth"

# Training Loop
for epoch in range(EPOCHS, EPOCHS + RETRAIN_EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS + RETRAIN_EPOCHS}")
    train_loss = train_fn(retrain_loader, model, optimizer, device, scheduler, current_steps, lambdaa)
    current_steps += len(retrain_loader)
    model.eval()
    val_loss = 0
    val_criterion = nn.BCEWithLogitsLoss()
    
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logitA, logitB = model(input_ids, attention_mask)
            
            target_A = (labels <= 1).float()
            target_B = (labels == 0).float()
            
            loss_A = val_criterion(logitA, target_A)
            toxic_mask = (labels <= 1)
            if toxic_mask.sum() > 0:
                loss_B = val_criterion(logitB[toxic_mask], target_B[toxic_mask])
            else:
                loss_B = torch.tensor(0.0, device=device)
            
            val_loss += (loss_A + lambdaa * loss_B).item()
            
    avg_val_loss = val_loss / len(val_loader)
    
    # Get Accuracy from your existing evaluate function
    val_acc = evaluate_fn(val_loader, model, device) # Assuming this returns accuracy
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    # --- 3. Logging to TensorBoard ---
    writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
    writer.add_scalar('Accuracy/Validation', val_acc, epoch)
    
    # --- 4. Save Model if Val Loss Improved ---
    if avg_val_loss < best_val_loss:
        print(f"Validation loss decreased ({best_val_loss:.4f} --> {avg_val_loss:.4f}). Saving model...")
        torch.save(model.state_dict(), save_path)
        print(f"Saved at {save_path}")
        best_val_loss = avg_val_loss

Epoch 3/5


Training: 100%|██████████| 302/302 [02:12<00:00,  2.28it/s]
Evaluating: 100%|██████████| 248/248 [00:28<00:00,  8.63it/s]



Validation Accuracy: 0.8376
                 precision    recall  f1-score   support

Hate Speech (0)       0.23      0.58      0.33       220
  Offensive (1)       0.97      0.84      0.90      3052
    Neither (2)       0.83      0.93      0.88       694

       accuracy                           0.84      3966
      macro avg       0.68      0.78      0.70      3966
   weighted avg       0.90      0.84      0.86      3966

Train Loss: 0.7397 | Val Loss: 0.4118 | Val Acc: 0.8376
Validation loss decreased (inf --> 0.4118). Saving model...
Saved at BERT_RETRAIN_3_epochs_1e-05_lr.pth
Epoch 4/5


Training: 100%|██████████| 302/302 [02:12<00:00,  2.28it/s]
Evaluating: 100%|██████████| 248/248 [00:28<00:00,  8.62it/s]



Validation Accuracy: 0.8056
                 precision    recall  f1-score   support

Hate Speech (0)       0.20      0.65      0.31       220
  Offensive (1)       0.96      0.79      0.87      3052
    Neither (2)       0.87      0.90      0.88       694

       accuracy                           0.81      3966
      macro avg       0.68      0.78      0.69      3966
   weighted avg       0.90      0.81      0.84      3966

Train Loss: 0.5743 | Val Loss: 0.4694 | Val Acc: 0.8056
Epoch 5/5


Training: 100%|██████████| 302/302 [02:12<00:00,  2.28it/s]
Evaluating: 100%|██████████| 248/248 [00:28<00:00,  8.63it/s]


Validation Accuracy: 0.8235
                 precision    recall  f1-score   support

Hate Speech (0)       0.23      0.68      0.34       220
  Offensive (1)       0.97      0.81      0.88      3052
    Neither (2)       0.83      0.94      0.88       694

       accuracy                           0.82      3966
      macro avg       0.68      0.81      0.70      3966
   weighted avg       0.91      0.82      0.85      3966

Train Loss: 0.4405 | Val Loss: 0.5202 | Val Acc: 0.8235





In [16]:
def get_predictions_two_head(model, data_loader, threshold = 0.5):
    model = model.eval()
    final_predictions = []
    real_values = []
    
    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["labels"].to(device)

            # 1. Forward Pass: Get the two binary logits
            logitA, logitB = model(input_ids, attention_mask=attention_mask)
            
            # 2. Convert Logits to Probabilities (Sigmoid)
            probA = torch.sigmoid(logitA) # Probability of being Toxic
            probB = torch.sigmoid(logitB) # Probability of being Hate (conditional)
            
            # 3. Apply Hierarchical Logic
            # We iterate through the batch to assign final classes 0, 1, or 2
            batch_preds = []
            for pA, pB in zip(probA, probB):
                if pA < 0.5:
                    # Head A says "Not Toxic" -> Predict Class 2 (Neither)
                    batch_preds.append(2)
                else:
                    # Head A says "Toxic" -> Check Head B
                    if pB > threshold:
                        # Head B says "Hate" -> Predict Class 0 (Hate Speech)
                        batch_preds.append(0)
                    else:
                        # Head B says "Offensive" -> Predict Class 1 (Offensive)
                        batch_preds.append(1)
            
            final_predictions.extend(batch_preds)
            real_values.extend(targets.cpu().numpy())
            
    # Convert list to Tensor/Array for scikit-learn
    return torch.tensor(final_predictions), torch.tensor(real_values)

In [17]:
def sweep_thresholds_two_head(
    model,
    data_loader,
    thresholds=np.arange(0.05, 0.95, 0.05),
    toxic_cut=0.5
):
    rows = []
    for thr in thresholds:
        y_pred, y_true = get_predictions_two_head(model, data_loader, threshold=float(thr))

        y_true_np = y_true.cpu().numpy()
        y_pred_np = y_pred.cpu().numpy()

        # class-specific F1 for hate (label 0)
        f1_hate = f1_score(y_true_np, y_pred_np, labels=[0], average="macro")

        # weighted F1 across all 3 classes
        f1_weighted = f1_score(y_true_np, y_pred_np, average="weighted")

        # optional: macro F1 (often useful with imbalance)
        f1_macro = f1_score(y_true_np, y_pred_np, average="macro")

        rows.append({
            "threshold": float(thr),
            "f1_hate": float(f1_hate),
            "f1_weighted": float(f1_weighted),
            "recall_hate": recall_score(y_true_np, y_pred_np, labels=[0], average="macro"),
            "f1_macro": float(f1_macro),
        })

    return pd.DataFrame(rows).sort_values("threshold").reset_index(drop=True)

def pick_threshold_weighted_then_hate_recall(df_thr, tol=1e-4):
    """
    1) Maximize weighted F1
    2) Among near-ties, maximize hate recall
    3) If still tied, pick the smaller threshold (more recall-friendly)
    """
    best_w = df_thr["recall_hate"].max()

    # Step 1: keep near-best weighted F1
    candidates = df_thr[df_thr["recall_hate"] >= best_w - tol].copy()

    # Step 2: maximize hate recall
    candidates = candidates.sort_values(
        ["f1_weighted", "threshold"],
        ascending=[False, True]
    )

    return candidates.iloc[0]

def pick_threshold_by_weighted_score(
    df_thr,
    w_recall=0.6,
    w_f1=0.4,
    min_recall_hate=None,   # e.g., 0.80
    min_f1_weighted=None,   # e.g., 0.75
    tol=1e-6
):
    df = df_thr.copy()

    # Optional hard constraints
    if min_recall_hate is not None:
        df = df[df["recall_hate"] >= float(min_recall_hate)]
    if min_f1_weighted is not None:
        df = df[df["f1_weighted"] >= float(min_f1_weighted)]

    # If constraints filter everything out, fall back to unconstrained scoring
    if len(df) == 0:
        df = df_thr.copy()

    df["score"] = w_recall * df["recall_hate"] + w_f1 * df["f1_weighted"]

    # Tie-breaks: higher score, then higher f1_weighted, then higher recall_hate,
    # then smaller threshold (more recall-friendly)
    df = df.sort_values(
        ["score", "f1_weighted", "recall_hate", "threshold"],
        ascending=[False, False, False, True]
    )

    # keep near-best within tol
    best_score = df["score"].iloc[0]
    top = df[df["score"] >= best_score - tol].copy()
    top = top.sort_values(["threshold"], ascending=True)

    return top.iloc[0]


thresholds = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
df_thr = sweep_thresholds_two_head(model, val_loader, thresholds=thresholds)
# df_thr.sort_values("f1_weighted", ascending=False).head(10)
best = pick_threshold_by_weighted_score(df_thr)
best

threshold      0.200000
f1_hate        0.323322
f1_weighted    0.820420
recall_hate    0.831818
f1_macro       0.682553
score          0.827259
Name: 2, dtype: float64

In [18]:
best['threshold']

0.2

In [19]:
y_pred, y_test = get_predictions_two_head(model, test_loader, best['threshold'])
target_names = ['Hate Speech (0)', 'Offensive Language (1)', 'Neither (2)']
print(classification_report(y_test, y_pred, target_names=target_names))

                        precision    recall  f1-score   support

       Hate Speech (0)       0.20      0.86      0.32       286
Offensive Language (1)       0.99      0.71      0.83      3838
           Neither (2)       0.83      0.94      0.88       833

              accuracy                           0.76      4957
             macro avg       0.67      0.84      0.68      4957
          weighted avg       0.91      0.76      0.81      4957

