Skip to content

jingtaozhan/PromptReformulate

Repository files navigation

PromptReformulate

This is the official code for `Capability-aware Prompt Reformulation Learning for Text-to-Image Generation'.

Capability-aware Prompt Reformulation (CAPR) innovatively integrates user capability into the reformulation process through two key components: the Conditional Reformulation Model (CRM) and Configurable Capability Features (CCF). CRM reformulates prompts according to a specified user capability, as represented by CCF. The CCF, in turn, offers the flexibility to tune and guide the CRM's behavior. This enables CAPR to effectively learn diverse reformulation strategies across various user capacities and to simulate high-capability user reformulation during inference.

Requirements

This repo is developed with PyTorch, clip, and skopt. They should be installed manually due to the requirement of platform-specific custom configuration.

Training

torchrun --nproc_per_node=4 \
    train.py \
    --task_mode single \
    --train_data_path ${train_data_path} \
    --valid_data_path ${valid_data_path} \
    --model_name_or_path "TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T" \
    --bf16 True \
    --per_device_train_batch_size 10 \
    --per_device_eval_batch_size 32 \
    --gradient_accumulation_steps 1 \
    --dataloader_num_workers 4 \
    --output_dir $output_dir \
    --num_train_epochs 2 \
    --save_strategy "epoch" \
    --evaluation_strategy "epoch" \
    --save_total_limit 20 \
    --learning_rate $lr \
    --optim adamw_hf \
    --seed 2023 \
    --adam_beta1 0.9 \
    --adam_beta2 0.95 \
    --adam_epsilon 1e-5 \
    --weight_decay 0.1 \
    --warmup_ratio 0.02 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10 \
    --tf32 True \
    --gradient_checkpointing True \
    --remove_unused_columns False

Tuning CCF

torchrun --nproc_per_node=1 \
    bayes_tune.py \
    --rewrite_input_data_path ${rewrite_input_data_path} \
    --score_input_data_path ${score_input_data_path} \
    --rewrite_model_name_or_path $rewrite_model_name_or_path \
    --bf16 True \
    --no_repeat_ngram_size 3 \
    --generate_model_name_or_path "CompVis/stable-diffusion-v1-4" \
    --generate_seeds 0 1 \
    --generate_inference_steps 20 \
    --imagereward 1 \
    --log_output_dir $log_output_dir \
    --n_calls 50 \
    --min_clip 0 \
    --max_clip 9 \
    --min_aesthetic 0 \
    --max_aesthetic 9 \
    --min_overall 9 \
    --max_overall 9 \
    --min_phrase_cnt 1 \
    --max_phrase_cnt 10 \
    --dataloader_num_workers 0 \
    --output_dir $output_dir \
    --seed 2022 \
    --per_device_eval_batch_size 8 \
    --remove_unused_columns False \
    --overwrite_output_dir True

Reformulation

torchrun --nproc_per_node=1 \
    evaluate.py \
    --predict_with_generate True \
    --model_name_or_path $model_name_or_path \
    --data_path $data_path \
    --delta_clip $delta_clip \
    --delta_aesthetic $delta_aesthetic \
    --delta_overall $delta_overall \
    --delta_phrase_cnt $delta_phrase_cnt \
    --output_name "prompt.json" \
    --dataloader_num_workers 1 \
    --output_dir $output_dir \
    --seed 2022 \
    --per_device_eval_batch_size 8 \
    --logging_steps 20 \
    --remove_unused_columns False \

Evaluation

torchrun --nproc_per_node=1 \
    generate_image.py \
    --model-path "CompVis/stable-diffusion-v1-4" \
    --input-file $reformulate_prompt_path \
    --output-root ${output_root}/images \
    --num-inference-steps 50 \
    --num-gpus 1

torchrun --nproc_per_node=1 \
    evaluate_image.py \
    --models "hpsv2" "ImageReward" \
    --prompt_path $user_input_prompt_path  \
    --image_root ${output_root}/images \
    --output_dir ${output_root}/scores

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages