In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install datasets -q

In [None]:
from datasets import load_dataset, Dataset
import json

In [None]:
dataset = load_dataset("json", data_files="/content/drive/MyDrive/Colab Notebooks/woke-odds/dpo_final_dataset.jsonl")
print(f"Original dataset size: {len(dataset['train'])}")
print("\nFirst example:")
print(dataset['train'][0])

In [None]:
def convert_to_conversational_explicit_prompt(example):
    """
    Convert from the old format to Conversational + Explicit prompt format

    Old format:
    {
        "prompt": "[LA|LEX] What is...",
        "chosen": "Are you referring to...",
        "rejected": "What do you mean..."
    }

    New format:
    {
        "prompt": [
            {"role": "system", "content": "..."},
            {"role": "user", "content": "..."}
        ],
        "chosen": [
            {"role": "assistant", "content": "..."}
        ],
        "rejected": [
            {"role": "assistant", "content": "..."}
        ]
    }
    """
    # System prompt (same as your SFT dataset)
    system_content = """You are an AI that generates a single, concise clarifying question when a user's query is ambiguous.

Task:
Generate exactly one clarifying question based on the ambiguity type.
If the query is clear and needs no clarification, output: <NO_CLARIFYING_QUESTION>

Output format: One clarifying question (or <NO_CLARIFYING_QUESTION> if not needed)

Categories:
- EM (Epistemic Misalignment): Questions with unfamiliar entities or self-contradictions
- LA (Linguistic Ambiguity): Questions with lexical or semantic ambiguity
- AO (Aleatoric Output): Questions with missing contextual information causing confusion
- NONE: Clear questions that don't require clarification

Subclasses:
For EM:
- UNF (UNFAMILIAR): Query contains unfamiliar entities or facts
- CONT (CONTRADICTION): Query contains self-contradictions

For LA:
- LEX (LEXICAL): Query contains terms with multiple meanings
- SEM (SEMANTIC): Query lacks context leading to multiple interpretations

For AO:
- WHOM: Query output contains confusion due to missing personal elements
- WHEN: Query output contains confusion due to missing temporal elements
- WHERE: Query output contains confusion due to missing spatial elements
- WHAT: Query output contains confusion due to missing task-specific elements

For Clear Questions:
- NONE: Use when require_clarification=0, output <NO_CLARIFYING_QUESTION>"""

    # Create prompt with system and user messages
    prompt = [
        {"role": "system", "content": system_content},
        {"role": "user", "content": example['prompt']}
    ]

    # Create chosen (good response)
    chosen = [
        {"role": "assistant", "content": example['chosen']}
    ]

    # Create rejected (bad response)
    rejected = [
        {"role": "assistant", "content": example['rejected']}
    ]

    return {
        "prompt": prompt,
        "chosen": chosen,
        "rejected": rejected
    }

In [None]:
converted_dataset = dataset['train'].map(
    convert_to_conversational_explicit_prompt,
    remove_columns=dataset['train'].column_names
)

print(f"Converted dataset size: {len(converted_dataset)}")


In [None]:
print("=" * 80)
print("CONVERTED FORMAT - First Example:")
print("=" * 80)
print(json.dumps(converted_dataset[0], indent=2, ensure_ascii=False))

print("\n" + "=" * 80)
print("CONVERTED FORMAT - Second Example:")
print("=" * 80)
print(json.dumps(converted_dataset[1], indent=2, ensure_ascii=False))

print("\n" + "=" * 80)
print("CONVERTED FORMAT - Third Example:")
print("=" * 80)
print(json.dumps(converted_dataset[2], indent=2, ensure_ascii=False))


In [None]:
def verify_structure(dataset, num_samples=5):
    """Verify that all samples have the correct structure"""
    print(f"\nVerifying structure of {num_samples} samples...")

    for i in range(min(num_samples, len(dataset))):
        sample = dataset[i]

        # Check required keys
        assert 'prompt' in sample, f"Sample {i}: Missing 'prompt' key"
        assert 'chosen' in sample, f"Sample {i}: Missing 'chosen' key"
        assert 'rejected' in sample, f"Sample {i}: Missing 'rejected' key"

        # Check prompt structure
        assert isinstance(sample['prompt'], list), f"Sample {i}: 'prompt' should be a list"
        assert len(sample['prompt']) >= 1, f"Sample {i}: 'prompt' should have at least 1 message"

        # Check chosen structure
        assert isinstance(sample['chosen'], list), f"Sample {i}: 'chosen' should be a list"
        assert len(sample['chosen']) >= 1, f"Sample {i}: 'chosen' should have at least 1 message"
        assert sample['chosen'][0]['role'] == 'assistant', f"Sample {i}: 'chosen' should start with assistant"

        # Check rejected structure
        assert isinstance(sample['rejected'], list), f"Sample {i}: 'rejected' should be a list"
        assert len(sample['rejected']) >= 1, f"Sample {i}: 'rejected' should have at least 1 message"
        assert sample['rejected'][0]['role'] == 'assistant', f"Sample {i}: 'rejected' should start with assistant"

        print(f"✓ Sample {i}: Valid structure")

    print(f"\n✅ All {num_samples} samples verified successfully!")
    return True

verify_structure(converted_dataset, num_samples=10)

In [None]:
output_path = "/content/drive/MyDrive/Colab Notebooks/woke-odds/dpo_final_dataset_modified.jsonl"

# Save as JSONL with explicit key ordering
with open(output_path, 'w', encoding='utf-8') as f:
    for example in converted_dataset:
        # Create ordered dict to ensure role comes before content
        ordered_example = {
            "prompt": [
                {"role": msg["role"], "content": msg["content"]}
                for msg in example["prompt"]
            ],
            "chosen": [
                {"role": msg["role"], "content": msg["content"]}
                for msg in example["chosen"]
            ],
            "rejected": [
                {"role": msg["role"], "content": msg["content"]}
                for msg in example["rejected"]
            ]
        }
        f.write(json.dumps(ordered_example, ensure_ascii=False) + '\n')

print(f"\n✅ Converted dataset saved to: {output_path}")
print(f"Total samples: {len(converted_dataset)}")


In [None]:
# Load the saved file to verify
verification_dataset = load_dataset("json", data_files=output_path)
print(f"\n✅ Verification: Successfully loaded {len(verification_dataset['train'])} samples from saved file")

print("\nFirst sample from saved file:")
print(json.dumps(verification_dataset['train'][0], indent=2, ensure_ascii=False))

In [None]:
print("\n" + "=" * 80)
print("DATASET STATISTICS")
print("=" * 80)

# Count messages in prompts
prompt_lengths = [len(sample['prompt']) for sample in converted_dataset]
print(f"Average prompt length (messages): {sum(prompt_lengths) / len(prompt_lengths):.2f}")
print(f"Min prompt length: {min(prompt_lengths)}")
print(f"Max prompt length: {max(prompt_lengths)}")

# Count roles in prompts
system_count = sum(1 for sample in converted_dataset if any(msg['role'] == 'system' for msg in sample['prompt']))
print(f"\nSamples with system message: {system_count}/{len(converted_dataset)}")

print("\n✅ Conversion completed successfully!")
