In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from diffusers import StableDiffusionPipeline, DDPMScheduler
from diffusers.optimization import get_scheduler
from torchvision import transforms
from tqdm import tqdm

from peft import LoraConfig, get_peft_model


# ============================================
# 1. Config
# ============================================
model_id = "sd-legacy/stable-diffusion-v1-5"
dataset_dir = "./vsr_sd"        # 你的数据目录，里面有 images/ 和 captions.txt
lora_rank = 4
train_steps = 500
learning_rate = 1e-4
batch_size = 1
resolution = 512
output_dir = "./lora_output_conv_all"

device = "mps" if torch.backends.mps.is_available() else "cuda"
print("Using device:", device)


# ============================================
# 2. Load model (Stable Diffusion v1.5)
# ============================================
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float32 if device == "mps" else torch.float16,
)
pipe.to(device)

# Freeze base model
pipe.unet.requires_grad_(False)

# ============================================
# 3. Apply LoRA to UNet cross-attention layers
# ============================================
lora_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_rank * 2,
    target_modules=["conv1", "conv2"],
    lora_dropout=0.0,
    bias="none",
)

pipe.unet = get_peft_model(pipe.unet, lora_config)
pipe.unet.print_trainable_parameters()


# ============================================
# 4. Dataset
# ============================================
class CustomDataset(Dataset):
    def __init__(self, root):
        self.image_dir = os.path.join(root, "images")
        self.caption_file = os.path.join(root, "captions.txt")

        with open(self.caption_file, "r") as f:
            lines = f.read().splitlines()

        self.data = []
        for line in lines:
            filename, caption = line.split("\t")
            self.data.append((filename, caption))

        self.preprocess = transforms.Compose([
            transforms.Resize((resolution, resolution)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        filename, caption = self.data[idx]
        image = Image.open(os.path.join(self.image_dir, filename)).convert("RGB")
        image = self.preprocess(image)
        return image, caption


dataset = CustomDataset(dataset_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


# ============================================
# 5. Optimizer + Scheduler
# ============================================
optimizer = torch.optim.Adam(pipe.unet.parameters(), lr=learning_rate)

noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=train_steps,
)




Using device: mps


Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 32.04it/s]


trainable params: 1,958,400 || all params: 861,479,364 || trainable%: 0.2273


In [2]:
print(pipe.unet)

PeftModel(
  (base_model): LoraModel(
    (model): UNet2DConditionModel(
      (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (time_proj): Timesteps()
      (time_embedding): TimestepEmbedding(
        (linear_1): Linear(in_features=320, out_features=1280, bias=True)
        (act): SiLU()
        (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (down_blocks): ModuleList(
        (0): CrossAttnDownBlock2D(
          (attentions): ModuleList(
            (0-1): 2 x Transformer2DModel(
              (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
              (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
              (transformer_blocks): ModuleList(
                (0): BasicTransformerBlock(
                  (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
                  (attn1): Attention(
                    (to_q): Linear(in_features=320, out_features=320, bias=False)
         

In [None]:
import json
os.makedirs(output_dir, exist_ok=True)
lora_config.save_pretrained(output_dir)

train_config = {
    "lora_rank": 4,
    "train_steps": train_steps,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "resolution": resolution
}

with open(f"./{output_dir}/train_config.json", "w") as f:
    json.dump(train_config, f, indent=2)

In [None]:
import wandb

wandb.init(
    project="stable-diffusion-training",   # change this
    name="sd-lora-unet-bs1-lr1e4-conv-all",           # optional run name
    config={
        "train_steps": train_steps,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "resolution": resolution,
        "optimizer": optimizer.__class__.__name__,
        "scheduler": lr_scheduler.__class__.__name__,
        "model": "sd-legacy/stable-diffusion-v1-5",
        "train_unet": True,
        "train_text_encoder": False,
    }
)


# ============================================
# 6. Training Loop
# ============================================
pipe.text_encoder.requires_grad_(False)
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer

pipe.unet.train()

global_step = 0

for epoch in range(100):  # loop until reaching steps
    for batch in dataloader:
        if global_step >= train_steps:
            break

        images, captions = batch
        images = images.to(device)

        # Encode text
        inputs = tokenizer(
            list(captions),
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        input_ids = inputs.input_ids.to(device)
        encoder_hidden_states = text_encoder(input_ids)[0]

        # Add noise
        with torch.no_grad():
            latents = pipe.vae.encode(images).latent_dist.sample()
            latents = latents * pipe.vae.config.scaling_factor

        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (batch_size,), device=device).long()

        # 2. Sample noise
        noise = torch.randn_like(latents)

        # 3. Add noise according to timestep
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Predict noise
        noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample

        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        global_step += 1

        wandb.log(
            {
                "loss": loss.item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "step": global_step,
            },
            step=global_step,
        )

        if global_step % 50 == 0:
            print(f"Step {global_step} / {train_steps}, Loss = {loss.item():.4f}")

    if global_step >= train_steps:
        break


# ============================================
# 7. Save LoRA weights
# ============================================
os.makedirs(output_dir, exist_ok=True)

pipe.unet.save_pretrained(output_dir)

print("Training finished! LoRA saved to:", output_dir)

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


Step 50 / 2000, Loss = 0.0838
Step 100 / 2000, Loss = 0.0276
Step 150 / 2000, Loss = 0.7123
Step 200 / 2000, Loss = 0.0116
Step 250 / 2000, Loss = 0.1197
Step 300 / 2000, Loss = 0.0057
Step 350 / 2000, Loss = 0.1925
Step 400 / 2000, Loss = 0.0505
Step 450 / 2000, Loss = 0.0771
Step 500 / 2000, Loss = 0.1492
Step 550 / 2000, Loss = 0.4443
Step 600 / 2000, Loss = 0.0078
Step 650 / 2000, Loss = 0.0941
Step 700 / 2000, Loss = 0.0290
Step 750 / 2000, Loss = 0.2600
Step 800 / 2000, Loss = 0.2163
Step 850 / 2000, Loss = 0.1387
Step 900 / 2000, Loss = 0.0219
Step 950 / 2000, Loss = 0.1483
Step 1000 / 2000, Loss = 0.0756
Step 1050 / 2000, Loss = 0.3233
Step 1100 / 2000, Loss = 0.2288
Step 1150 / 2000, Loss = 0.0265


KeyboardInterrupt: 

In [6]:
pipe.unet.save_pretrained(output_dir)

print("Training finished! LoRA saved to:", output_dir)

Training finished! LoRA saved to: ./lora_output_conv1


In [None]:
# from peft import PeftModel, PeftConfig

# pipe.unet = PeftModel.from_pretrained(pipe.unet, "lora_output")



In [8]:
spatial = [
    "a chicken on the left of a car",
    "a person on the left of a cow",
    "a horse on the right of a man",
    "a man on side of a cat",
    "a chicken near a book",
    "a bicycle on the right of a girl",
    "a dog next to a phone",
    "a sheep next to a bicycle",
    "a pig on the bottom of a candle",
    "a butterfly on the left of a phone"
]

output_dir = f"spatail_lora_{output_dir}"
os.makedirs(output_dir, exist_ok=True)
for prompt in spatial:
    image = pipe(prompt).images[0]  
    image.save(f"./{output_dir}/{prompt}.png")

100%|██████████| 100/100 [02:27<00:00,  1.48s/it]
100%|██████████| 100/100 [02:31<00:00,  1.52s/it]
100%|██████████| 100/100 [02:37<00:00,  1.57s/it]
100%|██████████| 100/100 [02:43<00:00,  1.64s/it]
100%|██████████| 100/100 [02:49<00:00,  1.70s/it]
100%|██████████| 100/100 [02:58<00:00,  1.79s/it]
100%|██████████| 100/100 [03:04<00:00,  1.85s/it]
100%|██████████| 100/100 [03:10<00:00,  1.91s/it]
100%|██████████| 100/100 [03:03<00:00,  1.84s/it]
100%|██████████| 100/100 [03:01<00:00,  1.81s/it]
