# Llama Guard Few-Shot Classification with vLLM

This notebook implements few-shot classification using Llama Guard 3 8B with vLLM for efficient batch inference.
We use 20 examples from the training set (10 per class) to guide the model.

In [None]:
!pip install uv

In [None]:
%%time
%%capture
!uv pip install vllm --torch-backend=auto --system

In [None]:
!pip show vllm

In [None]:
import os
import re
import zipfile

import pandas as pd
from vllm import LLM, SamplingParams

## Configuration

In [None]:
# Data paths
DEV_DATA_PATH = "../../data/dev_phase/subtask1/dev"
TRAIN_DATA_PATH = "../../data/dev_phase/subtask1/train"
OUTPUT_DIR = "../../predictions/llama_guard_vllm_few_shot"

# Language codes
LANG_CODES = [
    "amh",
    "arb",
    "ben",
    "deu",
    "eng",
    "fas",
    "hau",
    "hin",
    "ita",
    "khm",
    "mya",
    "nep",
    "ori",
    "pan",
    "pol",
    "rus",
    "spa",
    "swa",
    "tel",
    "tur",
    "urd",
    "zho",
]

# Few-shot configuration
N_EXAMPLES_PER_CLASS = 10  # 10 examples for label 0, 10 for label 1
RANDOM_SEED = 42

## Load Llama Guard Model with vLLM

In [None]:
%%time
llm = LLM(
    model="meta-llama/Llama-Guard-3-8B",
    dtype="auto",
    gpu_memory_utilization=0.9
)

## Custom Polarization Taxonomy

In [None]:
POLARIZATION_TAXONOMY = """<BEGIN UNSAFE CONTENT CATEGORIES>
P1: Political Polarization
Content that creates strong us-vs-them divisions, vilifies political groups, or uses divisive rhetoric about political figures or parties.

P2: Implicit Hate Speech
Subtle expressions of prejudice, stereotyping, or derogatory characterizations of groups based on identity, religion, ethnicity, or nationality.

P3: Conspiracy Theories & Misinformation
Content promoting unfounded conspiracy theories, electoral fraud claims, or misinformation that fuels distrust in institutions.

P4: Sarcasm & Mockery with Hostile Intent
Sarcastic or mocking language directed at groups or individuals with intent to demean, ridicule, or inflame tensions.

P5: Gender-Based Hostility
Content expressing misogyny, gender-based violence, or hostile attitudes toward gender groups.
<END UNSAFE CONTENT CATEGORIES>

<BEGIN POLICY>
Determine if the content exhibits polarization according to the categories above. Content is polarized if it contains divisive political rhetoric, conspiracy theories, implicit hate speech, hostile sarcasm, or gender-based hostility. Consider context and implied meanings, not just explicit language.
<END POLICY>"""

## Data Loading and Few-Shot Example Selection

In [None]:
def load_few_shot_examples(lang_code, n_per_class=10, seed=42):
    """Load balanced few-shot examples from training data."""
    train_file = os.path.join(TRAIN_DATA_PATH, f"{lang_code}.csv")

    if not os.path.exists(train_file):
        raise FileNotFoundError(f"No training file for {lang_code}")

    df = pd.read_csv(train_file)

    # Sample n_per_class examples from each class
    polarized = df[df["polarization"] == 1].sample(
        n=n_per_class, random_state=seed
    )
    non_polarized = df[df["polarization"] == 0].sample(
        n=n_per_class, random_state=seed
    )

    # Combine and shuffle
    return pd.concat([polarized, non_polarized]).sample(frac=1, random_state=seed)

In [None]:
def build_few_shot_prefix(lang_code, few_shot_df):
    """Build the few-shot prompt prefix with examples."""
    header = f"""{POLARIZATION_TAXONOMY}

Language: {lang_code}

Below are labeled examples:

"""

    examples = []
    for _, row in few_shot_df.iterrows():
        # Use "unsafe" for polarized (1) and "safe" for non-polarized (0)
        label = "unsafe" if row["polarization"] == 1 else "safe"
        examples.append(
            f"Text: {row['text']}\nClassification: {label}\n"
        )

    examples_block = "\n".join(examples)

    footer = """
Now classify the following text:

Text: {text}
Classification:"""

    return header + examples_block + footer

## Cache Few-Shot Prefixes for All Languages

In [None]:
# Cache Prefixes Per Language
FEW_SHOT_PREFIXES = {}

for lang in LANG_CODES:
    try:
        few_shot_df = load_few_shot_examples(lang, n_per_class=N_EXAMPLES_PER_CLASS, seed=RANDOM_SEED)
        FEW_SHOT_PREFIXES[lang] = build_few_shot_prefix(lang, few_shot_df)
    except Exception as e:
        print(f"⚠ Skipping {lang}: {e}")

print(f"✓ Few-shot prefixes ready for {len(FEW_SHOT_PREFIXES)} languages")

## Test Few-Shot Prefix

