In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from diffusers import StableDiffusionPipeline, DDPMScheduler
from diffusers.optimization import get_scheduler
from torchvision import transforms
from tqdm import tqdm

from peft import LoraConfig, get_peft_model


# ============================================
# 1. Config
# ============================================
model_id = "sd-legacy/stable-diffusion-v1-5"
dataset_dir = "./vsr_sd"        # 你的数据目录，里面有 images/ 和 captions.txt
lora_rank = 4
train_steps = 500
learning_rate = 1e-4
batch_size = 1
resolution = 512
experiment_name = "sd_lora_bs1_lr1e_4_conv_up"
output_dir = f"./{experiment_name}" #MODIFY

device = "mps" if torch.backends.mps.is_available() else "cuda"
print("Using device:", device)


# ============================================
# 2. Load model (Stable Diffusion v1.5)
# ============================================
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float32 if device == "mps" else torch.float16,
)
pipe.to(device)



  from .autonotebook import tqdm as notebook_tqdm


Using device: mps


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 27.95it/s]


StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.36.0",
  "_name_or_path": "sd-legacy/stable-diffusion-v1-5",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": true,
  "safety_checker": [
    "stable_diffusion",
    "StableDiffusionSafetyChecker"
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [None]:
lora_module_list = []

# MODIFY HERE
for name, module in pipe.unet.named_modules():
    if 'resnet' in name and 'up_blocks' in name and 'conv' in name and not 'conv_shortcut' in name:
        print(name)
        lora_module_list.append(name)

print(lora_module_list)

# Freeze base model
pipe.unet.requires_grad_(False)

# ============================================
# 3. Apply LoRA to UNet cross-attention layers
# ============================================
lora_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_rank * 2,
    target_modules=lora_module_list, #MODIFY
    lora_dropout=0.0,
    bias="none",
)

pipe.unet = get_peft_model(pipe.unet, lora_config)
pipe.unet.print_trainable_parameters()

print(pipe.unet)

