In [1]:
%%time
%%capture
!pip install --upgrade transformers
!pip install accelerate

CPU times: user 8.1 ms, sys: 4.07 ms, total: 12.2 ms
Wall time: 7.96 s


# Llama Guard Few-Shot Classification for Polarization Detection

This notebook implements few-shot classification using Llama Guard 3 8B model.
We use 20 examples from the training set (10 per class) to guide the model.

In [20]:
import os
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import shutil

In [3]:
model_id = "meta-llama/Llama-Guard-3-8B"
device = "cuda"
dtype = torch.bfloat16

In [5]:
%%time
# hf auth login
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, device_map=device)

tokenizer_config.json:   0%|          | 0.00/52.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/860 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

CPU times: user 21.6 s, sys: 17.8 s, total: 39.3 s
Wall time: 44.3 s


## Configuration

In [6]:
# Data paths
DEV_DATA_PATH = "subtask1/dev"
TRAIN_DATA_PATH = "subtask1/train"
OUTPUT_DIR = "./llama_guard_few_shot"

# Language codes
LANG_CODES = [
    "amh", "arb", "ben", "deu", "eng", "fas", "hau", "hin",
    "ita", "khm", "mya", "nep", "ori", "pan", "pol", "rus",
    "spa", "swa", "tel", "tur", "urd", "zho"
]

# Few-shot configuration
N_EXAMPLES_PER_CLASS = 10  # 10 examples for label 0, 10 for label 1
RANDOM_SEED = 42

## Custom Polarization Taxonomy for Llama Guard

In [7]:
POLARIZATION_TAXONOMY = """<BEGIN UNSAFE CONTENT CATEGORIES>
P1: Political Polarization
Content that creates strong us-vs-them divisions, vilifies political groups, or uses divisive rhetoric about political figures or parties.

P2: Implicit Hate Speech
Subtle expressions of prejudice, stereotyping, or derogatory characterizations of groups based on identity, religion, ethnicity, or nationality.

P3: Conspiracy Theories & Misinformation
Content promoting unfounded conspiracy theories, electoral fraud claims, or misinformation that fuels distrust in institutions.

P4: Sarcasm & Mockery with Hostile Intent
Sarcastic or mocking language directed at groups or individuals with intent to demean, ridicule, or inflame tensions.

P5: Gender-Based Hostility
Content expressing misogyny, gender-based violence, or hostile attitudes toward gender groups.
<END UNSAFE CONTENT CATEGORIES>

<BEGIN POLICY>
Determine if the content exhibits polarization according to the categories above. Content is polarized if it contains divisive political rhetoric, conspiracy theories, implicit hate speech, hostile sarcasm, or gender-based hostility. Consider context and implied meanings, not just explicit language.
<END POLICY>"""

## Data Loading Functions

In [8]:
def load_few_shot_examples(lang_code, n_per_class=10, seed=42):
    """Load balanced few-shot examples from training data."""
    train_file = os.path.join(TRAIN_DATA_PATH, f"{lang_code}.csv")
    
    if not os.path.exists(train_file):
        raise FileNotFoundError(f"No training file for {lang_code}")
    
    df = pd.read_csv(train_file)
    
    # Sample n_per_class examples from each class
    polarized = df[df["polarization"] == 1].sample(n=n_per_class, random_state=seed)
    non_polarized = df[df["polarization"] == 0].sample(n=n_per_class, random_state=seed)
    
    # Combine and shuffle
    examples = pd.concat([polarized, non_polarized]).sample(frac=1, random_state=seed)
    
    return examples

## Prompt Building with Few-Shot Examples

