In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from diffusers import StableDiffusionPipeline, DDPMScheduler
from diffusers.optimization import get_scheduler
from torchvision import transforms
from tqdm import tqdm

from peft import LoraConfig, get_peft_model


# ============================================
# 1. Config
# ============================================
model_id = "sd-legacy/stable-diffusion-v1-5"
dataset_dir = "./vsr_sd"        # 你的数据目录，里面有 images/ 和 captions.txt
lora_rank = 4
train_steps = 500
learning_rate = 1e-4
batch_size = 1
resolution = 512
output_dir = "./lora_output"

device = "mps" if torch.backends.mps.is_available() else "cuda"
print("Using device:", device)


# ============================================
# 2. Load model (Stable Diffusion v1.5)
# ============================================
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float32 if device == "mps" else torch.float16,
)
pipe.to(device)

# Freeze base model
pipe.unet.requires_grad_(False)

# ============================================
# 3. Apply LoRA to UNet cross-attention layers
# ============================================
lora_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_rank * 2,
    target_modules=["to_q", "to_k", "to_v", "to_out.0"],
    lora_dropout=0.0,
    bias="none",
)

pipe.unet = get_peft_model(pipe.unet, lora_config)
pipe.unet.print_trainable_parameters()


# ============================================
# 4. Dataset
# ============================================
class CustomDataset(Dataset):
    def __init__(self, root):
        self.image_dir = os.path.join(root, "images")
        self.caption_file = os.path.join(root, "captions.txt")

        with open(self.caption_file, "r") as f:
            lines = f.read().splitlines()

        self.data = []
        for line in lines:
            filename, caption = line.split("\t")
            self.data.append((filename, caption))

        self.preprocess = transforms.Compose([
            transforms.Resize((resolution, resolution)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        filename, caption = self.data[idx]
        image = Image.open(os.path.join(self.image_dir, filename)).convert("RGB")
        image = self.preprocess(image)
        return image, caption


dataset = CustomDataset(dataset_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


# ============================================
# 5. Optimizer + Scheduler
# ============================================
optimizer = torch.optim.Adam(pipe.unet.parameters(), lr=learning_rate)

noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=train_steps,
)




  from .autonotebook import tqdm as notebook_tqdm


Using device: mps


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 22.49it/s]


trainable params: 797,184 || all params: 860,318,148 || trainable%: 0.0927


In [11]:
# ============================================
# 6. Training Loop
# ============================================
pipe.text_encoder.requires_grad_(False)
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer

pipe.unet.train()

global_step = 0

for epoch in range(10):  # loop until reaching steps
    for batch in dataloader:
        if global_step >= train_steps:
            break

        images, captions = batch
        images = images.to(device)

        # Encode text
        inputs = tokenizer(
            list(captions),
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        input_ids = inputs.input_ids.to(device)
        encoder_hidden_states = text_encoder(input_ids)[0]

        # Add noise
        with torch.no_grad():
            latents = pipe.vae.encode(images).latent_dist.sample()
            latents = latents * pipe.vae.config.scaling_factor

        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (batch_size,), device=device).long()

        # 2. Sample noise
        noise = torch.randn_like(latents)

        # 3. Add noise according to timestep
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Predict noise
        noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample

        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        global_step += 1

        if global_step % 50 == 0:
            print(f"Step {global_step} / {train_steps}, Loss = {loss.item():.4f}")

    if global_step >= train_steps:
        break


# ============================================
# 7. Save LoRA weights
# ============================================
os.makedirs(output_dir, exist_ok=True)
pipe.unet.save_pretrained(output_dir)

print("Training finished! LoRA saved to:", output_dir)

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


Step 50 / 500, Loss = 0.5066
Step 100 / 500, Loss = 0.0054
Step 150 / 500, Loss = 0.0121
Step 200 / 500, Loss = 0.0305
Step 250 / 500, Loss = 0.5797
Step 300 / 500, Loss = 0.0741
Step 350 / 500, Loss = 0.3555
Step 400 / 500, Loss = 0.1010
Step 450 / 500, Loss = 0.0277
Step 500 / 500, Loss = 0.0053
Training finished! LoRA saved to: ./lora_output


In [4]:
from peft import PeftModel, PeftConfig

pipe.unet = PeftModel.from_pretrained(pipe.unet, "lora_output")



In [7]:
spatial = [
    "a chicken on the left of a car",
    "a person on the left of a cow",
    "a horse on the right of a man",
    "a man on side of a cat",
    "a chicken near a book",
    "a bicycle on the right of a girl",
    "a dog next to a phone",
    "a sheep next to a bicycle",
    "a pig on the bottom of a candle",
    "a butterfly on the left of a phone"
]

output_dir = "spatail_lora_2"
os.makedirs(output_dir, exist_ok=True)
for prompt in spatial:
    image = pipe(prompt).images[0]  
    image.save(f"./{output_dir}/{prompt}.png")

100%|██████████| 50/50 [01:06<00:00,  1.33s/it]
100%|██████████| 50/50 [01:08<00:00,  1.36s/it]
100%|██████████| 50/50 [01:28<00:00,  1.77s/it]
100%|██████████| 50/50 [02:29<00:00,  3.00s/it]
100%|██████████| 50/50 [02:43<00:00,  3.28s/it]
100%|██████████| 50/50 [02:53<00:00,  3.48s/it]
100%|██████████| 50/50 [02:45<00:00,  3.32s/it]
100%|██████████| 50/50 [02:54<00:00,  3.49s/it]
100%|██████████| 50/50 [02:56<00:00,  3.53s/it]
100%|██████████| 50/50 [02:50<00:00,  3.41s/it]
