<a href="https://colab.research.google.com/github/hshen13/workable-colab/blob/main/controlnettest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision diffusers transformers accelerate opencv-python numpy matplotlib tqdm


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt

# 检查 GPU 是否可用
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


In [None]:
import os
from torch.utils.data import Dataset, DataLoader

# 请修改data_dir为您的数据所在路径
data_dir = "/content/data"  # 假设数据位于此路径下
images_dir = os.path.join(data_dir, "images")
skeletons_dir = os.path.join(data_dir, "skeletons")

# 获取所有图像文件名列表（假定骨骼图与原图文件名前缀相同）
image_files = sorted([f for f in os.listdir(images_dir) if f.endswith(".png") or f.endswith(".jpg")])
skeleton_files = sorted([f for f in os.listdir(skeletons_dir) if f.endswith(".png") or f.endswith(".jpg")])

print(f"找到 {len(image_files)} 张原图, {len(skeleton_files)} 张骨骼图.")

# 定义图像分辨率
img_size = 128  # 可以调整，如 256

class PoseDataset(Dataset):
    def __len__(self):
        return min(len(image_files), len(skeleton_files))

    def __getitem__(self, idx):
        # 构建文件路径
        img_path = os.path.join(images_dir, image_files[idx])
        skel_path = os.path.join(skeletons_dir, skeleton_files[idx])

        # 读取原始图像 (BGR格式)，转换为RGB
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # 读取骨骼图 (灰度)
        skel = cv2.imread(skel_path, cv2.IMREAD_GRAYSCALE)

        # 调整大小到统一尺寸
        if img_size is not None:
            img = cv2.resize(img, (img_size, img_size))
            skel = cv2.resize(skel, (img_size, img_size))

        # 转换为浮点，并归一化到[-1, 1]
        img = img.astype(np.float32) / 127.5 - 1.0  # 原图3通道
        skel = skel.astype(np.float32) / 127.5 - 1.0  # 骨骼图1通道

        # 将数据转换为PyTorch张量，并调整维度顺序为 [C, H, W]
        img_tensor = torch.from_numpy(img).permute(2, 0, 1)  # [3, H, W]
        skel_tensor = torch.from_numpy(skel).unsqueeze(0)    # [1, H, W]

        return skel_tensor, img_tensor

# 创建数据集和数据加载器
dataset = PoseDataset()
batch_size = 8  # 可根据显存大小调整
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# 打印一个样本以验证
skel_sample, img_sample = dataset[0]
print("骨骼张量形状:", skel_sample.shape, "原图张量形状:", img_sample.shape)
print("骨骼张量值范围:", skel_sample.min().item(), "~", skel_sample.max().item())


In [None]:
from diffusers import UNet2DModel

