In [1]:
# -*- coding: utf-8 -*-
import os
import json
import copy
import argparse
import torch
import wandb
import traceback
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF

# Diffusers 和相關函式庫
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.optimization import get_scheduler
from accelerate import Accelerator, DistributedDataParallelKwargs # 為了 find_unused_parameters
from torch.utils.data import DataLoader
from peft import LoraConfig # <-- Import LoraConfig
from peft import get_peft_model
from peft import PeftModel

# 自訂模組 (需要你有這些檔案)
from config import Config # 假設你的設定檔名為 config.py
from utils.dataset import ControlNetDataset # 假設你的資料集類別在 utils/dataset.py

# 環境變數設定
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # 避免 TensorFlow OneDNN 相關警告



  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [5]:
def parse_args():
    """解析命令行參數"""
    parser = argparse.ArgumentParser(description="Train a ControlNet model using LoRA")
    parser.add_argument("--config", type=str, default="config.py", help="Path to config file")
    # --- 允許命令行覆蓋 Config 中的部分設定 ---
    parser.add_argument("--condition_type", type=str, choices=["canny", "depth", "pose", "seg"], default=None,
                        help="Override config: Type of condition to use (aligns with config)")
    parser.add_argument("--use_text_condition", default=None, type=lambda x: (str(x).lower() == 'true'),
                        help="Override config: Use text prompts (True/False)")
    # --- LoRA 參數覆蓋 ---
    parser.add_argument("--use_lora", default=None, type=lambda x: (str(x).lower() == 'true'),
                        help="Override config: Use LoRA training (True/False)")
    parser.add_argument("--lora_rank", type=int, default=None, help="Override config: LoRA rank")
    parser.add_argument("--lora_alpha", type=int, default=None, help="Override config: LoRA alpha")

    return parser.parse_args()



config = Config()
# --- 設定資料類型 ---
weight_dtype = torch.float32
# --- 載入 ControlNet 基礎模型 ---
try:
    controlnet = ControlNetModel.from_pretrained(
        config.controlnet_model,
        torch_dtype=weight_dtype # 使用基於 accelerator 的 dtype
    )
    print(f"Loaded ControlNet base model from {config.controlnet_model}")
    for name, module in controlnet.named_modules():
        print(name, "->", type(module))
except Exception as e:
    print(f"Error loading ControlNet model: {e}")


Loaded ControlNet base model from lllyasviel/sd-controlnet-seg
 -> <class 'diffusers.models.controlnet.ControlNetModel'>
conv_in -> <class 'torch.nn.modules.conv.Conv2d'>
time_proj -> <class 'diffusers.models.embeddings.Timesteps'>
time_embedding -> <class 'diffusers.models.embeddings.TimestepEmbedding'>
time_embedding.linear_1 -> <class 'torch.nn.modules.linear.Linear'>
time_embedding.act -> <class 'torch.nn.modules.activation.SiLU'>
time_embedding.linear_2 -> <class 'torch.nn.modules.linear.Linear'>
controlnet_cond_embedding -> <class 'diffusers.models.controlnet.ControlNetConditioningEmbedding'>
controlnet_cond_embedding.conv_in -> <class 'torch.nn.modules.conv.Conv2d'>
controlnet_cond_embedding.blocks -> <class 'torch.nn.modules.container.ModuleList'>
controlnet_cond_embedding.blocks.0 -> <class 'torch.nn.modules.conv.Conv2d'>
controlnet_cond_embedding.blocks.1 -> <class 'torch.nn.modules.conv.Conv2d'>
controlnet_cond_embedding.blocks.2 -> <class 'torch.nn.modules.conv.Conv2d'>
con