<a href="https://colab.research.google.com/github/efandresena/SemEval/blob/main/subtask_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Multililingual for subtask 2

In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import random
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    Trainer, TrainingArguments, DataCollatorWithPadding,
    EarlyStoppingCallback
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import warnings
warnings.filterwarnings('ignore')


In [None]:
workdir = "/content/drive/MyDrive/NLP/SemEval"

CONFIG = {
    'model': 'xlm-roberta-large',
    'max_len': 250,
    'epochs': 6,
    'lr': 2e-5,
    'batch': 16,
    'grad_accum': 2,
    'warmup': 0.1,
    'weight_decay': 0.01,
    'augment_factor': 3,
}

LABELS = ['gender/sexual', 'political', 'religious', 'racial/ethnic', 'other']

SYNONYMS = {
    'eng': {
        'gender_sexual': {
            'woman': ['female', 'lady', 'girl', 'bitch', 'slut', 'whore', 'cunt', 'hoe', 'skank', 'thot'],
            'man': ['male', 'guy', 'dude', 'faggot', 'fag', 'pussy', 'beta', 'simp'],
            'gay': ['LGBT', 'queer', 'homo', 'fag', 'dyke', 'tranny', 'trans', 'sissy', 'queen'],
            'rape': ['assault', 'violence', 'molest', 'grope', 'force', 'violated'],
            'marriage': ['union', 'wedlock', 'family', 'traditional marriage'],
            # Additional common terms
            'lesbian': ['dyke', 'lezzie'],
            'transgender': ['tranny', 'shemale', 'he-she'],
            'feminist': ['feminazi', 'SJW'],
            'misogyny': ['hate women', 'women are inferior'],
            'sexism': ['sexist', 'misogynist']
        },
        'political': {
            'government': ['regime', 'authorities', 'dictators', 'tyrants', 'deep state'],
            'election': ['vote', 'ballot', 'rigged', 'stolen election'],
            'corruption': ['bribery', 'fraud', 'embezzlement', 'thieves', 'looters'],
            'protest': ['rally', 'demonstration', 'riot', 'anarchy'],
            # Additional common terms
            'liberal': ['leftist', 'snowflake', 'woke', 'commie', 'socialist'],
            'conservative': ['right-wing', 'fascist', 'nazi', 'bigot', 'MAGA'],
            'traitor': ['sellout', 'enemy within', 'deep state puppet'],
            'patriot': ['true patriot', 'real American'],
            'democracy': ['fake democracy', 'stolen vote'],
            'opposition': ['enemies', 'traitors', 'coup plotters']
        },
        'religious': {
            'muslim': ['islamic', 'muzz', 'raghead', 'terrorist', 'jihadist'],
            'christian': ['believer', 'infidel', 'kafir', 'crusader'],
            'church': ['temple', 'mosque', 'synagogue', 'shrine'],
            'god': ['deity', 'creator', 'allah', 'jesus'],
            # Additional common terms
            'islam': ['terrorist religion', 'sharia', 'radical islam'],
            'christianity': ['fake religion', 'crusaders'],
            'jew': ['kike', 'zionist', 'jewed'],
            'hindu': ['idol worshipper'],
            'atheist': ['godless', 'infidel'],
            'blasphemy': ['insult god', 'offend religion']
        },
        'racial_ethnic': {
            'black': ['african', 'nigger', 'nigga', 'monkey', 'ape', 'coon'],
            'white': ['caucasian', 'cracker', 'white devil', 'colonizer'],
            'racism': ['prejudice', 'bigotry', 'supremacy', 'hate'],
            'immigrant': ['migrant', 'foreigner', 'illegal', 'invader', 'alien'],
            # Additional common terms
            'asian': ['chink', 'gook', 'slant-eye'],
            'latino': ['beaner', 'wetback'],
            'arab': ['sand nigger', 'camel jockey'],
            'native': ['savage', 'redskin'],
            'ethnic': ['tribal', 'primitive'],
            'xenophobia': ['hate foreigners', 'go back home']
        },
        'other': {
            'violence': ['brutality', 'aggression', 'kill', 'murder', 'slaughter'],
            'hate': ['hostility', 'chuki', 'enmity', 'loathing'],
            'war': ['conflict', 'battle', 'genocide', 'ethnic cleansing'],
            # Additional terms
            'dehumanize': ['animal', 'cockroach', 'weed', 'vermin', 'subhuman'],
            'threat': ['kill', 'die', 'eliminate', 'wipe out'],
            'incite': ['rally', 'mobilize', 'attack', 'burn'],
            'disability': ['retard', 'cripple', 'spastic'],
            'age': ['boomer', 'snowflake', 'old fart']
        }
    },
    'swa': {
        'gender_sexual': {
            'mwanamke': ['mama', 'bibi', 'msichana', 'malaya', 'kahaba', 'dada', 'mrembo'],
            'mwanamume': ['baba', 'bwana', 'mzee', 'mwanaume', 'shoga', 'shoga', 'ngombe'],
            'ndoa': ['kuoana', 'harusi', 'ndoa ya jadi'],
            # Additional common terms
            'mwanamke': ['mwanamke wa mitaani', 'kiboko', 'dada wa mtaa'],
            'shoga': ['msenge', 'shoga', 'mashoga', 'homosexual'],
            'ngono': ['ngono', 'kufanya mapenzi', 'kudhalilisha'],
            'feminism': ['feminazi', 'wanawake wenye hasira'],
            'ubaguzi wa jinsia': ['sexism', 'misogyny']
        },
        'political': {
            'serikali': ['hukuma', 'dola', 'viongozi', 'wadhalilishaji'],
            'rais': ['mkuu', 'kiongozi', 'dikteta', 'mwizi'],
            'uchaguzi': ['kura', 'uchaguzi ulioibiwa', 'rigged'],
            'rushwa': ['ufisadi', 'kutoa rushwa', 'kuiba'],
            # Additional common terms
            'upinzani': ['maadui', 'wasaliti', 'wauzaji wa nchi'],
            'wazalendo': ['watajua', 'wale wengine'],
            'watajua': ['watajua hawajui', 'watajua'],
            'chunga kura': ['secure vote', 'protect vote'],
            'madoadoa': ['dots', 'madoadoa'],
            'mende': ['cockroaches'],
            'wabara': ['wabara waende kwao']
        },
        'religious': {
            'mwislamu': ['muislamu', 'musalama', 'mgaidi', 'kaffir'],
            'mkristo': ['muumini', 'mchawi', 'kafir'],
            'dini': ['imani', 'dini ya kishenzi', 'dini ya kuficha'],
            'mungu': ['mwenyezi mungu', 'allah'],
            'islam': ['dini ya magaidi', 'sharia'],
            'kristo': ['dini ya kishenzi', 'kristo'],
            'kafir': ['kafir', 'kaffir'],
            'chinja kafir': ['chinja kafir', 'kill infidel']
        },
        'racial_ethnic': {
            'mweusi': ['mwafrika', 'mweusi', 'madoadoa'],
            'mzungu': ['mweupe', 'mzungu', 'mkoloni'],
            'kabila': ['jamii', 'kabila', 'ukabila'],
            # Additional common terms
            'mende': ['mende', 'cockroaches'],
            'kwekwe': ['kwekwe', 'weeds'],
            'madoadoa': ['madoadoa', 'spots'],
            'wabara': ['wabara waende kwao'],
            'wakuja': ['wakuja'],
            'watajua': ['watajua hawajui']
        },
        'other': {
            'vita': ['mapambano', 'mzozo', 'vita', 'chinja'],
            'chuki': ['uadui', 'chuki', 'hasira'],
            'ua': ['mauaji', 'kill', 'slaughter'],
            # Additional common terms
            'chinja': ['chinja', 'butcher'],
            'mende': ['mende', 'cockroaches'],
            'kwekwe': ['kwekwe', 'weeds'],
            'madoadoa': ['madoadoa'],
            'operation linda kura': ['protect vote'],
            'hatupangwingwi': ["we won't be told what to do"]
        }
    }
}