# 定义UNet模型
model = UNet2DModel(
    sample_size=img_size,       # 输入图像大小 (H=img_size, W=img_size)
    in_channels=4,              # 输入通道: 3 (图像) + 1 (骨骼)
    out_channels=3,             # 输出通道: 3 (图像)
    layers_per_block=2,         # 每个下采样/上采样块的层数
    block_out_channels=(64, 128, 128, 256),  # 每个下采样块的通道数
    down_block_types=(         # 下采样模块类型
        "DownBlock2D",         # 无注意力的下采样块
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(           # 上采样模块类型
        "UpBlock2D",           # 无注意力的上采样块
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    )
)
model.to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")


In [None]:
import torch.nn.functional as F

# 扩散过程超参数
num_timesteps = 1000  # 扩散步数 T
beta_start, beta_end = 1e-4, 0.02  # beta调度（线性）
betas = torch.linspace(beta_start, beta_end, num_timesteps).to(device)
alphas = 1.0 - betas
alpha_prod = torch.cumprod(alphas, dim=0)  # 计算 \bar{α}_t 累乘
alpha_prod_prev = torch.cat([torch.tensor([1.0]).to(device), alpha_prod[:-1]], dim=0)

# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# 训练设置
num_epochs = 5   # 训练轮数 (根据数据集大小和需求调整)
log_interval = 100  # 每隔多少步打印一次loss
sample_interval = 1  # 每隔多少轮生成示例图像

# 开始训练循环
model.train()
for epoch in range(1, num_epochs+1):
    epoch_loss = 0.0
    for step, (skel_batch, img_batch) in enumerate(dataloader):
        # 将批数据移动到GPU
        skel_batch = skel_batch.to(device)   # [B, 1, H, W]
        img_batch = img_batch.to(device)     # [B, 3, H, W]

        # 从均匀分布随机采样一个时间步 t，对于一个batch每个样本可以用不同t
        t = torch.randint(0, num_timesteps, (batch_size,), device=device)

        # 采样与img_batch形状相同的标准正态噪声
        noise = torch.randn_like(img_batch)
        # 计算 \sqrt{\bar{α}_t} 和 \sqrt{1-\bar{α}_t} （需将 alpha_prod 按t索引并调整维度）
        alpha_prod_t = alpha_prod[t].reshape(-1, 1, 1, 1)        # [B,1,1,1]
        sqrt_alpha_prod_t = torch.sqrt(alpha_prod_t)            # \sqrt{\bar{\alpha}_t}
        sqrt_one_minus_alpha_prod_t = torch.sqrt(1 - alpha_prod_t)  # \sqrt{1-\bar{\alpha}_t}

        # 前向扩散：生成 x_t
        x_t = sqrt_alpha_prod_t * img_batch + sqrt_one_minus_alpha_prod_t * noise

        # 将骨骼条件与带噪图像拼接作为模型输入
        model_in = torch.cat([x_t, skel_batch], dim=1)  # [B, 4, H, W]

        # 模型预测噪声
        model_output = model(model_in, timesteps=t).sample  # diffusers模型输出包含在sample属性中
        pred_noise = model_output  # [B, 3, H, W]

        # 计算MSE损失 (预测噪声 vs 实际噪声)
        loss = F.mse_loss(pred_noise, noise)
        epoch_loss += loss.item()

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 打印训练信息
        if (step + 1) % log_interval == 0:
            print(f"Epoch {epoch} Step {step+1}/{len(dataloader)}, Loss: {loss.item():.4f}")

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch} completed, average loss: {avg_loss:.4f}")

    # 每隔 sample_interval 轮次，生成样本进行可视化
    if epoch % sample_interval == 0:
        model.eval()
        with torch.no_grad():
            # 从当前训练集中取一张骨骼图来测试
            test_skel, test_img = dataset[0]  # 取第一个样本为例
            test_skel = test_skel.unsqueeze(0).to(device)  # [1,1,H,W]
            true_img = test_img.numpy().transpose(1,2,0)   # 转回numpy用于展示 (H,W,C)
            true_img = ((true_img + 1) * 127.5).astype(np.uint8)  # 反归一化到[0,255]

            # 执行扩散采样过程（从纯噪声生成图像）
            # 从标准正态噪声开始 (B=1)
            x = torch.randn((1, 3, img_size, img_size), device=device)
            for ti in range(num_timesteps-1, -1, -1):
                # 当前步的alpha和beta
                alpha_t = alphas[ti]
                alpha_prod_t = alpha_prod[ti]
                alpha_prod_prev_t = alpha_prod_prev[ti]
                beta_t = betas[ti]
                # 模型预测噪声
                model_in = torch.cat([x, test_skel], dim=1)  # 拼接骨骼条件
                model_out = model(model_in, timesteps=torch.tensor([ti]).to(device)).sample
                pred_noise = model_out[0]  # 输出维度 [1,3,H,W],取第0个

                # 计算x_{t-1}的预测（DDPM采样公式）
                # μ_theta = 1/sqrt(alpha_t) * (x_t - (1-alpha_t)/sqrt(1-ᾱ_t) * pred_noise)
                sqrt_recip_alpha_t = torch.sqrt(1.0/alpha_t)
                sqrt_one_minus_alpha_prod_t = torch.sqrt(1 - alpha_prod_t)
                mu = sqrt_recip_alpha_t * (x - (beta_t / sqrt_one_minus_alpha_prod_t) * pred_noise)
                if ti > 0:
                    # 添加随机噪声
                    sigma_t = torch.sqrt((1 - alpha_prod_prev_t) / (1 - alpha_prod_t) * beta_t)
                    noise_t = torch.randn_like(x)
                    x = mu + sigma_t * noise_t
                else:
                    x = mu
                # 为数值稳定，裁剪x到[-1,1]
                x = x.clamp(-1, 1)

            # 获得生成的图像
            gen_img = x.cpu().numpy()[0].transpose(1, 2, 0)  # 转为numpy (H,W,C)
            gen_img = ((gen_img + 1) * 127.5).astype(np.uint8)  # 反归一化为0-255

            # 显示原图与生成图对比
            fig, axes = plt.subplots(1, 2, figsize=(6,3))
            axes[0].imshow(true_img)
            axes[0].set_title("真实图像")
            axes[0].axis('off')
            axes[1].imshow(gen_img)
            axes[1].set_title("生成图像")
            axes[1].axis('off')
            plt.show()
        model.train()


