# Import libraries

In [1]:
# Re-install transformers
# !pip install -q transformers==4.55.4 faiss-cpu

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Sampler, DataLoader

import numpy as np
from safetensors.torch import load_model

In [3]:
from transformers import (
    BertTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
    ErnieConfig,
    ErnieModel,
    EarlyStoppingCallback,
    AutoModelForSequenceClassification
)

In [4]:
import gc
import os
import random
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import f1_score, recall_score, precision_score

In [5]:
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

In [6]:
from pytorch_metric_learning.losses import SoftTripleLoss
from pytorch_metric_learning.miners import TripletMarginMiner
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

# Set seed for deterministic

In [7]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# torch.use_deterministic_algorithms(True) 

# Constants

In [8]:
PATH_TO_DATASET = '/mnt/d/SemEval2026/subtask1'
MODEL_NAME = '/mnt/d/SemEval2026/ernie-3.0-xbase-zh'
BATCH_SIZE_TRAIN = 24
BATCH_SIZE_EVAL = 32

LANG = 'zho'
NUM_CLASSES = 2
K_FOLDS = 5

SAVE_DIR = f'/mnt/d/SemEval2026/Ernie3-Sub1-Ablation-Baseline-{LANG}'
TEMP_DIR = f'/mnt/d/SemEval2026/Ernie3-Sub1-temp-{LANG}'

In [9]:
SUBTASK2_COLUMNS = ['political', 'racial/ethnic', 'religious', 'gender/sexual', 'other']
SUBTASK3_COLUMNS = ['stereotype','vilification','dehumanization','extreme_language','lack_of_empathy','invalidation']

In [10]:
# Training configuration
LEARNING_RATE = 2e-5
MAX_STEPS = 300
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1
EARLY_STOPPING_PATIENCE = 5
EVAL_STEPS = LOGGING_STEPS = 10

# Prepare tokenizer and model

In [11]:
# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

# Prepare dataset

In [12]:
class PolarizationDataset(torch.utils.data.Dataset):
  def __init__(self, data, tokenizer, max_length=96):
    """
    Args:
      data: Dict with 'texts', 'labels', and 'ids' keys
    """
    self.tokenizer = tokenizer
    self.max_length = max_length

    if isinstance(data, dict):
      self.texts = data['texts']
      self.labels = data['labels']
      # Add ID support. If not present, generate dummy IDs (0...N)
      self.ids = data.get('ids', list(range(len(self.texts))))
    else:
      raise ValueError("Data must be a single dict with 'texts', 'labels', and 'ids'")

    self.has_labels = any(label is not None for label in self.labels)

  def __len__(self):
    return len(self.texts)

  def __getitem__(self, idx):
    text = self.texts[idx]
    label = self.labels[idx] if self.has_labels else None
    sample_id = self.ids[idx]

    encoding = self.tokenizer(
      text,
      truncation=True,
      padding='max_length', # Changed to max_length for consistent tensor shapes in batch
      max_length=self.max_length,
      return_tensors='pt'
    )

    item = {key: encoding[key].squeeze() for key in encoding.keys()}

    if label is not None:
      item['labels'] = torch.tensor(int(label > 0), dtype=torch.long)
    
    # Return ID so we can verify if needed, though Sampler handles the logic
    # item['id'] = sample_id 

    return item

In [13]:
train_path = PATH_TO_DATASET + f'/train/{LANG}.csv'
df = pd.read_csv(train_path)

In [14]:
PATH_TO_DATASET_SUBTASK2 = '/mnt/d/SemEval2026/subtask2'
df_sub2 = pd.read_csv(
    PATH_TO_DATASET_SUBTASK2 + f'/train/{LANG}.csv'
)

In [15]:
df = df.merge(
    df_sub2.drop(columns=['text']),
    on='id',
    how='left'
)

