Skip to content

CSGO: Content-Style Composition in Text-to-Image Generation πŸ”₯

Notifications You must be signed in to change notification settings

instantX-research/CSGO

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

19 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

CSGO: Content-Style Composition in Text-to-Image Generation

Peng Xing12* Β· Haofan Wang1* Β· Yanpeng Sun2 Β· Qixun Wang1 Β· Xu Bai13 Β· Hao Ai14 Β· Renyuan Huang15 Zechao Li2βœ‰

1InstantX Team Β· 2Nanjing University of Science and Technology Β· 3Xiaohongshu Β· 4Beihang University Β· 5Peking University

*equal contributions, βœ‰corresponding authors

Hugging Face Hugging Face GitHub

Updates πŸ”₯

  • 2024/09/04: πŸ”₯ We released the gradio code. You can simply configure it and use it directly.
  • 2024/09/03: πŸ”₯ We released the online demo on Hugggingface.
  • 2024/09/03: πŸ”₯ We released the pre-trained weight.
  • 2024/09/03: πŸ”₯ We released the initial version of the inference code.
  • 2024/08/30: πŸ”₯ We released the technical report on arXiv
  • 2024/07/15: πŸ”₯ We released the homepage.

Plan πŸ’ͺ

  • technical report
  • inference code
  • pre-trained weight [4_16]
  • pre-trained weight [4_32]
  • online demo
  • pre-trained weight_v2 [4_32]
  • IMAGStyle dataset
  • training code
  • more pre-trained weight

Introduction πŸ“–

This repo, named CSGO, contains the official PyTorch implementation of our paper CSGO: Content-Style Composition in Text-to-Image Generation. We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) πŸ’–.

Pipeline πŸ’»

Capabilities πŸš…

πŸ”₯ Our CSGO achieves image-driven style transfer, text-driven stylized synthesis, and text editing-driven stylized synthesis.

πŸ”₯ For more results, visit our homepage πŸ”₯

Getting Started 🏁

1. Clone the code and prepare the environment

git clone https://github.com/instantX-research/CSGO
cd CSGO

# create env using conda
conda create -n CSGO python=3.9
conda activate CSGO

# install dependencies with pip
# for Linux and Windows users
pip install -r requirements.txt

2. Download pretrained weights

We currently release two model weights.

Mode content token style token Other
csgo.bin 4 16 -
csgo_4_32.bin 4 32 Deepspeed zero2
csgo_4_32_v2.bin 4 32 Deepspeed zero2+more(coming soon)

The easiest way to download the pretrained weights is from HuggingFace:

# first, ensure git-lfs is installed, see: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage
git lfs install
# clone and move the weights
git clone https://huggingface.co/InstantX/CSGO

Our method is fully compatible with SDXL, VAE, ControlNet, and Image Encoder. Please download them and place them in the ./base_models folder.

tips:If you expect to load Controlnet directly using ControlNetPipeline as in CSGO, do the following:

git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic
mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors

3. Inference πŸš€

import torch
from ip_adapter.utils import BLOCKS as BLOCKS
from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS
from PIL import Image
from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    StableDiffusionXLControlNetPipeline,

)
from ip_adapter import CSGO


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

base_model_path =  "./base_models/stable-diffusion-xl-base-1.0"  
image_encoder_path = "./base_models/IP-Adapter/sdxl_models/image_encoder"
csgo_ckpt = "./CSGO/csgo.bin"
pretrained_vae_name_or_path ='./base_models/sdxl-vae-fp16-fix'
controlnet_path = "./base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic"
weight_dtype = torch.float16


vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    base_model_path,
    controlnet=controlnet,
    torch_dtype=torch.float16,
    add_watermarker=False,
    vae=vae
)
pipe.enable_vae_tiling()


target_content_blocks = BLOCKS['content']
target_style_blocks = BLOCKS['style']
controlnet_target_content_blocks = controlnet_BLOCKS['content']
controlnet_target_style_blocks = controlnet_BLOCKS['style']

csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32,
                          target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet=False,controlnet_adapter=True,
                              controlnet_target_content_blocks=controlnet_target_content_blocks, 
                              controlnet_target_style_blocks=controlnet_target_style_blocks,
                              content_model_resampler=True,
                              style_model_resampler=True,
                              load_controlnet=False,

                              )

style_name = 'img_0.png'
content_name = 'img_0.png'
style_image = "../assets/{}".format(style_name)
content_image = Image.open('../assets/{}'.format(content_name)).convert('RGB')

caption ='a small house with a sheep statue on top of it'

num_sample=4

#image-driven style transfer
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
                           prompt=caption,
                           negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
                           content_scale=1.0,
                           style_scale=1.0,
                           guidance_scale=10,
                           num_images_per_prompt=num_sample,
                           num_samples=1,
                           num_inference_steps=50,
                           seed=42,
                           image=content_image.convert('RGB'),
                           controlnet_conditioning_scale=0.6,
                          )

#text-driven stylized synthesis
caption='a cat'
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
                           prompt=caption,
                           negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
                           content_scale=1.0,
                           style_scale=1.0,
                           guidance_scale=10,
                           num_images_per_prompt=num_sample,
                           num_samples=1,
                           num_inference_steps=50,
                           seed=42,
                           image=content_image.convert('RGB'),
                           controlnet_conditioning_scale=0.01,
                          )

#text editing-driven stylized synthesis
caption='a small house'
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
                           prompt=caption,
                           negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
                           content_scale=1.0,
                           style_scale=1.0,
                           guidance_scale=10,
                           num_images_per_prompt=num_sample,
                           num_samples=1,
                           num_inference_steps=50,
                           seed=42,
                           image=content_image.convert('RGB'),
                           controlnet_conditioning_scale=0.4,
                          )

4 Gradio interface βš™οΈ

We also provide a Gradio interface for a better experience, just run by:

# For Linux and Windows users (and macOS)
python gradio/app.py 

If you don't have the resources to configure it, we provide an online demo.

Demos


πŸ”₯ For more results, visit our homepage πŸ”₯

Content-Style Composition

Cycle Translation

Text-Driven Style Synthesis

Text Editing-Driven Style Synthesis

Star History

Star History Chart

Acknowledgements

This project is developed by InstantX Team and Xiaohongshu, all copyright reserved. Sincere thanks to xiaohongshu for providing the computing resources.

Citation πŸ’–

If you find CSGO useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:

@article{xing2024csgo,
       title={CSGO: Content-Style Composition in Text-to-Image Generation}, 
       author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li},
       year={2024},
       journal = {arXiv 2408.16766},
}