In [None]:
# 定义生成函数
def generate_image_from_skeleton(skeleton_img):
    """ 给定骨骼图（numpy二维数组或张量），生成对应的图像（numpy数组）。 """
    model.eval()
    with torch.no_grad():
        # 确保骨骼图为numpy数组
        if isinstance(skeleton_img, torch.Tensor):
            skel = skeleton_img.cpu().numpy()
        else:
            skel = skeleton_img
        # 如果是二维，则添加批次和通道维度
        if skel.ndim == 2:
            skel = skel[None, None, ...]  # [1,1,H,W]
        elif skel.ndim == 3:
            skel = skel[None, ...]        # [1,1,H,W] if it was [H,W,1]
        skel = torch.from_numpy(skel.astype(np.float32)).to(device)

        # 初始噪声图 x_T
        x = torch.randn((1, 3, skel.shape[-2], skel.shape[-1]), device=device)
        # 反向扩散过程
        for ti in range(num_timesteps-1, -1, -1):
            alpha_t = alphas[ti]
            alpha_prod_t = alpha_prod[ti]
            alpha_prod_prev_t = alpha_prod_prev[ti]
            beta_t = betas[ti]
            # 模型预测噪声
            model_in = torch.cat([x, skel], dim=1)
            pred = model(model_in, timesteps=torch.tensor([ti]).to(device)).sample
            pred_noise = pred[0]
            # 计算均值 mu_t
            sqrt_recip_alpha_t = (1.0/alpha_t).sqrt()
            sqrt_one_minus_alpha_prod_t = (1 - alpha_prod_t).sqrt()
            mu = sqrt_recip_alpha_t * (x - (beta_t / sqrt_one_minus_alpha_prod_t) * pred_noise)
            if ti > 0:
                sigma_t = ((1 - alpha_prod_prev_t)/(1 - alpha_prod_t) * beta_t).sqrt()
                x = mu + sigma_t * torch.randn_like(x)
            else:
                x = mu
            x = x.clamp(-1, 1)
        # 转换输出图像为numpy格式
        result = x.cpu().numpy()[0].transpose(1, 2, 0)  # [H,W,C]
        result = ((result + 1) * 127.5).astype(np.uint8)  # 反归一化
        return result

# 测试生成函数：用训练集第一张骨骼图
test_skel_img = dataset[0][0].numpy().squeeze()  # 取出骨骼图的numpy数组 [H,W]
gen_result = generate_image_from_skeleton(test_skel_img)
print("生成图像尺寸:", gen_result.shape)
plt.imshow(gen_result)
plt.title("生成结果示例")
plt.axis('off')
plt.show()


In [None]:
!huggingface-cli login



    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) Y
Token is valid (permission: fineG

In [None]:
# 安装最新版本依赖
!pip install --upgrade diffusers transformers accelerate opencv-python

import torch
from PIL import Image
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel

# 设置设备（建议使用 GPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载 SDXL 兼容的 ControlNet 模型（请确认仓库名称是否正确）
controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=torch.float16)
controlnet.to(device)