In [16]:
def prepare_data_and_folds(df, df_sub2, n_splits=5, seed=SEED):
    """
    1. Merges Subtask 1 and 2.
    2. Generates 'stage1_target' (0-5) for Metric Learning.
    3. Creates Multilabel Stratified Folds.
    
    Returns: 
        df (updated with 'stage1_target'), 
        fold_idx (list of train/val tuples)
    """
    print(f"Original df shape: {df.shape}")
    
    # --- STEP 2: GENERATE STAGE 1 TARGETS (The Geometry Fix) ---
    # We map Multi-Label vectors to a Single Integer for Triplet Loss.
    # Priority: Gender > Religious > Race > Political > Other
    # Reasoning: 'Gender' is the hardest implicit class. 'Religious' is the smallest minority.
    
    def get_fine_grained_label(row):
        if row['polarization'] == 0:
            return 0  # Class 0: Safe
        
        # Priority Check for Hate Topics
        if row['gender/sexual'] == 1: return 4  # Class 4: Gender (High Priority!)
        if row['religious'] == 1:     return 3  # Class 3: Religious (Protect Minority)
        if row['racial/ethnic'] == 1: return 2  # Class 2: Race (Dominant)
        if row['political'] == 1:     return 1  # Class 1: Politics
        if row['other'] == 1:         return 5  # Class 5: Other
        
        return 5 # Fallback

    print("Generating Fine-Grained Labels for Stage 1...")
    df['polarization'] = df.apply(get_fine_grained_label, axis=1)
    
    # --- STEP 3: STRATIFIED SPLIT (The Variance Fix) ---
    # Stratify on: Main Label + All 5 Topics
    y_main = df['polarization'].values.reshape(-1, 1)
    y_sub2 = df_sub2[SUBTASK2_COLUMNS].values
    stratify_targets = np.hstack([y_main, y_sub2])
    
    print(f"Running Multilabel Stratified K-Fold (n={n_splits})...")
    mskf = MultilabelStratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    
    fold_idx = []
    for train_idx, val_idx in mskf.split(df, stratify_targets):
        fold_idx.append((train_idx, val_idx))

    # Return the modified DF (so you can access 'stage1_target' later) and the indices
    return df, fold_idx

In [17]:
_, fold_idx = prepare_data_and_folds(df, df_sub2, n_splits=K_FOLDS, seed=SEED)

Original df shape: (4280, 8)
Generating Fine-Grained Labels for Stage 1...
Running Multilabel Stratified K-Fold (n=5)...


# Training process

## Sampler