In [None]:
import string
import random

def augment_text(text, category, lang, prob=0.5):
    """
    Augment text by replacing words from the given category with synonyms.
    Returns (new_text, was_changed)
    """
    cat_syns = SYNONYMS.get(lang, {}).get(category, {})
    if not cat_syns:
        return text, False

    words = text.split()
    aug_words = []
    changed = False

    for word in words:
        # Remove all punctuation for matching, but remember original form
        cleaned = word.translate(str.maketrans('', '', string.punctuation)).lower()

        if cleaned in cat_syns and random.random() < prob:
            syn = random.choice(cat_syns[cleaned])

            # Preserve capitalization
            if word and word[0].isupper():
                syn = syn.capitalize()

            # Re-attach trailing punctuation if original had it
            if word and word[-1] in string.punctuation:
                # Find the last punctuation and attach it
                trailing_punct = word[-1]
                if len(word) > 1 and word[-2] in string.punctuation:
                    trailing_punct = word[-2] + word[-1]  # handle !!, ??, etc.
                syn += trailing_punct

            aug_words.append(syn)
            changed = True
        else:
            aug_words.append(word)

    new_text = ' '.join(aug_words)
    return new_text, changed


def augment_df(df, lang, factor=3):
    print(f"Augmenting {lang}: {len(df)} samples")

    aug_texts = []
    aug_labels = []
    total_attempts = 0
    successful_changes = 0

    for _, row in df.iterrows():
        text = row['text']
        label_vals = {label: row[label] for label in LABELS}
        pos_labels = [l for l, v in label_vals.items() if v == 1]

        if not pos_labels:
            continue  # Skip non-toxic rows

        # Adaptive number of augmentations per row
        min_count = min(df[l].sum() for l in pos_labels)
        if min_count < 50:
            n_aug = factor * 4      # Very rare → aggressive augmentation
        elif min_count < 150:
            n_aug = factor * 3
        elif min_count < 300:
            n_aug = factor * 2
        else:
            n_aug = factor           # Common enough → moderate

        for _ in range(n_aug):
            total_attempts += 1
            cat = random.choice(pos_labels)  # Pick one positive label to target
            new_text, was_changed = augment_text(text, cat, lang, prob=0.9)

            if was_changed and new_text != text:
                aug_texts.append(new_text)
                aug_labels.append(label_vals.copy())
                successful_changes += 1

    print(f"  Total attempts: {total_attempts}")
    print(f"  Successful augmentations: {successful_changes}")

    if aug_texts:
        aug_df = pd.DataFrame({
            'text': aug_texts,
            **{label: [d[label] for d in aug_labels] for label in LABELS}
        })
        result = pd.concat([df, aug_df], ignore_index=True)
        print(f"  +{len(aug_df)} augmented → {len(result)} total samples")
        return result
    else:
        print("  No successful augmentations generated.")
        return df