# 加载 SDXL + ControlNet 管道
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",  # SDXL 模型
    controlnet=controlnet,
    torch_dtype=torch.float16
)
pipe.to(device)

# 加载原图和目标骨架图（确保图片尺寸与模型要求一致，必要时进行预处理）
source_image = Image.open("/content/ManofSteel.png").convert("RGB")
target_skeleton = Image.open("/content/139.png").convert("RGB")

prompt = "superman"
strength = 0.9  # 控制骨架条件对生成结果的影响力度

# 使用 image-to-image 模式生成新图
result = pipe(prompt=prompt, image=source_image, control_image=target_skeleton, strength=strength).images[0]
result.show()
result.save("generated_result_sdxl.png")





An error occurred while trying to fetch thibaud/controlnet-openpose-sdxl-1.0: thibaud/controlnet-openpose-sdxl-1.0 does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
#@title 1. 安装依赖 & 克隆仓库
# 安装最新依赖，确保使用安全加载格式
!pip install --upgrade diffusers transformers accelerate opencv-python safetensors

# 克隆 ControlNetPlus 仓库
!git clone https://github.com/xinsir6/ControlNetPlus.git

# 如果仓库内有 requirements.txt，也可以安装其依赖：
!pip install -r ControlNetPlus/requirements.txt