In [18]:
class HierarchicalSampler(Sampler):
    """
    Custom sampler for HF Trainer that ensures:
    1. Each batch is 50% Safe (Label 0) and 50% Hate (Labels 1-5).
    2. The Hate half is stratified: it cycles through sub-types (1->2->3->4->5) 
       to guarantee diversity and inclusion of rare classes (e.g., Rape Apology).
    3. Handles exhaustion and oversampling automatically.
    """

    def __init__(self, dataset, batch_size, shuffle=True, seed=SEED):
        """
        Args:
            dataset: Dataset instance containing .labels (values 0-5)
            batch_size: Size of each batch
            shuffle: Whether to shuffle samples
            seed: Random seed for reproducibility
        """
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.seed = seed

        # Set random seed if provided
        if self.seed is not None:
            random.seed(self.seed)
            np.random.seed(self.seed)

        # Initialize storage for all 6 classes
        self.label_indices = {0: [], 1: [], 2: [], 3: [], 4: [], 5: []}
        
        # Group indices by their fine-grained label (0-5)
        for idx in range(len(self.dataset)):
            label = self.dataset.labels[idx]
            # Ensure we capture all valid labels 0-5
            if label is not None and label in self.label_indices:
                self.label_indices[label].append(idx)

        self.total_samples = len(dataset)
        # Standard calculation for number of batches
        self.num_batches = (self.total_samples + self.batch_size - 1) // self.batch_size

    def _create_epoch_pools(self):
        """Create shuffled pools for every class 0-5"""
        pools = {k: [] for k in self.label_indices.keys()}

        for label in self.label_indices:
            indices = self.label_indices[label].copy()
            if self.shuffle:
                random.shuffle(indices)
            pools[label] = indices

        return pools

    def _get_balanced_samples(self, pools, num_samples):
        """
        Construct a batch with:
        - 50% Label 0 (Safe)
        - 50% Labels 1-5 (Hate) mixed via Round Robin
        """
        n_safe_target = num_samples // 2
        n_hate_target = num_samples - n_safe_target
        
        selected = []

        # --- PART 1: Fill Safe Slots (Label 0) ---
        for _ in range(n_safe_target):
            if pools[0]:
                selected.append(pools[0].pop(0))
            else:
                # Fallback: If 0 is empty, try to steal from Hate pools (1-5)
                # We try 1..5 in order until we find a sample
                found = False
                for fallback_label in [1, 2, 3, 4, 5]:
                    if pools[fallback_label]:
                        selected.append(pools[fallback_label].pop(0))
                        found = True
                        break
                # If everything is empty (shouldn't happen in loop logic), break
                if not found:
                    break

        # --- PART 2: Fill Hate Slots (Labels 1-5) via Round Robin ---
        # We cycle 1 -> 2 -> 3 -> 4 -> 5 -> 1 ... to ensure rare classes (4) get picked
        hate_labels = [1, 2, 3, 4, 5]
        
        # We keep looping until we fill the quota or run out of hate samples
        current_hate_idx = 0 
        attempts = 0
        max_attempts = len(hate_labels) * 5 # Prevent infinite loop if all hate pools empty

        while len(selected) < num_samples:
            # Check if we have exhausted all pools (Safe and Hate)
            if all(not pools[k] for k in pools):
                break

            target_label = hate_labels[current_hate_idx % len(hate_labels)]
            
            if pools[target_label]:
                selected.append(pools[target_label].pop(0))
                attempts = 0 # Reset attempts on success
            else:
                attempts += 1
            
            # Move to next hate label (Round Robin)
            current_hate_idx += 1

            # Fallback logic: If we tried all hate labels and found nothing,
            # try filling the remaining slots with Safe (0) if available
            if attempts >= len(hate_labels):
                if pools[0]:
                    selected.append(pools[0].pop(0))
                    attempts = 0 # Reset because we found a safe sample
                else:
                    # If pools[0] is also empty, we are truly done
                    break

        return selected

    def _get_remaining_count(self, pools):
        """Sum of all remaining samples across all classes"""
        return sum(len(indices) for indices in pools.values())

    def __iter__(self):
        # Create fresh pools for this epoch
        pools = self._create_epoch_pools()

        all_batches = []

        # Generate batches until pools are exhausted
        while self._get_remaining_count(pools) > 0:
            num_to_sample = min(self.batch_size, self._get_remaining_count(pools))
            if num_to_sample > 0:
                batch = self._get_balanced_samples(pools, num_to_sample)
                all_batches.append(batch)

        # Handle last batch - oversample if needed to reach batch_size
        # (This prevents Metric Learning from crashing on a tiny last batch)
        if all_batches and len(all_batches[-1]) < self.batch_size:
            last_batch = all_batches[-1]
            needed = self.batch_size - len(last_batch)

            # We just sample randomly from the whole dataset to fill the gap
            oversample_pool = range(self.__len__())
            last_batch.extend(random.choices(oversample_pool, k=needed))

            all_batches[-1] = last_batch

        # Shuffle WITHIN the batch (Standard practice)
        if self.shuffle:
            for i in range(len(all_batches)):
                random.shuffle(all_batches[i])

        # Flatten into a single list of indices for the Trainer
        indices = [idx for batch in all_batches for idx in batch]

        return iter(indices)

    def __len__(self):
        return self.total_samples

## Collator

In [19]:
data_collator = DataCollatorWithPadding(tokenizer)

## Trainers

In [20]:
class CustomTrainer(Trainer):
    def get_train_dataloader(self):
        if self.train_dataset is None:
            raise ValueError("Trainer: train_dataset has to be defined for training.")

        train_sampler = HierarchicalSampler(
            dataset=self.train_dataset,
            batch_size=self.args.per_device_train_batch_size,
            shuffle=True,
            seed=self.args.seed,
        )

        data_collator = self.data_collator
        if data_collator is None:
            # Do not wrap collator with accelerator.prepare; pass it directly
            data_collator = DataCollatorWithPadding(self.tokenizer)

        dataloader = DataLoader(
            self.train_dataset,
            # shuffle=True,
            sampler=train_sampler,
            # batch_sampler=train_sampler,
            collate_fn=data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            batch_size=self.args.per_device_train_batch_size,
        )
        return self.accelerator.prepare(dataloader)
    
    def get_eval_dataloader(self, eval_dataset=None):
        if eval_dataset is None:
            raise ValueError("Trainer: eval_dataset has to be defined for training.")

        eval_sampler = HierarchicalSampler(
            dataset=eval_dataset,
            batch_size=self.args.per_device_eval_batch_size,
            shuffle=False,
            seed=self.args.seed,
        )

        data_collator = self.data_collator
        if data_collator is None:
            # Do not wrap collator with accelerator.prepare; pass it directly
            data_collator = DataCollatorWithPadding(self.tokenizer)

        dataloader = DataLoader(
            eval_dataset,
            # shuffle=False,
            sampler=eval_sampler,
            # batch_sampler=train_sampler,
            collate_fn=data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            batch_size=self.args.per_device_eval_batch_size,
        )
        return self.accelerator.prepare(dataloader)

