# examples
image under ./images folder are genereated by our fintuned model, using sdxl as base model 


# 0. setup GPU server
we suggest, at least you shoud have 16G gpu, and 64G cpu 

# 1. install requerments 
pip install torch torchvision transformers pillow requests clip-interrogator pert
pip install accelerate peft deepspeed

# 2.  prepare annotation data 
using clip or clip to annotate data 
annotate_image('images)

# 3. prepare the metadata.jsonl
 Generate metadata.jsonl file containing image paths, captions, and style information
 generate_metadata('images', 'images', "Azuki, Azuki Styple, Azuki art, azuki")

# 4.  train your lora model with bash comment
#!/bin/bash
accelerate launch \
  --mixed_precision="fp16" \
  --num_processes=1 \
  --num_machines=1 \
  --machine_rank=0 \
  --dynamo_backend="no" \
  train_text_to_image_lora_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --train_data_dir="./images" \
  --resolution=768 \
  --output_dir="./images" \
  --train_batch_size=2 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --num_train_epochs=10 \
  --rank=8 \
  --train_text_encoder \
  --mixed_precision="fp16" \
  --enable_xformers_memory_efficient_attention \
  --validation_epochs=5 \
  --checkpointing_steps=100

# 5. generate azuki style NFT avators
prompt = "sunshine boy, cigarette in mouth, sword in hand, Azuki style, high quality, vivid colors, clean background, single color background" # 你可以修改为任何你想要的提示词

generate_images(
    prompt=prompt,
    output_dir="output",
    num_images=1,
    width=720,
    height=720,
    pretrained_model_path="stabilityai/stable-diffusion-xl-base-1.0", 
    lora_path="images/pytorch_lora_weights.safetensors",
    lora_scale=0.7,
    seed=42 
)

In [None]:
import os
import torch
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from clip_interrogator import Config, Interrogator
import sys

def annotate_image(train_data_dir):
    # Configure paths
    output_caption_dir = f"{train_data_dir}.captions"  # Directory to store captions
    os.makedirs(output_caption_dir, exist_ok=True)

    # Choose between BLIP or CLIP Interrogator
    USE_BLIP = False  # True = use BLIP, False = use CLIP Interrogator

    # ========== Method 1: BLIP Auto Caption Generation ==========
    if USE_BLIP:
        # Load BLIP model
        processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
        model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda")

        def generate_caption_blip(image_path):
            image = Image.open(image_path).convert("RGB")
            inputs = processor(image, return_tensors="pt").to("cuda")
            caption_ids = model.generate(**inputs)
            caption = processor.batch_decode(caption_ids, skip_special_tokens=True)[0]
            return caption

    # ========== Method 2: CLIP Interrogator Caption Generation ==========
    else:
        ci_config = Config(device="cuda")
        ci = Interrogator(ci_config)

        def generate_caption_clip(image_path):
            image = Image.open(image_path).convert("RGB")
            caption = ci.interrogate(image)
            return caption

    # Loop through all images and generate captions
    for filename in os.listdir(train_data_dir):
        if filename.endswith((".png", ".jpg", ".jpeg", ".webp")):
            image_path = os.path.join(train_data_dir, filename)
            caption = generate_caption_blip(image_path) if USE_BLIP else generate_caption_clip(image_path)

            # Generate corresponding txt file
            txt_path = os.path.join(output_caption_dir, f"{os.path.splitext(filename)[0]}.txt")
            with open(txt_path, "w", encoding="utf-8") as f:
                f.write(caption)

            print(f"✅ Generated Caption: {filename} -> {caption}")

In [None]:

import os
import sys
import json
import glob

def generate_metadata(image_dir, captions_path, style_desc):
    """
    Generate metadata.jsonl file containing image paths, captions, and style information

    Args:
        image_dir: Path to image directory
        captions_path: Path to captions file or directory
        style_desc: Additional style description text
    """
    # Ensure input directory exists
    if not os.path.exists(image_dir):
        print(f"Error: Image directory '{image_dir}' does not exist")
        return

    # Get all image files
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp']
    image_files = []
    for ext in image_extensions:
        image_files.extend(glob.glob(os.path.join(image_dir, f"*{ext}")))
        image_files.extend(glob.glob(os.path.join(image_dir, f"*{ext.upper()}")))

    if not image_files:
        print(f"Warning: No image files found in '{image_dir}'")
        return

    # Process caption files
    captions = {}
    if os.path.isfile(captions_path):
        # If captions_path is a file, try to read captions from it
        try:
            with open(captions_path, 'r', encoding='utf-8') as f:
                # Try to read different caption file formats
                if captions_path.endswith('.json') or captions_path.endswith('.jsonl'):
                    for line in f:
                        try:
                            data = json.loads(line.strip())
                            if 'file' in data and 'caption' in data:
                                captions[data['file']] = data['caption']
                        except json.JSONDecodeError:
                            continue
                else:
                    # Assume file format is: filename|caption per line
                    for line in f:
                        parts = line.strip().split('|', 1)
                        if len(parts) == 2:
                            captions[parts[0]] = parts[1]
        except Exception as e:
            print(f"Error reading caption file: {e}")
    elif os.path.isdir(captions_path):
        # If captions_path is a directory, try to find corresponding caption file for each image
        for image_file in image_files:
            base_name = os.path.splitext(os.path.basename(image_file))[0]
            caption_file = os.path.join(captions_path, f"{base_name}.txt")
            if os.path.exists(caption_file):
                try:
                    with open(caption_file, 'r', encoding='utf-8') as f:
                        captions[os.path.basename(image_file)] = f.read().strip()
                except Exception as e:
                    print(f"Error reading {caption_file}: {e}")

    # Create metadata.jsonl file
    parent_dir = os.path.dirname(os.path.abspath(image_dir))
    output_file = os.path.join(parent_dir, "metadata.jsonl")

    with open(output_file, 'w', encoding='utf-8') as f:
        for image_file in sorted(image_files):
            file_name = os.path.basename(image_file)
            rel_path = os.path.relpath(image_file, parent_dir)

            # Get caption for image, use filename if no caption exists
            caption = captions.get(file_name, file_name)

            # Add style description to caption if provided
            if style_desc:
                full_caption = f"{caption}, {style_desc}"
            else:
                full_caption = caption

            # Create metadata entry
            metadata = {
                "file": rel_path,
                "caption": full_caption
            }

            # Write JSON line
            f.write(json.dumps(metadata, ensure_ascii=False) + '\n')

    print(f"Successfully generated metadata.jsonl file with {len(image_files)} image entries")
    print(f"Saved to: {output_file}")


In [None]:
annotate_image('images)

In [None]:
 generate_metadata('images', 'images', "Azuki, Azuki Styple, Azuki art, azuki")

In [None]:

import os
import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from safetensors.torch import load_file

def generate_images(prompt, output_dir="output", num_images=1, width=720, height=720,
                    pretrained_model_path="stabilityai/stable-diffusion-xl-base-1.0",
                    lora_path="images/pytorch_lora_weights.safetensors",
                    lora_scale=0.7, seed=None):
    """
    Generate images with and without LoRA

    Args:
        prompt (str): The prompt for image generation
        output_dir (str): Output directory
        num_images (int): Number of images to generate for each type
        width (int): Image width
        height (int): Image height
        pretrained_model_path (str): Path to pretrained model
        lora_path (str): Path to LoRA weights file
        lora_scale (float): LoRA intensity
        seed (int): Random seed, set to None for random seed
    """
    os.makedirs(output_dir, exist_ok=True)

    # Set random seed for reproducibility
    if seed is not None:
        torch.manual_seed(seed)

    # Load base SDXL model
    print("Loading base SDXL model...")

    # Use loading method similar to train.py
    pipe = StableDiffusionXLPipeline.from_pretrained(
        pretrained_model_path,
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True
    )

    # Optimize inference speed
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.to("cuda")
    pipe.enable_xformers_memory_efficient_attention()

    # First generate original images (without LoRA)
    print(f"Generating {num_images} images without LoRA...")
    for i in range(num_images):
        current_seed = torch.initial_seed() if seed is None else seed + i
        generator = torch.Generator(device="cuda").manual_seed(current_seed)

        image = pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=30,
            generator=generator
        ).images[0]

        # Save image
        image_path = os.path.join(output_dir, f"no_lora_seed_{current_seed}.png")
        image.save(image_path)
        print(f"Saved image to: {image_path}")

    # Load LoRA weights
    print(f"Loading LoRA weights: {lora_path}")
    if os.path.exists(lora_path):
        try:
            # Use diffusers' LoRA loading method
            pipe.load_lora_weights(lora_path)
            pipe.fuse_lora(lora_scale=lora_scale)
            lora_loaded = True
        except Exception as e:
            print(f"LoRA loading error: {e}")
            print("LoRA weights are not compatible with SDXL model architecture. This usually happens when trying to use LoRA trained for SD1.5 with SDXL model.")
            print("Continuing with base model generation...")
            lora_loaded = False
    else:
        print(f"Error: LoRA file {lora_path} does not exist!")
        lora_loaded = False

    # Generate images with LoRA
    if lora_loaded:
        print(f"Generating {num_images} images with LoRA...")
        for i in range(num_images):
            current_seed = torch.initial_seed() if seed is None else seed + i
            generator = torch.Generator(device="cuda").manual_seed(current_seed)

            image = pipe(
                prompt=prompt,
                width=width,
                height=height,
                num_inference_steps=30,
                generator=generator
            ).images[0]

            # Save image
            image_path = os.path.join(output_dir, f"with_lora_seed_{current_seed}.png")
            image.save(image_path)
            print(f"Saved image to: {image_path}")

        # Unload LoRA weights
        pipe.unfuse_lora()
    else:
        print("Skipping LoRA image generation due to failed LoRA loading.")

    print("Image generation completed!")

In [None]:
prompt = "sunshine boy, cigarette in mouth, sword in hand, Azuki style, high quality, vivid colors, clean background, single color background" # 你可以修改为任何你想要的提示词

generate_images(
    prompt=prompt,
    output_dir="output",
    num_images=1,
    width=720,
    height=720,
    pretrained_model_path="stabilityai/stable-diffusion-xl-base-1.0",
    lora_path="images/pytorch_lora_weights.safetensors",
    lora_scale=0.7,
    seed=42
)