In [None]:

class FocalLoss(nn.Module):
  def __init__(self, alpha=0.25, gamma=2.0, pos_weight=None):
      super().__init__()
      self.alpha = alpha
      self.gamma = gamma
      self.pos_weight = pos_weight

  def forward(self, inputs, targets):
      bce = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none', pos_weight=self.pos_weight)
      probs = torch.sigmoid(inputs)
      pt = torch.where(targets == 1, probs, 1 - probs)
      focal = (1 - pt) ** self.gamma
      if self.alpha is not None:
          alpha_w = torch.where(targets == 1, self.alpha, 1 - self.alpha)
          focal = alpha_w * focal
      return (focal * bce).mean()

class WeightedTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        pos_weight = self.class_weights.to(outputs.logits.device) if self.class_weights is not None else None
        loss = FocalLoss(alpha=0.25, gamma=2.0, pos_weight=pos_weight)(outputs.logits, labels)
        return (loss, outputs) if return_outputs else loss

class PolarizationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tokenizer(self.texts[idx], truncation=True, max_length=self.max_len, padding=False, return_tensors='pt')
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

def compute_metrics(pred):
    probs = torch.sigmoid(torch.from_numpy(pred.predictions)).numpy()
    preds = (probs > 0.5).astype(int)
    f1_macro = f1_score(pred.label_ids, preds, average='macro', zero_division=0)
    f1_per = f1_score(pred.label_ids, preds, average=None, zero_division=0)
    metrics = {'f1_macro': f1_macro}
    for i, label in enumerate(LABELS):
        metrics[f'f1_{label.replace("/", "_").replace(" ", "_")}'] = f1_per[i]
    return metrics

In [None]:


