# Counterfactual Sample Generation for Microaggression Detection

Generate counterfactual (non-microaggressive) samples from microaggressive texts using few-shot
prompting with GPT to create a balanced binary classification dataset.

## Pipeline:
1. **Setup**: Load dataset and configure API
2. **Generate & Balance**: For each split, generate counterfactuals and combine with originals
3. **Save**: Output balanced train/validation/test CSV files

## 1. Setup and Configuration

In [None]:
# Imports
import os
import pandas as pd
from datasets import Dataset, DatasetDict
from google.colab import drive, userdata

# Configuration
BASE_PATH = '/content/drive/MyDrive/266_project/'
DATASET_PATH = BASE_PATH + 'selfMA_ds'
RANDOM_SEED = 42

In [None]:
# Mount Google Drive and load dataset
drive.mount('/content/drive')
selfMA = DatasetDict.load_from_disk(DATASET_PATH)

print("Dataset loaded:")
print(selfMA)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Dataset loaded:
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 1040
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 130
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 130
    })
})


In [None]:
# Setup OpenAI API
OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
print("OpenAI API key configured.")

OpenAI API key configured.


## 2. Counterfactual Generation Function

Define a function that uses few-shot prompting to transform microaggressive text into non-microaggressive equivalents while preserving the core neutral sentiment.

In [None]:
from openai import OpenAI

client = OpenAI()

def generate_non_microaggressive_equivalent(microaggressive_text):
    """
    Transform microaggressive text into a non-microaggressive equivalent.
    Uses few-shot prompting with examples to guide the transformation.

    Args:
        microaggressive_text: The input text containing microaggression

    Returns:
        Non-microaggressive equivalent text
    """
    prompt = f"""Task: Transform the given microaggressive text into a non-microaggressive equivalent.
The non-microaggressive equivalent should remove any bias, insensitivity, or invalidation present
in the original text, while retaining the core positive or neutral sentiment if applicable.
Do not repeat the original text in the answer. Do not use labels in the answer.
Produce the non-microaggressive equivalent phrase as the answer.

Examples:
Microaggressive: "You're very articulate for someone like you."
Non-Microaggressive: "Your presentation was very articulate."

Microaggressive: "Where are you really from?"
Non-Microaggressive: "Where did you grow up?"

Microaggressive: "You must be good at math since you're Asian."
Non-Microaggressive: "You are really good at math."

Microaggressive: "You're too pretty to be a software engineer."
Non-Microaggressive: "You are pretty."

Microaggressive: "Are you sure you want to lead this project?"
Non-Microaggressive: "Can you lead this project?"

Microaggressive: "You're surprisingly smart for someone like you."
Non-Microaggressive: "You are really smart."

Microaggressive: "I didn't expect someone like you to be this good at coding."
Non-Microaggressive: "You are good at coding."

Microaggressive: "Are you the diversity hire?"
Non-Microaggressive: "Are you the new hire?"

Microaggressive: "You look too feminine to be a scientist."
Non-Microaggressive: "You are a scientist."

Microaggressive: "You're lucky you're cute."
Non-Microaggressive: "You're cute."

Microaggressive: "{microaggressive_text}"
Non-Microaggressive:
"""
    resp = client.chat.completions.create(
        model="gpt-4.1",
        messages=[{"role": "user", "content": prompt}]
    )
    return resp.choices[0].message.content.strip()

## 3. Generate Counterfactual Samples

For each split (train/validation/test):
- Generate counterfactual for each microaggressive text (label 0)
- Keep original microaggressive texts (label 1)  
- Combine and shuffle within the split
- Save to CSV

