In [3]:
import model_loader
import pipeline
from PIL import Image
from pathlib import Path
from transformers import CLIPTokenizer
import torch

DEVICE = "cpu"

ALLOW_CUDA = True
ALLOW_MPS = True

if torch.cuda.is_available() and ALLOW_CUDA:
    DEVICE = "cuda"
elif (torch.has_mps or torch.backends.mps.is_available()) and ALLOW_MPS:
    DEVICE = "mps"
print(f"Using device: {DEVICE}")

tokenizer = CLIPTokenizer("../data/vocab.json", merges_file="../data/merges.txt")
model_file = "../data/v1-5-pruned-emaonly.ckpt"
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)


Using device: mps


/Users/kaimao/Desktop/DIFFUSION/pytorch-stable-diffusion/venv/lib/python3.10/site-packages/lightning_fabric/__init__.py:41: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.


In [4]:
from peft import LoraConfig, get_peft_model

In [5]:
print(models)

total_params = 0
for model_name, model in models.items():
    print(f"model_name: {model_name} params: {sum(p.numel() for p in model.parameters())}")
    total_params += sum(p.numel() for p in model.parameters())
print("Total parameters:", total_params)

{'clip': CLIP(
  (embedding): CLIPEmbedding(
    (token_embedding): Embedding(49408, 768)
  )
  (layers): ModuleList(
    (0-11): 12 x CLIPLayer(
      (layernorm_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attention): SelfAttention(
        (in_proj): Linear(in_features=768, out_features=2304, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (layernorm_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (linear_1): Linear(in_features=768, out_features=3072, bias=True)
      (linear_2): Linear(in_features=3072, out_features=768, bias=True)
    )
  )
  (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
), 'encoder': VAE_Encoder(
  (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): VAE_ResidualBlock(
    (groupnorm_1): GroupNorm(32, 128, eps=1e-05, affine=True)
    (conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (groupnorm_2): Grou

In [6]:
diffusion_model = models['diffusion']

print(diffusion_model)
print(f"params: {sum(p.numel() for p in diffusion_model.parameters())}")

config = LoraConfig(
    r=4,
    lora_alpha=16,
    target_modules=[
        "attention_1.in_proj",
        "attention_1.out_proj",
        "attention_2.q_proj",
        "attention_2.k_proj",
        "attention_2.v_proj",
        "attention_2.out_proj",
        "conv_feature",
        "conv_merged",
    ],
    lora_dropout=0.0,
)

diffusion_model_lora = get_peft_model(diffusion_model, config)
print(f"params num {sum(p.numel() for p in diffusion_model_lora.parameters())}")
print(f"diffusion model trainable num {sum(p.numel() for p in diffusion_model_lora.parameters() if p.requires_grad)}")

Diffusion(
  (time_embedding): TimeEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (unet): UNET(
    (encoders): ModuleList(
      (0): SwitchSequential(
        (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1-2): 2 x SwitchSequential(
        (0): UNET_ResidualBlock(
          (groupnorm_feature): GroupNorm(32, 320, eps=1e-05, affine=True)
          (conv_feature): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (linear_time): Linear(in_features=1280, out_features=320, bias=True)
          (groupnorm_merged): GroupNorm(32, 320, eps=1e-05, affine=True)
          (conv_merged): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (residual_layer): Identity()
        )
        (1): UNET_AttentionBlock(
          (groupnorm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (conv_inp

In [None]:
"""
LoRA Training Script for Stable Diffusion

- Loads images and captions
- Prepares data for training
- Sets up model with LoRA
- Trains only LoRA parameters
- Saves LoRA weights compatible with pipeline.generate
"""

import os
import torch
import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPTokenizer
from torchvision import transforms
from PIL import Image
from pipeline import get_time_embedding

# --- 1. Dataset ---
class StyleDataset(Dataset):
    def __init__(self, image_dir, captions_file, image_size=512):
        self.image_dir = image_dir
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        # Read captions
        with open(captions_file, 'r') as f:
            lines = f.readlines()
        self.samples = []
        for line in lines:
            if ':' in line:
                img, cap = line.strip().split(':', 1)
                self.samples.append((img.strip(), cap.strip()))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_name, caption = self.samples[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        
        return {'image':image, 'caption':caption}

# --- 2. Text Encoder Helper (CLIP) ---
def encode_text(clip_model, tokenizer, captions, device):
    # Tokenize and encode captions using CLIP
    # Assumes CLIP model has encode_text method
    tokens = tokenizer(captions, padding=True, return_tensors="pt").to(device)
    with torch.no_grad():
        text_embeds = clip_model(tokens.input_ids)
    return text_embeds

models_with_lora_diffusion = models
models_with_lora_diffusion["diffusion"] = diffusion_model_lora
batch_size=4
epochs=10
lr=1e-4
device=None

if device is None:
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif (hasattr(torch, "has_mps") and torch.has_mps) or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
        device = "mps"
print(f"Using device: {device}")

# Tokenizer (use local vocab/merges files)
tokenizer = CLIPTokenizer("data/vocab.json", merges_file="data/merges.txt")

# Prepare dataset and dataloader
dataset = StyleDataset(image_dir, captions_file)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


lora_params = [p for p in diffusion_model_lora.parameters() if p.requires_grad]
# Training loop
optimizer = torch.optim.Adam(lora_params, lr=lr)

for epoch in range(epochs):
    models["diffusion"].train()
    for batch in dataloader:
            images, captions = batch['image'], batch['caption']
            
            # Move images to device
            images = images.to(device)
            
            # 1. Tokenize captions properly (batch of strings)
            tokens = tokenizer(
                list(captions),
                padding="max_length",
                max_length=77,
                truncation=True,
                return_tensors="pt"
            ).input_ids.to(device)
            
            # 2. Get text embeddings from CLIP (context for cross-attention)
            with torch.no_grad():
                context = models["clip"](tokens)  # (batch, 77, 768)
            
            # 3. Encode images to latent space via VAE
            with torch.no_grad():
                noise = torch.randn_like(mean)
                latents = models["encoder"](images, noise)
                # VAE outputs 8 channels (mean + log_var), sample from distribution
                mean, log_var = torch.chunk(latents, 2, dim=1)
                log_var = torch.clamp(log_var, -30, 20)
                std = torch.exp(0.5 * log_var)
                latents = mean + std * noise
                latents = latents * 0.18215  # SD scaling factor
            
            # 4. Sample random timestep for each image in batch
            batch_size = images.shape[0]
            timesteps = torch.randint(0, 1000, (batch_size,), device=device)
            time_embedding = get_time_embedding(timesteps).to(device)
            
            # 5. Add noise to latents (forward diffusion)
            noise = torch.randn_like(latents)
            noisy_latents = add_noise(latents, noise, timesteps)  # You need this function
            
            # 6. Forward pass - predict noise
            optimizer.zero_grad()
            predicted_noise = models["diffusion"].unet(noisy_latents, context, time_embedding)
            
            # 7. Compute loss (MSE between predicted and actual noise)
            loss = F.mse_loss(predicted_noise, noise)
            
            loss.backward()
            optimizer.step()
    
    print(f"Epoch {epoch+1}/{epochs} completed. Loss: {loss.item():.4f}")


# Explicitly move all models to CUDA before saving LoRA weights
if device == "cuda":
    for model in models.values():
        model.to("cuda")
save_models_lora(models, output_path)
print(f"LoRA weights saved to {output_path} (saved from {device})")
print(f"Using device: {device}")

# Dataset and DataLoader
dataset = StyleDataset(image_dir, captions_file)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

# Load models with LoRA
models = preload_models_with_lora(
    ckpt_path=ckpt_path,
    device=device,
    lora_rank=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    apply_to_diffusion=True,
    apply_to_clip=True
)
diffusion = models['diffusion']
clip = models['clip']

# Remove duplicate import and initialization of CLIPTokenizer

# Only train LoRA parameters
for param in diffusion.parameters():
    param.requires_grad = False
lora_params = get_lora_parameters(diffusion)
for param in lora_params:
    param.requires_grad = True
optimizer = torch.optim.AdamW(lora_params, lr=lr)

# --- Training Loop ---
for epoch in range(epochs):
    diffusion.train()
    total_loss = 0
    for images, captions in dataloader:
        images = images.to(device)
        # Encode text
        text_embeds = encode_text(clip, tokenizer, list(captions), device)

        latents = diffusion.encoder(images)
        noise = torch.randn_like(latents)
        noisy_latents = latents + noise
        pred = diffusion.decoder(noisy_latents)
        loss = torch.nn.functional.mse_loss(pred, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
    avg_loss = total_loss / len(dataset)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
    # Optionally save intermediate LoRA weights
    if (epoch+1) % 5 == 0:
        save_models_lora(models, f"{output_path}_epoch{epoch+1}.pt")
# Save final LoRA weights
save_models_lora(models, output_path)
print(f"Training complete. LoRA weights saved to {output_path}")



params num 862176708
trainable num 2655744