def train_combined(class_weight=True):
  print("="*60)
  print("COMBINED MULTILINGUAL TRAINING")
  print("="*60)

  # Load data
  eng_train = pd.read_csv(os.path.join(workdir, "dev_phase/subtask2/train/eng.csv"))
  swa_train = pd.read_csv(os.path.join(workdir, "dev_phase/subtask2/train/swa.csv"))
  eng_dev = pd.read_csv(os.path.join(workdir, "dev_phase/subtask2/dev/eng.csv"))
  swa_dev = pd.read_csv(os.path.join(workdir, "dev_phase/subtask2/dev/swa.csv"))

  # Augment
  eng_train = augment_df(eng_train, 'eng', CONFIG['augment_factor'])
  swa_train = augment_df(swa_train, 'swa', CONFIG['augment_factor'])

  # Combine
  combined_train = pd.concat([eng_train, swa_train], ignore_index=True)
  print(f"\nCombined training: {len(combined_train)} samples")
  for label in LABELS:
      print(f"  {label}: {combined_train[label].sum()}")

  # Split validation
  train_df, val_df = train_test_split(combined_train, test_size=0.15, random_state=42)

  # Class weights
  pos_counts = train_df[LABELS].sum()
  neg_counts = len(train_df) - pos_counts
  class_weights = torch.clamp(torch.tensor(neg_counts / (pos_counts + 1e-6), dtype=torch.float32), 1.0, 10.0)

  # Model
  tokenizer = AutoTokenizer.from_pretrained(CONFIG['model'])
  model = AutoModelForSequenceClassification.from_pretrained(CONFIG['model'], num_labels=5, problem_type="multi_label_classification")

  # Datasets
  train_dataset = PolarizationDataset(train_df['text'].tolist(), train_df[LABELS].values.tolist(), tokenizer, CONFIG['max_len'])
  val_dataset = PolarizationDataset(val_df['text'].tolist(), val_df[LABELS].values.tolist(), tokenizer, CONFIG['max_len'])

  # Training
  training_args = TrainingArguments(
      output_dir="./results_combined",
      num_train_epochs=CONFIG['epochs'],
      learning_rate=CONFIG['lr'],
      per_device_train_batch_size=CONFIG['batch'],
      per_device_eval_batch_size=CONFIG['batch'] * 2,
      gradient_accumulation_steps=CONFIG['grad_accum'],
      warmup_ratio=CONFIG['warmup'],
      weight_decay=CONFIG['weight_decay'],
      eval_strategy="steps",
      eval_steps=100,
      save_strategy="steps",
      save_steps=100,
      load_best_model_at_end=True,
      metric_for_best_model="f1_macro",
      greater_is_better=True,
      logging_steps=50,
      fp16=True,
      report_to="none",
      save_total_limit=2,
  )
  if class_weight:
    trainer = WeightedTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        data_collator=DataCollatorWithPadding(tokenizer),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
        class_weights=class_weights
    )
  else:
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        data_collator=DataCollatorWithPadding(tokenizer),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
    )

  trainer.train()

    # Evaluate
  final = trainer.evaluate()
  print(f"\nFinal F1 Macro: {final['eval_f1_macro']:.4f}")

  # Find threshold
  val_preds = trainer.predict(val_dataset)
  val_probs = torch.sigmoid(torch.tensor(val_preds.predictions)).numpy()
  best_thresh, best_f1 = 0.5, 0
  for thresh in np.arange(0.3, 0.40, 0.55):
      preds = (val_probs > thresh).astype(int)
      f1 = f1_score(val_df[LABELS].values, preds, average='macro', zero_division=0)
      if f1 > best_f1:
          best_f1, best_thresh = f1, thresh
  print(f"Best threshold: {best_thresh:.2f} (F1: {best_f1:.4f})")

  # Predict English
  eng_dataset = PolarizationDataset(eng_dev['text'].tolist(), [[0]*5]*len(eng_dev), tokenizer, CONFIG['max_len'])
  eng_preds = trainer.predict(eng_dataset)
  eng_probs = torch.sigmoid(torch.tensor(eng_preds.predictions)).numpy()
  eng_binary = (eng_probs > best_thresh).astype(int)
  eng_result = pd.DataFrame(eng_binary, columns=LABELS)
  eng_result.insert(0, 'id', eng_dev['id'])
  eng_result.to_csv(os.path.join(workdir, "pred_eng_mul.csv"), index=False)
  print(f"\n✓ English predictions saved")
  print(f"Distribution:\n{eng_result[LABELS].sum()}")

  # Predict Swahili
  swa_dataset = PolarizationDataset(swa_dev['text'].tolist(), [[0]*5]*len(swa_dev), tokenizer, CONFIG['max_len'])
  swa_preds = trainer.predict(swa_dataset)
  swa_probs = torch.sigmoid(torch.tensor(swa_preds.predictions)).numpy()
  swa_binary = (swa_probs > best_thresh).astype(int)
  swa_result = pd.DataFrame(swa_binary, columns=LABELS)
  swa_result.insert(0, 'id', swa_dev['id'])
  swa_result.to_csv(os.path.join(workdir, "pred_swa_mul.csv"), index=False)
  print(f"\n✓ Swahili predictions saved")
  print(f"Distribution:\n{swa_result[LABELS].sum()}")