In [None]:
def process_split(split_name, dataset_split):
    """
    Process a single split: generate counterfactuals, combine with originals, shuffle.

    Returns a balanced DataFrame with 'text' and 'label' columns.
    """
    texts = dataset_split['text']
    labels = dataset_split['label']

    original_samples = []      # Label 1: microaggressive
    counterfactual_samples = [] # Label 0: non-microaggressive

    print(f"Processing {split_name} split ({len(texts)} samples)...")

    for i, (text, label) in enumerate(zip(texts, labels)):
        if label in [1, 2]:  # Microaggressive text
            # Keep original as label 1
            original_samples.append({'text': text, 'label': 1})

            # Generate counterfactual as label 0
            try:
                counterfactual = generate_non_microaggressive_equivalent(text)
                counterfactual_samples.append({'text': counterfactual, 'label': 0})
            except Exception as e:
                print(f"  Error at sample {i}: {e}")
                counterfactual_samples.append({'text': f"Error: {e}", 'label': 0})

            if (i + 1) % 50 == 0:
                print(f"  Processed {i + 1}/{len(texts)}...")

    # Combine and shuffle
    df = pd.concat([
        pd.DataFrame(original_samples),
        pd.DataFrame(counterfactual_samples)
    ], ignore_index=True)

    df = df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)

    print(f"  Completed: {len(df)} samples (balanced: {len(original_samples)} + {len(counterfactual_samples)})")
    return df

In [None]:
# Process all splits and save
results = {}

for split_name in ['train', 'validation', 'test']:
    df = process_split(split_name, selfMA[split_name])
    results[split_name] = df

    # Save to CSV
    output_path = f"{BASE_PATH}balanced_{split_name}.csv"
    df.to_csv(output_path, index=False)
    print(f"  Saved to {output_path}\n")

Processing train split (1040 samples)...
  Processed 50/1040...
  Processed 100/1040...
  Processed 150/1040...
  Processed 200/1040...
  Processed 250/1040...
  Processed 300/1040...
  Processed 350/1040...
  Processed 400/1040...
  Processed 450/1040...
  Processed 500/1040...
  Processed 550/1040...
  Processed 600/1040...
  Processed 650/1040...
  Processed 700/1040...
  Processed 750/1040...
  Processed 800/1040...
  Processed 850/1040...
  Processed 900/1040...
  Processed 950/1040...
  Processed 1000/1040...
  Completed: 2080 samples (balanced: 1040 + 1040)
  Saved to /content/drive/MyDrive/266_project/balanced_train.csv

Processing validation split (130 samples)...
  Processed 50/130...
  Processed 100/130...
  Completed: 260 samples (balanced: 130 + 130)
  Saved to /content/drive/MyDrive/266_project/balanced_validation.csv

Processing test split (130 samples)...
  Processed 50/130...
  Processed 100/130...
  Completed: 260 samples (balanced: 130 + 130)
  Saved to /content/driv

## 4. Summary

Load the generated counterfactual samples and combine them with original
microaggressive texts to create a balanced binary classification dataset:
- **Label 0**: Non-microaggressive (generated counterfactuals)
- **Label 1**: Microaggressive (original texts)

In [None]:
print("=" * 60)
print("DATASET GENERATION COMPLETE")
print("=" * 60)

total_samples = sum(len(df) for df in results.values())

print(f"""
Results:
--------
Train:      {len(results['train']):,} samples ({len(results['train'])//2} per class)
Validation: {len(results['validation']):,} samples ({len(results['validation'])//2} per class)
Test:       {len(results['test']):,} samples ({len(results['test'])//2} per class)

Total:      {total_samples:,} samples

Class Balance: 1:1 (perfectly balanced in each split)

Output Files:
- {BASE_PATH}balanced_train.csv
- {BASE_PATH}balanced_validation.csv
- {BASE_PATH}balanced_test.csv
""")

# Show sample from each split
for split_name, df in results.items():
    print(f"\n{split_name.upper()} sample:")
    print(df.head(4).to_string(index=False))

DATASET GENERATION COMPLETE

Results:
--------
Train:      2,080 samples (1040 per class)
Validation: 260 samples (130 per class)
Test:       260 samples (130 per class)

Total:      2,600 samples

Class Balance: 1:1 (perfectly balanced in each split)

Output Files:
- /content/drive/MyDrive/266_project/balanced_train.csv
- /content/drive/MyDrive/266_project/balanced_validation.csv
- /content/drive/MyDrive/266_project/balanced_test.csv


TRAIN sample:
                                                                                         text  label
                                                         I'm glad to have you as my roommate.      0
                                                                    THERE'S NO BITCH IN TEAM.      1
                                                                              Are you normal?      1
Hi Miss [woman with a PhD], do you know where Dr. [male colleague without a PhD]'s office is?      1

VALIDATION sample:
                    