In [None]:
# Cell 1: Parameter Setup

# Sampling parameters
num_train_per_prompt = 9
dialects = ["ine", "bre", "aae", "sge", "che"]
method = "sft"           # "dpo" or "sft"
prompt_stype = "concise"   # "concsie" or "detailed"
generative_model = "stable-diffusion1.5"

# Experiment naming
exp_name = f"all_dialects-{method}-{prompt_stype}-4-1-1-{num_train_per_prompt}"
num_processes = 4

# Path templates (use .format(dialect=...) in Cell 2)
csv_template = (
    "/local1/bryanzhou008/Dialect/"
    "multimodal-dialectal-bias/data/text/train_val_test/4-1-1/"
    f"{prompt_stype}/{{dialect}}/train.csv"
)
images_template = (
    "/local1/bryanzhou008/Dialect/"
    f"multimodal-dialectal-bias/data/image/{prompt_stype}/"
    "{dialect}/" + generative_model
)

# Output directories
output_folder = (
    "/local1/bryanzhou008/Dialect/"
    f"multimodal-dialectal-bias/mitigation/baselines/"
    f"diffusion_dpo/a_data/{exp_name}"
)

# Training script template
script_command = f"""#!/bin/bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATA_DIR="{output_folder}"

accelerate launch --num_processes {num_processes} \\
  /local1/bryanzhou008/Dialect/multimodal-dialectal-bias/mitigation/baselines/diffusion_dpo/a_src/train/train.py \\
    --pretrained_model_name_or_path=$MODEL_NAME \\
    --train_data_dir=$DATA_DIR \\
    --train_batch_size=1 \\
    --dataloader_num_workers=16 \\
    --gradient_accumulation_steps=128 \\
    --max_train_steps=2000 \\
    --lr_scheduler="constant_with_warmup" --lr_warmup_steps=500 \\
    --learning_rate=1e-8 --scale_lr \\
    --cache_dir="/local1/bryanzhou008/Dialect/multimodal-dialectal-bias/mitigation/baselines/diffusion_dpo/temp_cache/" \\
    --checkpointing_steps=200 \\
    --beta_dpo=5000 \\
    --output_dir="/local1/bryanzhou008/Dialect/multimodal-dialectal-bias/mitigation/baselines/diffusion_dpo/a_checkpoints/{exp_name}" \\
    --report_to "wandb" 
"""


In [None]:
# Cell 2: Data Preparation and Script Generation

import os
import csv
import json
import shutil

# Create unified train directory
train_dir = os.path.join(output_folder, "train")
os.makedirs(train_dir, exist_ok=True)

metadata_entries = []
global_counter = 0

# Loop over each dialect
for dialect in dialects:
    # Resolve paths for this dialect
    csv_file = csv_template.format(dialect=dialect)
    input_images_folder = images_template.format(dialect=dialect)
    
    # Read CSV
    with open(csv_file, newline='', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row_index, row in enumerate(reader):
            dialect_prompt = row["Dialect_Prompt"].strip()
            sae_prompt     = row["SAE_Prompt"].strip()
            prompt_text    = dialect_prompt  # caption

            # Source subfolders
            lose_dir = os.path.join(input_images_folder, "dialect_imgs", dialect_prompt)
            win_dir  = os.path.join(input_images_folder, "sae_imgs",     sae_prompt)

            # Sample first N images
            for i in range(num_train_per_prompt):
                win_src = os.path.join(win_dir,  f"{i}.jpg")
                if method == "dpo":
                    lose_src = os.path.join(lose_dir, f"{i}.jpg")
                
                # Skip if missing
                if not os.path.exists(win_src):
                    print(f"Warning: missing win {win_src}")
                    continue
                if method == "dpo" and not os.path.exists(lose_src):
                    print(f"Warning: missing lose {lose_src}")
                    continue

                # Destination filenames
                win_name  = f"win_{global_counter}.jpg"
                win_dest  = os.path.join(train_dir, win_name)
                shutil.copy(win_src, win_dest)

                if method == "dpo":
                    lose_name = f"lose_{global_counter}.jpg"
                    lose_dest = os.path.join(train_dir, lose_name)
                    shutil.copy(lose_src, lose_dest)

                # Build metadata entries
                if method == "dpo":
                    metadata_entries.append({
                        "file_name": win_name,
                        "jpg_0":     win_name,
                        "jpg_1":     lose_name,
                        "label_0":   1,
                        "caption":   prompt_text
                    })
                    metadata_entries.append({
                        "file_name": lose_name,
                        "jpg_0":     win_name,
                        "jpg_1":     lose_name,
                        "label_0":   1,
                        "caption":   prompt_text
                    })
                else:  # sft
                    metadata_entries.append({
                        "file_name": win_name,
                        "jpg_0":     win_name,
                        "jpg_1":     win_name,
                        "label_0":   1,
                        "caption":   prompt_text
                    })

                global_counter += 1

# Write metadata.jsonl
metadata_path = os.path.join(train_dir, "metadata.jsonl")
with open(metadata_path, "w", encoding='utf-8') as f:
    for entry in metadata_entries:
        f.write(json.dumps(entry) + "\n")

print(f"Dataset created at {output_folder}")
print(f"Total metadata entries: {len(metadata_entries)}")

# Write training script
script_path = os.path.join(output_folder, "run_training.sh")
with open(script_path, "w", encoding='utf-8') as f:
    f.write(script_command)
os.chmod(script_path, 0o755)

print(f"Training script created at {script_path}")
