In [None]:
# 导入必要的库
import os
import torch
import numpy as np
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler, AutoencoderKLCogVideoX
from diffusers.utils import load_video, export_to_video

# 导入自定义模块 - 确保这些模块能够被正确导入
import sys
sys.path.append('..')  # 如果需要导入上级目录的模块
from controlnet_pipeline import ControlnetCogVideoXPipeline
from cogvideo_transformer import CustomCogVideoXTransformer3DModel
from cogvideo_controlnet import CogVideoXControlnet

# 设置推理参数
depth_video_path = "data/depth/depth (1).mp4"  # 注意Windows路径使用raw字符串
checkpoint_path = "cogvideox-depth-controlnet/checkpoint-1000.pt"  # 检查点路径
output_path = "generated_depth_video.mp4"  # 输出视频路径
base_model_path = "THUDM/CogVideoX-2b"  # 基础模型路径

# 设置描述文本
prompt = """The video shows a street with cars parked on the side. The camera pans to the right, 
revealing more of the street and the surrounding area. The scene is overcast and foggy, 
creating a somewhat gloomy atmosphere. The camera movement is smooth and steady, 
allowing for a clear view of the surroundings."""

# 加载基础模型
print("正在加载基础模型...")
tokenizer = T5Tokenizer.from_pretrained(base_model_path, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder")
transformer = CustomCogVideoXTransformer3DModel.from_pretrained(base_model_path, subfolder="transformer")
vae = AutoencoderKLCogVideoX.from_pretrained(base_model_path, subfolder="vae")
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_model_path, subfolder="scheduler")

# 加载训练好的ControlNet模型
print("正在加载ControlNet模型...")
controlnet = CogVideoXControlnet(
    num_layers=8,  # 确保与训练时相同的参数
    downscale_coef=8,
    in_channels=3,
    num_attention_heads=30,  # CogVideoX-2b模型使用30个注意力头
    attention_head_dim=64,
    vae_channels=16,
)

# 加载检查点
ckpt = torch.load(checkpoint_path, map_location='cpu')
controlnet.load_state_dict(ckpt['state_dict'])
print(f"ControlNet检查点已加载: {checkpoint_path}")

# 创建推理管道
print("正在创建推理管道...")
pipe = ControlnetCogVideoXPipeline(
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    transformer=transformer,
    vae=vae,
    controlnet=controlnet,
    scheduler=scheduler,
)

# 设置为半精度以减少GPU内存使用
pipe = pipe.to(dtype=torch.float16, device='cuda')

# 启用CPU卸载以进一步减少GPU内存使用（如果需要）
# pipe.enable_model_cpu_offload()

# 可以启用VAE的分片和平铺以减少内存占用
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

# 加载深度图视频
print(f"正在加载深度图视频: {depth_video_path}")
num_frames = 49  # CogVideoX模型最多支持49帧
depth_frames = load_video(depth_video_path)[:num_frames]
print(f"已加载 {len(depth_frames)} 帧")

# 如果帧数不足，可以进行复制
if len(depth_frames) < num_frames:
    print(f"警告：帧数不足 {num_frames}，将进行重复填充")
    # 复制现有帧直到达到所需数量
    while len(depth_frames) < num_frames:
        depth_frames.append(depth_frames[len(depth_frames) % len(depth_frames)])

# 执行推理
print("开始推理过程...")
with torch.no_grad():  # 禁用梯度计算以节省内存
    output = pipe(
        controlnet_frames=depth_frames,
        prompt=prompt,
        height=480,
        width=720,
        num_frames=num_frames,
        guidance_scale=6.0,
        num_inference_steps=50,
        generator=torch.Generator(device="cuda").manual_seed(42),
        controlnet_weights=1.0,
        controlnet_guidance_start=0.0,
        controlnet_guidance_end=0.8,
    )

# 保存生成的视频
print(f"正在保存输出视频至: {output_path}")
export_to_video(output.frames[0], output_path, fps=8)
print("推理完成！")

# 显示第一帧图像用于预览
from IPython.display import Image as IPImage
from IPython.display import display
import tempfile

# 保存第一帧用于显示
first_frame = output.frames[0][0]
temp_image_path = "first_frame_preview.jpg"
first_frame.save(temp_image_path)

# 在Notebook中显示
print("生成视频的第一帧预览:")
display(IPImage(temp_image_path))
print(f"完整视频已保存至: {output_path}")

  from .autonotebook import tqdm as notebook_tqdm


正在加载基础模型...


Downloading shards: 100%|██████████| 2/2 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:11<00:00,  5.61s/it]


正在加载ControlNet模型...
ControlNet检查点已加载: cogvideox-depth-controlnet/checkpoint-1000.pt
正在创建推理管道...
正在加载深度图视频: D:\Programs\cogvideox-more-controlnet\data\depth\depth (1).mp4
已加载 49 帧
开始推理过程...


  0%|          | 0/50 [00:00<?, ?it/s]