In [None]:
train_combined(class_weight=False)

COMBINED MULTILINGUAL TRAINING
Augmenting eng: 3222 samples
  Total attempts: 5703
  Successful augmentations: 504
  +504 augmented → 3726 total samples
Augmenting swa: 6991 samples
  Total attempts: 12267
  Successful augmentations: 131
  +131 augmented → 7122 total samples

Combined training: 10848 samples
  gender/sexual: 244
  political: 1888
  religious: 504
  racial/ethnic: 2807
  other: 754


Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,F1 Macro,F1 Gender Sexual,F1 Political,F1 Religious,F1 Racial Ethnic,F1 Other
100,0.3065,0.277445,0.162139,0.0,0.53012,0.0,0.280576,0.0
200,0.2708,0.232237,0.227488,0.0,0.728682,0.0,0.408759,0.0
300,0.2225,0.229803,0.423706,0.0,0.697674,0.732558,0.688297,0.0
400,0.1992,0.188028,0.457698,0.0,0.757315,0.785185,0.745989,0.0
500,0.2059,0.186335,0.448233,0.0,0.740443,0.755245,0.745476,0.0
600,0.1837,0.182779,0.457013,0.0,0.791367,0.730539,0.763158,0.0
700,0.1738,0.176243,0.467086,0.0,0.782435,0.785185,0.76781,0.0
800,0.1561,0.180064,0.475823,0.044444,0.787546,0.8,0.747126,0.0
900,0.1526,0.174309,0.508443,0.042553,0.795009,0.771242,0.769231,0.164179
1000,0.1362,0.176425,0.497663,0.043478,0.782931,0.794118,0.759259,0.108527


Step,Training Loss,Validation Loss,F1 Macro,F1 Gender Sexual,F1 Political,F1 Religious,F1 Racial Ethnic,F1 Other
100,0.3065,0.277445,0.162139,0.0,0.53012,0.0,0.280576,0.0
200,0.2708,0.232237,0.227488,0.0,0.728682,0.0,0.408759,0.0
300,0.2225,0.229803,0.423706,0.0,0.697674,0.732558,0.688297,0.0
400,0.1992,0.188028,0.457698,0.0,0.757315,0.785185,0.745989,0.0
500,0.2059,0.186335,0.448233,0.0,0.740443,0.755245,0.745476,0.0
600,0.1837,0.182779,0.457013,0.0,0.791367,0.730539,0.763158,0.0
700,0.1738,0.176243,0.467086,0.0,0.782435,0.785185,0.76781,0.0
800,0.1561,0.180064,0.475823,0.044444,0.787546,0.8,0.747126,0.0
900,0.1526,0.174309,0.508443,0.042553,0.795009,0.771242,0.769231,0.164179
1000,0.1362,0.176425,0.497663,0.043478,0.782931,0.794118,0.759259,0.108527



Final F1 Macro: 0.5621
Best threshold: 0.30 (F1: 0.5801)



✓ English predictions saved
Distribution:
gender/sexual     0
political        56
religious         6
racial/ethnic    19
other             1
dtype: int64



✓ Swahili predictions saved
Distribution:
gender/sexual      2
political          9
religious         14
racial/ethnic    159
other             23
dtype: int64