down_blocks.0.resnets.0.conv1
down_blocks.0.resnets.0.conv2
down_blocks.0.resnets.1.conv1
down_blocks.0.resnets.1.conv2
down_blocks.1.resnets.0.conv1
down_blocks.1.resnets.0.conv2
down_blocks.1.resnets.1.conv1
down_blocks.1.resnets.1.conv2
down_blocks.2.resnets.0.conv1
down_blocks.2.resnets.0.conv2
down_blocks.2.resnets.1.conv1
down_blocks.2.resnets.1.conv2
down_blocks.3.resnets.0.conv1
down_blocks.3.resnets.0.conv2
down_blocks.3.resnets.1.conv1
down_blocks.3.resnets.1.conv2
['down_blocks.0.resnets.0.conv1', 'down_blocks.0.resnets.0.conv2', 'down_blocks.0.resnets.1.conv1', 'down_blocks.0.resnets.1.conv2', 'down_blocks.1.resnets.0.conv1', 'down_blocks.1.resnets.0.conv2', 'down_blocks.1.resnets.1.conv1', 'down_blocks.1.resnets.1.conv2', 'down_blocks.2.resnets.0.conv1', 'down_blocks.2.resnets.0.conv2', 'down_blocks.2.resnets.1.conv1', 'down_blocks.2.resnets.1.conv2', 'down_blocks.3.resnets.0.conv1', 'down_blocks.3.resnets.0.conv2', 'down_blocks.3.resnets.1.conv1', 'down_blocks.3.resnets.1

In [3]:
# ============================================
# 4. Dataset
# ============================================
class CustomDataset(Dataset):
    def __init__(self, root):
        self.image_dir = os.path.join(root, "images")
        self.caption_file = os.path.join(root, "captions.txt")

        with open(self.caption_file, "r") as f:
            lines = f.read().splitlines()

        self.data = []
        for line in lines:
            filename, caption = line.split("\t")
            self.data.append((filename, caption))

        self.preprocess = transforms.Compose([
            transforms.Resize((resolution, resolution)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        filename, caption = self.data[idx]
        image = Image.open(os.path.join(self.image_dir, filename)).convert("RGB")
        image = self.preprocess(image)
        return image, caption


dataset = CustomDataset(dataset_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


# ============================================
# 5. Optimizer + Scheduler
# ============================================
optimizer = torch.optim.Adam(pipe.unet.parameters(), lr=learning_rate)

noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=train_steps,
)


In [4]:
import json
os.makedirs(output_dir, exist_ok=True)
lora_config.save_pretrained(output_dir)

train_config = {
    "lora_rank": 4,
    "train_steps": train_steps,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "resolution": resolution
}

with open(f"./{output_dir}/train_config.json", "w") as f:
    json.dump(train_config, f, indent=2)

In [5]:
import wandb
wandb.init(
    project="stable-diffusion-training",   # change this
    name=f"{experiment_name}_2",           # optional run name
    config={
        "train_steps": train_steps,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "resolution": resolution,
        "optimizer": optimizer.__class__.__name__,
        "scheduler": lr_scheduler.__class__.__name__,
        "model": "sd-legacy/stable-diffusion-v1-5",
        "train_unet": True,
        "train_text_encoder": False,
    },
    reinit=True
)

[34m[1mwandb[0m: Currently logged in as: [33mykmao1515[0m ([33mkaimao-columbia-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:


VAL_PROMPTS = [
    "a sheep next to a bicycle",
    "a pig on the bottom of a candle",
]

EVAL_SEED = 42
EVAL_STEPS = 30
EVAL_GUIDANCE = 7.5

def save_samples(pipe, step):
    out_dir = f"{output_dir}/train_samples"
    os.makedirs(out_dir, exist_ok=True)

    generator = torch.Generator(device=pipe.device).manual_seed(EVAL_SEED)

    images = pipe(
        EVAL_PROMPTS,
        num_inference_steps=EVAL_STEPS,
        guidance_scale=EVAL_GUIDANCE,
        generator=generator,
    ).images

    for i, img in enumerate(images):
        img.save(f"{out_dir}/step_{step:06d}_{EVAL_PROMPTS[i]}.png")


# ============================================
# 6. Training Loop
# ============================================
pipe.text_encoder.requires_grad_(False)
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer

pipe.unet.train()

global_step = 0

for epoch in range(100):  # loop until reaching steps
    for batch in dataloader:
        if global_step >= train_steps:
            break

        images, captions = batch
        images = images.to(device)

        # Encode text
        inputs = tokenizer(
            list(captions),
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        input_ids = inputs.input_ids.to(device)
        encoder_hidden_states = text_encoder(input_ids)[0]

        # Add noise
        with torch.no_grad():
            latents = pipe.vae.encode(images).latent_dist.sample()
            latents = latents * pipe.vae.config.scaling_factor

        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (batch_size,), device=device).long()

        # 2. Sample noise
        noise = torch.randn_like(latents)

        # 3. Add noise according to timestep
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Predict noise
        noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample

        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        global_step += 1

        wandb.log(
            {
                "loss": loss.item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "step": global_step,
            },
            step=global_step,
        )

        if global_step % 50 == 0:
            print(f"Step {global_step} / {train_steps}, Loss = {loss.item():.4f}")
        if global_step % 100 == 0:
            pipe.unet.eval()
            with torch.no_grad():
                save_samples(pipe, global_step)
            pipe.unet.train()

    if global_step >= train_steps:
        break


# ============================================
# 7. Save LoRA weights
# ============================================
os.makedirs(output_dir, exist_ok=True)

pipe.unet.save_pretrained(output_dir)

print("Training finished! LoRA saved to:", output_dir)

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


Step 50 / 500, Loss = 0.1404
Step 100 / 500, Loss = 0.0137


100%|██████████| 30/30 [01:24<00:00,  2.82s/it]


Step 150 / 500, Loss = 0.6212
Step 200 / 500, Loss = 0.0066


100%|██████████| 30/30 [01:33<00:00,  3.12s/it]


Step 250 / 500, Loss = 0.1959
Step 300 / 500, Loss = 0.2727


100%|██████████| 30/30 [01:48<00:00,  3.63s/it]


Step 350 / 500, Loss = 0.1892
Step 400 / 500, Loss = 0.1832


100%|██████████| 30/30 [02:00<00:00,  4.02s/it]


Step 450 / 500, Loss = 0.1596
Step 500 / 500, Loss = 0.1775


100%|██████████| 30/30 [02:01<00:00,  4.06s/it]


Training finished! LoRA saved to: ./sd_lora_bs1_lr1e_4_conv_down


In [None]:
wandb.finish()

In [None]:
# from peft import PeftModel, PeftConfig

# pipe.unet = PeftModel.from_pretrained(pipe.unet, "lora_output")



In [None]:
spatial = [
    "a chicken on the left of a car",
    "a person on the left of a cow",
    "a horse on the right of a man",
    "a man on side of a cat",
    "a chicken near a book",
    "a bicycle on the right of a girl",
    "a dog next to a phone",
    "a sheep next to a bicycle",
    "a pig on the bottom of a candle",
    "a butterfly on the left of a phone"
]

inference_dir = f"{output_dir}/inference"
os.makedirs(inference_dir , exist_ok=True)
for prompt in spatial:
    image = pipe(prompt).images[0]  
    image.save(f"./{inference_dir}/{prompt}.png")

100%|██████████| 100/100 [02:27<00:00,  1.48s/it]
100%|██████████| 100/100 [02:31<00:00,  1.52s/it]
100%|██████████| 100/100 [02:37<00:00,  1.57s/it]
100%|██████████| 100/100 [02:43<00:00,  1.64s/it]
100%|██████████| 100/100 [02:49<00:00,  1.70s/it]
100%|██████████| 100/100 [02:58<00:00,  1.79s/it]
100%|██████████| 100/100 [03:04<00:00,  1.85s/it]
100%|██████████| 100/100 [03:10<00:00,  1.91s/it]
100%|██████████| 100/100 [03:03<00:00,  1.84s/it]
100%|██████████| 100/100 [03:01<00:00,  1.81s/it]