In [9]:
def build_few_shot_prompt(text, lang_code, few_shot_examples):
    """
    Build a Llama Guard prompt with custom taxonomy and few-shot examples.
    
    Args:
        text: The text to classify
        lang_code: ISO 639 language code
        few_shot_examples: DataFrame with 'text' and 'polarization' columns
    
    Returns:
        Formatted prompt string
    """
    # Build examples section
    examples_text = "\n\nHere are some labeled examples:\n\n"
    
    for idx, row in few_shot_examples.iterrows():
        label = "unsafe" if row["polarization"] == 1 else "safe"
        examples_text += f"Text: {row['text']}\n"
        examples_text += f"Classification: {label}\n\n"
    
    # Build the full prompt
    prompt = f"""{POLARIZATION_TAXONOMY}

Language: {lang_code}
{examples_text}
Now classify the following text:

Text: {text}
Classification:"""
    
    return prompt

## Inference Functions

In [10]:
def moderate(chat):
    """Generate response from Llama Guard."""
    input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
    output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
    prompt_len = input_ids.shape[-1]
    
    return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)

In [11]:
def parse_llama_guard_response(response):
    """
    Parse Llama Guard response to binary label.
    
    Llama Guard returns:
    - "safe" or "safe\n" for non-polarized content (label 0)
    - "unsafe\nS1" or similar for polarized content (label 1)
    
    Args:
        response: Raw response from Llama Guard
    
    Returns:
        0 for safe (non-polarized), 1 for unsafe (polarized)
    """
    if response is None:
        return 0  # Default to non-polarized
    
    response_lower = response.lower().strip()
    
    # Check for "unsafe" first (polarized = 1)
    if "unsafe" in response_lower:
        return 1
    # Check for "safe" (non-polarized = 0)
    elif "safe" in response_lower:
        return 0
    else:
        # If unclear, default to non-polarized
        print(f"Warning: Unclear response: {response}")
        return 0

In [12]:
def classify_text(text, lang_code, few_shot_examples):
    """
    Classify a single text using Llama Guard with few-shot examples.
    
    Args:
        text: Text to classify
        lang_code: Language code
        few_shot_examples: DataFrame with training examples
    
    Returns:
        Binary label (0 or 1)
    """
    # Build prompt with few-shot examples
    prompt = build_few_shot_prompt(text, lang_code, few_shot_examples)
    
    # Format as chat for Llama Guard
    chat = [{"role": "user", "content": prompt}]
    
    # Generate response
    response = moderate(chat)
    
    # Parse response to binary label
    label = parse_llama_guard_response(response)
    
    return label, response

## Test the Implementation

In [13]:
# Test with a simple example
test_lang = "eng"
test_examples = load_few_shot_examples(test_lang, n_per_class=N_EXAMPLES_PER_CLASS, seed=RANDOM_SEED)

print(f"Loaded {len(test_examples)} few-shot examples for {test_lang}")
print(f"\nClass distribution:")
print(test_examples["polarization"].value_counts())

Loaded 20 few-shot examples for eng

Class distribution:
polarization
1    10
0    10
Name: count, dtype: int64


In [14]:
# Test classification
test_text_1 = "I hate all politicians, they are destroying our country!"
test_text_2 = "The weather is nice today."

print("Test 1 (polarized):")
label_1, response_1 = classify_text(test_text_1, test_lang, test_examples)
print(f"Text: {test_text_1}")
print(f"Raw response: {response_1}")
print(f"Predicted label: {label_1}")

print("\n" + "="*80 + "\n")

print("Test 2 (non-polarized):")
label_2, response_2 = classify_text(test_text_2, test_lang, test_examples)
print(f"Text: {test_text_2}")
print(f"Raw response: {response_2}")
print(f"Predicted label: {label_2}")

Test 1 (polarized):
Text: I hate all politicians, they are destroying our country!
Raw response: 

safe
Predicted label: 0


Test 2 (non-polarized):
Text: The weather is nice today.
Raw response: 

safe
Predicted label: 0


## Process Dev Set