In [None]:
# View a sample prefix
test_lang = "eng"
if test_lang in FEW_SHOT_PREFIXES:
    sample_prefix = FEW_SHOT_PREFIXES[test_lang].format(text="SAMPLE_TEXT_HERE")
    print(f"Sample prefix for {test_lang}:")
    print("=" * 80)
    print(sample_prefix[:1000] + "..." if len(sample_prefix) > 1000 else sample_prefix)
    print("=" * 80)

## Prompt Creation and Response Parsing

In [None]:
def create_prompt(text, language_code):
    """Create a few-shot prompt for polarization detection."""
    prefix = FEW_SHOT_PREFIXES[language_code]
    return prefix.format(text=text)

In [None]:
def parse_response(response):
    """
    Parse Llama Guard response to binary label.
    
    Llama Guard returns:
    - "safe" or "safe\n" for non-polarized content (label 0)
    - "unsafe" or "\n\nunsafe\nS1" for polarized content (label 1)
    """
    if response is None:
        return 0  # Default to non-polarized if error

    response_lower = response.lower().strip()

    # Check for "unsafe" first (polarized = 1)
    if "unsafe" in response_lower:
        return 1
    # Check for "safe" (non-polarized = 0)
    elif "safe" in response_lower:
        return 0
    else:
        print(f"Unclear response: {response}")
        return 0

## Test Inference

In [None]:
%%time
# Test with sample texts
test_text_1 = "I hate all politicians, they are destroying our country!"
test_text_2 = "The weather is nice today."

test_prompts = [
    create_prompt(test_text_1, "eng"),
    create_prompt(test_text_2, "eng")
]

test_params = SamplingParams(
    temperature=0.0,  # Deterministic
    max_tokens=10,    # Short responses
)

test_outputs = llm.generate(test_prompts, test_params)

for i, output in enumerate(test_outputs):
    response = output.outputs[0].text.strip()
    prediction = parse_response(response)
    test_text = test_text_1 if i == 0 else test_text_2
    print(f"Text: {test_text}")
    print(f"Raw response: {response}")
    print(f"Predicted label: {prediction}")
    print("-" * 80)

## Process Dev Set with Batch Inference

In [None]:
def process_language(lang_code):
    """Process one language file with batched inference"""
    input_file = os.path.join(DEV_DATA_PATH, f"{lang_code}.csv")

    if not os.path.exists(input_file):
        print(f"File not found: {input_file}")
        return None

    df = pd.read_csv(input_file)
    print(f"\nProcessing {lang_code}: {len(df)} samples")

    # Prepare all prompts as a list
    prompts = [create_prompt(row["text"], lang_code) for _, row in df.iterrows()]

    # Define sampling params for fast, short outputs
    sampling_params = SamplingParams(
        temperature=0.0,  # Deterministic
        max_tokens=10,    # Enough for "safe" or "unsafe\nS1"
    )

    # Batched generation - vLLM handles batching automatically
    outputs = llm.generate(prompts, sampling_params)

    # Extract and parse responses
    predictions = []
    for idx, output in enumerate(outputs):
        response = output.outputs[0].text.strip()  # Get the generated text
        prediction = parse_response(response)
        post_id = df.iloc[idx]["id"]
        predictions.append({"id": post_id, "polarization": prediction})

    return pd.DataFrame(predictions)

In [None]:
def create_submission_zip(lang_codes_to_process=None):
    """Create submission zip file"""
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    if lang_codes_to_process is None:
        lang_codes_to_process = LANG_CODES

    for lang_code in lang_codes_to_process:
        pred_df = process_language(lang_code)

        if pred_df is not None:
            output_file = os.path.join(OUTPUT_DIR, f"pred_{lang_code}.csv")
            pred_df.to_csv(output_file, index=False)
            print(f"Saved predictions to {output_file}")

    zip_filename = "subtask_1_llama_guard_vllm.zip"
    with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
        for file in os.listdir(OUTPUT_DIR):
            if file.startswith("pred_") and file.endswith(".csv"):
                file_path = os.path.join(OUTPUT_DIR, file)
                zipf.write(file_path, os.path.join("subtask_1", file))
    
    print(f"\n✓ Created submission zip: {zip_filename}")

## Run Predictions

**Note:** Start with a single language to test, then process all languages.

In [None]:
# Process a single language for testing
create_submission_zip(["eng"])

In [None]:
# Uncomment to process all languages
# %%time
# create_submission_zip()

## View Results

In [None]:
# View predictions for a language
sample_lang = "eng"
pred_file = os.path.join(OUTPUT_DIR, f"pred_{sample_lang}.csv")

if os.path.exists(pred_file):
    pred_df = pd.read_csv(pred_file)
    print(f"Predictions for {sample_lang}:")
    print(f"Total samples: {len(pred_df)}")
    print(f"\nLabel distribution:")
    print(pred_df["polarization"].value_counts())
    print(f"\nFirst 10 predictions:")
    print(pred_df.head(10))
else:
    print(f"No predictions found for {sample_lang}")