## Callbacks

In [21]:
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=EARLY_STOPPING_PATIENCE
)

## Training loop

In [22]:
# for i, (train_index, val_index) in enumerate(skfold.split(df['text'], df['polarization'])):
for i, (train_index, val_index) in enumerate(fold_idx):
    print(f"Starting fold {i+1}")

    # Prepare fold-specific datasets
    fold_train_data = {
        'texts': [df['text'][j] for j in train_index],
        'labels': [df['polarization'][j] for j in train_index]
    }
    fold_val_data = {
        'texts': [df['text'][j] for j in val_index],
        'labels': [df['polarization'][j] for j in val_index]
    }

    fold_train_dataset = PolarizationDataset(fold_train_data, tokenizer)
    fold_val_dataset = PolarizationDataset(fold_val_data, tokenizer)

    # Initialize model
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=NUM_CLASSES,
    )

    ### Training ###
    training_args = TrainingArguments(
        output_dir=TEMP_DIR + f'/fold_{i+1}',
        max_steps=MAX_STEPS,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        lr_scheduler_type="cosine",
        per_device_train_batch_size=BATCH_SIZE_TRAIN,
        per_device_eval_batch_size=BATCH_SIZE_EVAL,
        eval_strategy="steps",
        eval_steps=EVAL_STEPS,
        save_strategy="best",
        logging_steps=LOGGING_STEPS,
        disable_tqdm=False,
        report_to="none",
        metric_for_best_model="loss",
        load_best_model_at_end=True,
        save_total_limit=2,
        dataloader_num_workers=4,
        warmup_ratio=WARMUP_RATIO,
        bf16=True,
    )

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=fold_train_dataset,
        eval_dataset=fold_val_dataset,
        data_collator=data_collator,
        callbacks=[early_stopping_callback],
    )

    trainer.train()

    ### Save the best model for this fold ###
    os.makedirs(f"{SAVE_DIR}", exist_ok=True)
    trainer.save_model(f"{SAVE_DIR}/fold_{i+1}_best_model")


    ### Clean up for next fold ###
    del fold_train_dataset
    del fold_val_dataset
    del trainer
    del model

    gc.collect()
    torch.cuda.empty_cache()

Starting fold 1


Some weights of ErnieForSequenceClassification were not initialized from the model checkpoint at /mnt/d/SemEval2026/ernie-3.0-xbase-zh and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss
10,0.7119,0.698431
20,0.6837,0.647317
30,0.6045,0.523308
40,0.5053,0.435859
50,0.4966,0.37918
60,0.3397,0.375515
70,0.4974,0.439997
80,0.4457,0.329453
90,0.36,0.312584
100,0.3226,0.362736


Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_1/checkpoint-10/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_1/checkpoint-20/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_1/checkpoint-30/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_1/checkpoint-40/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_1/checkpoint-50/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_1/checkpoint-6

Starting fold 2


Some weights of ErnieForSequenceClassification were not initialized from the model checkpoint at /mnt/d/SemEval2026/ernie-3.0-xbase-zh and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss
10,0.7279,0.683408
20,0.6857,0.634338
30,0.6178,0.513198
40,0.5204,0.366309
50,0.4761,0.311419
60,0.3576,0.295508
70,0.41,0.284452
80,0.2887,0.28487
90,0.372,0.308303
100,0.4124,0.339678


Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_2/checkpoint-10/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_2/checkpoint-20/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_2/checkpoint-30/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_2/checkpoint-40/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_2/checkpoint-50/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_2/checkpoint-6

Starting fold 3


Some weights of ErnieForSequenceClassification were not initialized from the model checkpoint at /mnt/d/SemEval2026/ernie-3.0-xbase-zh and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss
10,0.6926,0.681422
20,0.6548,0.637697
30,0.6029,0.516969
40,0.5515,0.420607
50,0.3968,0.393595
60,0.35,0.350208
70,0.4831,0.406327
80,0.3646,0.291668
90,0.3289,0.268097
100,0.2887,0.290145


Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_3/checkpoint-10/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_3/checkpoint-20/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_3/checkpoint-30/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_3/checkpoint-40/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_3/checkpoint-50/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_3/checkpoint-6

