# Prep DPO and SFT data

In [6]:
# Cell 1: Parameter Setup

# Number of win-lose pairs (for dpo) or number of win examples (for sft) to use per prompt.
num_train_per_prompt = 9
dialect = "ine"
method = "dpo"   # sft or dpo
prompt_stype = "basic" # basic or complex


exp_name = f"{dialect}-{method}-{prompt_stype}-2-2-2-{num_train_per_prompt}"
num_processes = 7


# Path to the input CSV file (with header: Dialect_Word, SAE_Word, Dialect_Prompt, SAE_Prompt)
csv_file = f"/local1/bryanzhou008/Dialect/multimodal-dialectal-bias/data/text/train_val_test/2-2-2/{prompt_stype}/{dialect}/train.csv"

# Path to the input images folder; inside this folder, there will be subfolders named after each prompt.
input_images_folder = f"/local1/bryanzhou008/Dialect/multimodal-dialectal-bias/data/image/{prompt_stype}/{dialect}/flux.1-dev"

# Output directory for the prepared dataset.
output_folder = f"/local1/bryanzhou008/Dialect/multimodal-dialectal-bias/mitigation/baselines/diffusion_dpo/a_data/{exp_name}"

# The shell script command template for training (you may adjust this command as needed).
script_command = """#!/bin/bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATA_DIR="{output_dir}"

accelerate launch --num_processes {num_processes} /local1/bryanzhou008/Dialect/multimodal-dialectal-bias/mitigation/baselines/diffusion_dpo/a_src/train/train.py \\
  --pretrained_model_name_or_path=$MODEL_NAME \\
  --train_data_dir=$DATA_DIR \\
  --train_batch_size=1 \\
  --dataloader_num_workers=16 \\
  --gradient_accumulation_steps=128 \\
  --max_train_steps=2000 \\
  --lr_scheduler="constant_with_warmup" --lr_warmup_steps=500 \\
  --learning_rate=1e-8 --scale_lr \\
  --cache_dir="/local1/bryanzhou008/Dialect/multimodal-dialectal-bias/mitigation/baselines/diffusion_dpo/temp_cache/" \\
  --checkpointing_steps 200 \\
  --beta_dpo 5000 \\
  --output_dir="/local1/bryanzhou008/Dialect/multimodal-dialectal-bias/mitigation/baselines/diffusion_dpo/a_checkpoints/{exp_name}" 
""".format(output_dir=output_folder, exp_name=exp_name, num_processes=num_processes)




In [7]:
# Cell 2: Data Preparation and Script Generation

import os
import csv
import json
import shutil
from pathlib import Path

# Create the output training directory (we'll place all processed images and metadata in a "train" subfolder)
train_dir = os.path.join(output_folder, "train")
os.makedirs(train_dir, exist_ok=True)

metadata = []

# Read the CSV file and iterate over the rows (skip header)
with open(csv_file, newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    # For each row, use Dialect_Prompt as the folder for lose images, SAE_Prompt for win images.
    for row_index, row in enumerate(reader):
        dialect_prompt = row["Dialect_Prompt"].strip()
        sae_prompt = row["SAE_Prompt"].strip()
        
        # Determine source directories for this row
        lose_dir = os.path.join(input_images_folder, "dialect_imgs", dialect_prompt)
        win_dir  = os.path.join(input_images_folder, "sae_imgs", sae_prompt)
        
        # For each prompt, take the first num_train_per_prompt images (assumed named "0.jpg", "1.jpg", etc.)
        for i in range(num_train_per_prompt):
            # Define source file paths.
            win_src = os.path.join(win_dir, f"{i}.jpg")
            
            # For "dpo", also get lose image; for "sft", we only need the win image.
            if method.lower() == "dpo":
                lose_src = os.path.join(lose_dir, f"{i}.jpg")
            else:
                lose_src = None  # not used in sft
            
            # Check existence of the required files.
            if not os.path.exists(win_src):
                print(f"Warning: Win image missing at {win_src}; skipping row {row_index}, image {i}.")
                continue
            if method.lower() == "dpo" and not os.path.exists(lose_src):
                print(f"Warning: Lose image missing at {lose_src}; skipping row {row_index}, image {i}.")
                continue
            
            # Define destination filenames to avoid collisions across rows.
            win_dest_name = f"win_{row_index}_{i}.jpg"
            win_dest_path = os.path.join(train_dir, win_dest_name)
            shutil.copy(win_src, win_dest_path)
            
            if method.lower() == "dpo":
                lose_dest_name = f"lose_{row_index}_{i}.jpg"
                lose_dest_path = os.path.join(train_dir, lose_dest_name)
                shutil.copy(lose_src, lose_dest_path)
            
            prompt_text = dialect_prompt  # Use the Dialect_Prompt as caption.
            
            if method.lower() == "dpo":
                # For dpo, create two metadata entries: one for win and one for lose.
                win_entry = {
                    "file_name": win_dest_name,
                    "jpg_0": win_dest_name,
                    "jpg_1": lose_dest_name,
                    "label_0": 1,
                    "caption": prompt_text
                }
                lose_entry = {
                    "file_name": lose_dest_name,
                    "jpg_0": win_dest_name,
                    "jpg_1": lose_dest_name,
                    "label_0": 1,
                    "caption": prompt_text
                }
                metadata.extend([win_entry, lose_entry])
            else:  # sft: use only win images.
                win_entry = {
                    "file_name": win_dest_name,
                    "jpg_0": win_dest_name,
                    "caption": prompt_text
                }
                metadata.append(win_entry)

# Write the metadata.jsonl file in the train directory.
metadata_path = os.path.join(train_dir, "metadata.jsonl")
with open(metadata_path, "w", encoding="utf-8") as f:
    for entry in metadata:
        f.write(json.dumps(entry) + "\n")

print(f"Dataset created at {output_folder}")
print(f"Total metadata entries: {len(metadata)}")

# Create the training shell script using the script_command from Cell 1.
script_path = os.path.join(output_folder, "run_training.sh")
with open(script_path, "w", encoding="utf-8") as f:
    f.write(script_command)
    
# Make the script executable.
os.chmod(script_path, 0o755)
print(f"Training script created at {script_path}")


Dataset created at /local1/bryanzhou008/Dialect/multimodal-dialectal-bias/mitigation/baselines/diffusion_dpo/a_data/ine-dpo-basic-2-2-2-9
Total metadata entries: 1188
Training script created at /local1/bryanzhou008/Dialect/multimodal-dialectal-bias/mitigation/baselines/diffusion_dpo/a_data/ine-dpo-basic-2-2-2-9/run_training.sh
