Skip to content

chaofengc/TexForce

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Enhancing Diffusion Models with Text-Encoder Reinforcement Learning

Official PyTorch codes for paper Enhancing Diffusion Models with Text-Encoder Reinforcement Learning

arXiv google colab logo huggingface visitors

teaser_img

Requirements & Installation

  • Clone the repo and install required packages with
# git clone this repository
git clone https://github.com/chaofengc/TexForce.git
cd TexForce 

# create new anaconda env
conda create -n texforce python=3.8
source activate texforce 

# install python dependencies
pip3 install -r requirements.txt

Results on SDXL-Turbo

We also applied our method to the recent model sdxl-turbo. The model is trained with ImageReward feedback through direct back-propagation to save training time. Test with the following codes

## Note: sdturbo requires latest diffusers installed from source with the following command
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
pipe.load_lora_weights('chaofengc/sdxl-turbo_texforce')

pt = ['a photo of a cat.']
img = pipe(prompt=pt, num_inference_steps=1, guidance_scale=0.0).images[0]

Here are some example results:

sdxl-turbo sdxl-turbo + TexForce
A photo of a cat.
An astronaut riding a horse.
water bottle.

Results on SD-Turbo

We applied our method to the recent model sdturbo. The model is trained with Q-Instruct feedback through direct back-propagation to save training time. Test with the following codes

## Note: sdturbo requires latest diffusers>=0.24.0 with AutoPipelineForText2Image class

from diffusers import AutoPipelineForText2Image
from peft import PeftModel
import torch

pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
PeftModel.from_pretrained(pipe.text_encoder, 'chaofengc/sd-turbo_texforce')

pt = ['a photo of a cat.']
img = pipe(prompt=pt, num_inference_steps=1, guidance_scale=0.0).images[0]

Here are some example results:

sd-turbo sd-turbo + TexForce
A photo of a cat.
A photo of a dog.
A photo of a boy, colorful.

Quick Test

You may simply load the pretrained lora weights with the following code block to improve performance of original stable diffusion model:

from diffusers import StableDiffusionPipeline
from diffusers import DDIMScheduler 
from peft import PeftModel
import torch

def load_model_weights(pipe, weight_path, model_type):
    if model_type == 'text+lora':
        text_encoder = pipe.text_encoder
        PeftModel.from_pretrained(text_encoder, weight_path)
    elif model_type == 'unet+lora':
        pipe.unet.load_attn_procs(weight_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

load_model_weights(pipe, './lora_weights/sd14_refl/', 'unet+lora')
load_model_weights(pipe, './lora_weights/sd14_texforce/', 'text+lora')

prompt = ['a painting of a dog.']
img = pipe(prompt).images[0]

Here are some example results:

SDv1.4 ReFL TexForce ReFL+TexForce
astronaut drifting afloat in space, in the darkness away from anyone else, alone, black background dotted with stars, realistic
portrait of a cute cyberpunk cat, realistic, professional
a coffee mug made of cardboard

Citation

If you find this code useful for your research, please cite our paper:

@article{chen2023texforce,
  title={Enhancing Diffusion Models with Text-Encoder Reinforcement Learning},
  author={Chaofeng Chen and Annan Wang and Haoning Wu and Liang Liao and Wenxiu Sun and Qiong Yan and Weisi Lin},
  year={2023},
  eprint={2311.15657},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}

License

This work is licensed under NTU S-Lab License 1.0 and a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

Creative Commons License

About

Official PyTorch codes for "Enhancing Diffusion Models with Text-Encoder Reinforcement Learning"

Resources

License

Unknown, Unknown licenses found

Licenses found

Unknown
LICENSE
Unknown
LICENSE-S-Lab

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published