Cloning into 'ControlNetPlus'...
remote: Enumerating objects: 339, done.[K
remote: Counting objects: 100% (51/51), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 339 (delta 30), reused 28 (delta 17), pack-reused 288 (from 1)[K
Receiving objects: 100% (339/339), 46.87 MiB | 67.03 MiB/s, done.
Resolving deltas: 100% (123/123), done.
Collecting absl-py==2.1.0 (from -r ControlNetPlus/requirements.txt (line 1))
  Downloading absl_py-2.1.0-py3-none-any.whl.metadata (2.3 kB)
Collecting accelerate==0.28.0 (from -r ControlNetPlus/requirements.txt (line 2))
  Downloading accelerate-0.28.0-py3-none-any.whl.metadata (18 kB)
Collecting aiofiles==23.2.1 (from -r ControlNetPlus/requirements.txt (line 3))
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting aiohttp==3.9.3 (from -r ControlNetPlus/requirements.txt (line 4))
  Downloading aiohttp-3.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.4 kB)
Collecting aiosignal==1.

In [None]:
!git clone https://huggingface.co/xinsir/controlnet-union-sdxl-1.0

Cloning into 'controlnet-union-sdxl-1.0'...
remote: Enumerating objects: 177, done.[K
remote: Total 177 (delta 0), reused 0 (delta 0), pack-reused 177 (from 1)[K
Receiving objects: 100% (177/177), 37.19 MiB | 43.87 MiB/s, done.
Resolving deltas: 100% (29/29), done.
Filtering content: 100% (4/4), 4.68 GiB | 63.96 MiB/s, done.


In [None]:
%cd ControlNetPlus
!python controlnet_union_test_openpose.py


/content/ControlNetPlus
2025-02-23 02:09:31.504225: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740276571.526254   18127 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740276571.533073   18127 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Traceback (most recent call last):
  File "/content/ControlNetPlus/controlnet_union_test_openpose.py", line 8, in <module>
    from controlnet_aux import OpenposeDetector
ModuleNotFoundError: No module named 'controlnet_aux'


Controlnet官方的例子

In [None]:
!git clone https://github.com/lllyasviel/ControlNet.git

Cloning into 'ControlNet'...
remote: Enumerating objects: 1356, done.[K
remote: Total 1356 (delta 0), reused 0 (delta 0), pack-reused 1356 (from 1)[K
Receiving objects: 100% (1356/1356), 122.40 MiB | 42.41 MiB/s, done.
Resolving deltas: 100% (596/596), done.


In [None]:
!git clone https://huggingface.co/lllyasviel/ControlNet ControlNet1/


Cloning into 'ControlNet1'...
remote: Enumerating objects: 52, done.[K
remote: Counting objects:   1% (1/52)[Kremote: Counting objects:   3% (2/52)[Kremote: Counting objects:   5% (3/52)[Kremote: Counting objects:   7% (4/52)[Kremote: Counting objects:   9% (5/52)[Kremote: Counting objects:  11% (6/52)[Kremote: Counting objects:  13% (7/52)[Kremote: Counting objects:  15% (8/52)[Kremote: Counting objects:  17% (9/52)[Kremote: Counting objects:  19% (10/52)[Kremote: Counting objects:  21% (11/52)[Kremote: Counting objects:  23% (12/52)[Kremote: Counting objects:  25% (13/52)[Kremote: Counting objects:  26% (14/52)[Kremote: Counting objects:  28% (15/52)[Kremote: Counting objects:  30% (16/52)[Kremote: Counting objects:  32% (17/52)[Kremote: Counting objects:  34% (18/52)[Kremote: Counting objects:  36% (19/52)[Kremote: Counting objects:  38% (20/52)[Kremote: Counting objects:  40% (21/52)[Kremote: Counting objects:  42% (22/52)[Kremote: Count

In [None]:
!unzip /content/ControlNet1/training/fill50k.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: fill50k/target/5499.png  
  inflating: fill50k/target/55.png   
  inflating: fill50k/target/550.png  
  inflating: fill50k/target/5500.png  
  inflating: fill50k/target/5501.png  
  inflating: fill50k/target/5502.png  
  inflating: fill50k/target/5503.png  
  inflating: fill50k/target/5504.png  
  inflating: fill50k/target/5505.png  
  inflating: fill50k/target/5506.png  
  inflating: fill50k/target/5507.png  
  inflating: fill50k/target/5508.png  
  inflating: fill50k/target/5509.png  
  inflating: fill50k/target/551.png  
  inflating: fill50k/target/5510.png  
  inflating: fill50k/target/5511.png  
  inflating: fill50k/target/5512.png  
  inflating: fill50k/target/5513.png  
  inflating: fill50k/target/5514.png  
  inflating: fill50k/target/5515.png  
  inflating: fill50k/target/5516.png  
  inflating: fill50k/target/5517.png  
  inflating: fill50k/target/5518.png  
  inflating: fill50k/target/5519.png  
  

In [None]:
!git clone https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main

Cloning into 'main'...
remote: Entry not found
fatal: repository 'https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/' not found


In [None]:
import json
import cv2
import numpy as np

from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/fill50k/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']

        source = cv2.imread('./training/fill50k/' + source_filename)
        target = cv2.imread('./training/fill50k/' + target_filename)

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)


In [None]:
from tutorial_dataset import MyDataset

dataset = MyDataset()
print(len(dataset))

item = dataset[1234]
jpg = item['jpg']
txt = item['txt']
hint = item['hint']
print(txt)
print(jpg.shape)
print(hint.shape)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DownUpBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(DownUpBlock, self).__init__()
        # 下采样部分：卷积 + 最大池化
        self.conv_down = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # 上采样部分：转置卷积 + 卷积
        self.upconv = nn.ConvTranspose2d(mid_channels, mid_channels, kernel_size=2, stride=2)
        self.conv_up = nn.Conv2d(mid_channels * 2, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        # 下采样
        x1 = F.relu(self.conv_down(x))
        x_down = self.pool(x1)
        # 上采样
        x_up = F.relu(self.upconv(x_down))
        # 如果需要，可以在这里将上采样结果与下采样前的特征（x1）做拼接（skip connection）
        # 注意确保尺寸匹配，如果尺寸不匹配可能需要裁剪
        x_cat = torch.cat([x_up, x1], dim=1)
        x_out = F.relu(self.conv_up(x_cat))
        return x_out

# 测试
if __name__ == '__main__':
    x = torch.randn(1, 3, 64, 64)  # 输入图像尺寸为 64x64
    block = DownUpBlock(in_channels=3, mid_channels=16, out_channels=8)
    y = block(x)
    print("DownUpBlock输出尺寸:", y.shape)

DownUpBlock输出尺寸: torch.Size([1, 8, 64, 64])