Starting fold 4


Some weights of ErnieForSequenceClassification were not initialized from the model checkpoint at /mnt/d/SemEval2026/ernie-3.0-xbase-zh and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss
10,0.7293,0.681751
20,0.7003,0.654084
30,0.641,0.546268
40,0.5508,0.468188
50,0.4701,0.423009
60,0.3655,0.385429
70,0.4448,0.350976
80,0.348,0.399182
90,0.3635,0.33977
100,0.2655,0.333072


Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_4/checkpoint-10/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_4/checkpoint-20/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_4/checkpoint-30/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_4/checkpoint-40/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_4/checkpoint-50/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_4/checkpoint-6

Starting fold 5


Some weights of ErnieForSequenceClassification were not initialized from the model checkpoint at /mnt/d/SemEval2026/ernie-3.0-xbase-zh and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss
10,0.6828,0.68192
20,0.666,0.643706
30,0.633,0.540656
40,0.4924,0.397669
50,0.4735,0.357719
60,0.3307,0.335868
70,0.3744,0.34079
80,0.4063,0.303497
90,0.3115,0.403702
100,0.3241,0.299159


Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_5/checkpoint-10/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_5/checkpoint-20/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_5/checkpoint-30/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_5/checkpoint-40/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_5/checkpoint-50/vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to /mnt/d/SemEval2026/Ernie3-Sub1-temp-zho/fold_5/checkpoint-6

# Inferences on test set

## Load all fold models

In [23]:
model_list = []
for i in range(K_FOLDS):
    print(f"Intialize model fold {i+1}")
    model = AutoModelForSequenceClassification.from_pretrained(f"{SAVE_DIR}/fold_{i+1}_best_model")
    model.eval()
    model.cuda()
    model_list.append(model)

Intialize model fold 1
Intialize model fold 2
Intialize model fold 3
Intialize model fold 4
Intialize model fold 5


## Load test dataset (same for all subtasks)

In [24]:
# PUBLIC TEST WITH LABELS
PATH_TO_PUBLIC_TEST_WITH_LABELS = '/mnt/d/SemEval2026/test_phase/subtask1/dev'
test = pd.read_csv(PATH_TO_PUBLIC_TEST_WITH_LABELS + f'/{LANG}.csv')
test_dataset = PolarizationDataset({
    'texts': test['text'].tolist(),
    'labels': test['polarization'].tolist()
}, tokenizer=tokenizer)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE_EVAL,
    shuffle=False,
    collate_fn=data_collator
)

## Generate the result

In [44]:
fold_raw_predictions = []
for fold_idx, model in enumerate(model_list):
    print(f"Evaluating fold {fold_idx+1} on public test set")
    all_preds = []
    for batch in test_dataloader:
        batch = {k: v.cuda() for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
            logits = outputs.logits
            probs = torch.softmax(logits, dim=1)
            all_preds.extend(probs.cpu().numpy())
    fold_raw_predictions.append(all_preds)

# Ensemble by averaging logits
final_raw_predictions = np.mean(np.array(fold_raw_predictions), axis=0)

Evaluating fold 1 on public test set
Evaluating fold 2 on public test set
Evaluating fold 3 on public test set
Evaluating fold 4 on public test set
Evaluating fold 5 on public test set


### Evaluate the public test results

In [45]:
ground_truth = test_dataset.labels

In [46]:
search_range = np.linspace(0.01, 0.99, 99)  # From 0.01 to 0.99 with step 0.01
best_f1 = 0.0
best_threshold = 0.5
for threshold in search_range:
    binarized_preds = (final_raw_predictions[:, 1] >= threshold).astype(int)
    f1 = f1_score(ground_truth, binarized_preds, average='macro')
    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold
best_threshold, best_f1

(np.float64(0.37), 0.9345737246680643)

In [47]:
predicted_labels = (final_raw_predictions[:, 1] >= best_threshold).astype(int)

In [48]:
macro_f1 = f1_score(ground_truth, predicted_labels, average='macro')
precision = precision_score(ground_truth, predicted_labels, average='macro')
recall = recall_score(ground_truth, predicted_labels, average='macro')

In [49]:
print(f"Macro F1: {macro_f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

Macro F1: 0.9346, Precision: 0.9346, Recall: 0.9347