In [15]:
def process_language(lang_code, few_shot_examples):
    """
    Process all samples for a single language.
    
    Args:
        lang_code: Language code to process
        few_shot_examples: DataFrame with training examples for this language
    
    Returns:
        DataFrame with predictions
    """
    dev_file = os.path.join(DEV_DATA_PATH, f"{lang_code}.csv")
    
    if not os.path.exists(dev_file):
        print(f"Dev file not found: {dev_file}")
        return None
    
    # Load dev data
    dev_df = pd.read_csv(dev_file)
    print(f"\nProcessing {lang_code}: {len(dev_df)} samples")
    
    # Process each sample
    predictions = []
    
    for idx, row in dev_df.iterrows():
        text = row["text"]
        post_id = row["id"]
        
        # Classify
        label, response = classify_text(text, lang_code, few_shot_examples)
        
        predictions.append({
            "id": post_id,
            "polarization": label
        })
        
        # # Progress indicator
        # if (idx + 1) % 10 == 0:
        #     print(f"  Processed {idx + 1}/{len(dev_df)} samples")
    
    return pd.DataFrame(predictions)

In [16]:
def process_all_languages(lang_codes=None):
    """
    Process all languages and save predictions.
    
    Args:
        lang_codes: List of language codes to process (default: all)
    """
    if lang_codes is None:
        lang_codes = LANG_CODES
    
    # Create output directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    # Process each language
    for lang_code in lang_codes:
        try:
            # Load few-shot examples for this language
            few_shot_examples = load_few_shot_examples(
                lang_code, 
                n_per_class=N_EXAMPLES_PER_CLASS, 
                seed=RANDOM_SEED
            )
            
            # Process dev set
            pred_df = process_language(lang_code, few_shot_examples)
            
            if pred_df is not None:
                # Save predictions
                output_file = os.path.join(OUTPUT_DIR, f"pred_{lang_code}.csv")
                pred_df.to_csv(output_file, index=False)
                print(f"✓ Saved predictions to {output_file}")
                
        except Exception as e:
            print(f"✗ Error processing {lang_code}: {e}")
            continue
    
    print("\n" + "="*80)
    print("Processing complete!")

## Run Predictions on Dev Set

**Note:** You can process a single language first to test, or process all languages.
Processing all languages may take a while depending on your GPU.

In [17]:
# Process a single language for testing
# process_all_languages(["eng"])

In [18]:
%%time
# Uncomment to process all languages
process_all_languages()


Processing amh: 166 samples
  Processed 10/166 samples
  Processed 20/166 samples
  Processed 30/166 samples
  Processed 40/166 samples
  Processed 50/166 samples
  Processed 60/166 samples
  Processed 70/166 samples
  Processed 80/166 samples
  Processed 90/166 samples
  Processed 100/166 samples
  Processed 110/166 samples
  Processed 120/166 samples
  Processed 130/166 samples
  Processed 140/166 samples
  Processed 150/166 samples
  Processed 160/166 samples
✓ Saved predictions to ./llama_guard_few_shot/pred_amh.csv

Processing arb: 169 samples
  Processed 10/169 samples
  Processed 20/169 samples
  Processed 30/169 samples
  Processed 40/169 samples
  Processed 50/169 samples
  Processed 60/169 samples
  Processed 70/169 samples
  Processed 80/169 samples
  Processed 90/169 samples
  Processed 100/169 samples
  Processed 110/169 samples
  Processed 120/169 samples
  Processed 130/169 samples
  Processed 140/169 samples
  Processed 150/169 samples
  Processed 160/169 samples
✓ Sav

## View Sample Predictions

In [19]:
# # View predictions for a language
# sample_lang = "eng"
# pred_file = os.path.join(OUTPUT_DIR, f"pred_{sample_lang}.csv")

# if os.path.exists(pred_file):
#     pred_df = pd.read_csv(pred_file)
#     print(f"Predictions for {sample_lang}:")
#     print(f"Total samples: {len(pred_df)}")
#     print(f"\nLabel distribution:")
#     print(pred_df["polarization"].value_counts())
#     print(f"\nFirst 10 predictions:")
#     print(pred_df.head(10))
# else:
#     print(f"No predictions found for {sample_lang}")

In [21]:
shutil.make_archive(base_name="subtask_1", format="zip", root_dir="llama_guard_few_shot")

'/home/jovyan/work/subtask_1.zip'