In [16]:
import torch
import numpy as np
import cv2
import os
from typing import List
from wan.modules.vae import WanVAE

def tensors_to_mp4_rgb(tensor_list: List[torch.Tensor], output_path: str, fps: int = 30):
    """Convert a list of PyTorch tensors to an MP4 video file with H.264 codec."""
    if not tensor_list:
        raise ValueError("tensor_list is empty")
    
    height, width = 480, 832
    
    # First, write to a temporary file with a codec that OpenCV supports
    temp_path = output_path.replace('.mp4', '_temp.mp4')
    
    # Try different codecs in order of preference
    codecs_to_try = [
        cv2.VideoWriter_fourcc(*'avc1'),  # H.264 variant
        cv2.VideoWriter_fourcc(*'H264'),  # H.264
        cv2.VideoWriter_fourcc(*'mp4v'),  # MPEG-4
        cv2.VideoWriter_fourcc(*'XVID'),  # XVID
    ]
    
    out = None
    for fourcc in codecs_to_try:
        out = cv2.VideoWriter(temp_path, fourcc, fps, (width, height), isColor=True)
        if out.isOpened():
            print(f"Using codec: {fourcc}")
            break
    
    if not out or not out.isOpened():
        raise ValueError("Failed to open video writer with any codec")
    
    for tensor in tensor_list:
        # Handle both [3, 832, 480] and [3, 1, 832, 480] shapes
        if tensor.dim() == 4:
            frame = tensor.squeeze(1)
        else:
            frame = tensor
            
        # Convert from CHW to HWC format
        frame = frame.permute(1, 2, 0)
        frame_np = frame.cpu().numpy()
        
        # Normalize to 0-255 range
        if frame_np.max() <= 1.0:
            frame_np = (frame_np * 255).astype(np.uint8)
        else:
            frame_np = frame_np.astype(np.uint8)
        
        # Convert RGB to BGR for OpenCV
        frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
        out.write(frame_bgr)
    
    out.release()
    cv2.destroyAllWindows()
    
    # Convert using ffmpeg if available (best compatibility)
    try:
        import subprocess
        # Use ffmpeg to convert to a universally compatible H.264 MP4
        cmd = [
            'ffmpeg', '-y', '-i', temp_path,
            '-c:v', 'libx264',  # H.264 codec
            '-pix_fmt', 'yuv420p',  # Pixel format for compatibility
            '-movflags', '+faststart',  # For web playback
            output_path
        ]
        subprocess.run(cmd, check=True, capture_output=True)
        os.remove(temp_path)
        print(f"Video saved with H.264 codec to: {output_path}")
    except:
        # If ffmpeg is not available, just rename the temp file
        if os.path.exists(temp_path):
            os.rename(temp_path, output_path)
        print(f"Video saved to: {output_path} (without ffmpeg conversion)")

# def tensors_to_mp4_rgb(tensor_list: List[torch.Tensor], output_path: str, fps: int = 16):
#     """Convert a list of PyTorch tensors to an MP4 video file (RGB)."""
#     if not tensor_list:
#         raise ValueError("tensor_list is empty")
    
#     height, width = 832, 480
#     fourcc = cv2.VideoWriter_fourcc(*'mp4v')
#     out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=True)
    
#     for tensor in tensor_list:
#         # Handle both [3, 832, 480] and [3, 1, 832, 480] shapes
#         if tensor.dim() == 4:
#             frame = tensor.squeeze(1)
#         else:
#             frame = tensor
            
#         # Convert from CHW to HWC format
#         frame = frame.permute(1, 2, 0)
#         frame_np = frame.cpu().numpy()
        
#         # Normalize to 0-255 range
#         if frame_np.max() <= 1.0:
#             frame_np = (frame_np * 255).astype(np.uint8)
#         else:
#             frame_np = frame_np.astype(np.uint8)
        
#         # Convert RGB to BGR for OpenCV
#         frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
#         out.write(frame_bgr)
    
#     out.release()
#     cv2.destroyAllWindows()
#     print(f"Video saved to: {output_path}")

In [11]:
latent_orig = torch.load('latent_pre_patch.pt')
latent_orig.shape

torch.Size([20, 21, 60, 104])

In [12]:
# Setup device and VAE
device = torch.device("cuda:0")
vae = WanVAE(
    vae_pth=os.path.join('/workspace/', 'wan_2.1_vae.safetensors'),
    device=device
)

# Load tensors
t_orig = torch.load('vae_latent.pt')
t_comfy = torch.load('/ComfyUI_fork/res.pt')

if t_orig.dim() == 5:
    t_orig = t_orig.squeeze(0)
if t_comfy.dim() == 5:
    t_comfy = t_comfy.squeeze(0)

# Compare tensors
print(f"t_orig shape: {t_orig.shape}, t_comfy shape: {t_comfy.shape}")

t_orig shape: torch.Size([20, 21, 60, 104]), t_comfy shape: torch.Size([20, 21, 60, 104])


In [5]:
print(f"Difference - min: {(t_orig - t_comfy).min():.4f}, max: {(t_orig - t_comfy).max():.4f}")
# Decode latents (skip first 4 channels if needed)
# Note: Based on your history, you're using indices [4:] to skip some channels
d_orig = vae.decode([t_orig[4:]])[0]
d_comfy = vae.decode([t_comfy[4:]])[0]

Difference - min: -0.0406, max: 0.0397


  with amp.autocast(dtype=self.dtype):


In [None]:
t_orig - t_comfy

In [7]:
diff = d_orig - d_comfy
print(diff.sum())

tensor(135106.1562, device='cuda:0')


In [17]:
print("d_orig:", d_orig.shape, "d_comfy:", d_comfy[0].shape)

# Convert decoded tensors to video
# Split along dimension 1 to get individual frames
frames_orig = torch.split(d_orig, 1, dim=1)
frames_comfy = torch.split(d_comfy, 1, dim=1)
frames_diff = torch.split((d_comfy - d_orig), 1, dim=1)

# Save videos
tensors_to_mp4_rgb(frames_orig, "orig.mp4", fps=16)
tensors_to_mp4_rgb(frames_comfy, "comfy.mp4", fps=16)
tensors_to_mp4_rgb(frames_diff, "diff.mp4", fps=16)

d_orig: torch.Size([3, 81, 480, 832]) d_comfy: torch.Size([81, 480, 832])
Using codec: 1983148141


[ERROR:0@146.256] global cap_ffmpeg_impl.hpp:3207 open Could not find encoder for codec_id=27, error: Encoder not found
[ERROR:0@146.256] global cap_ffmpeg_impl.hpp:3285 open VIDEOIO/FFMPEG: Failed to initialize VideoWriter
OpenCV: FFMPEG: tag 0x34363248/'H264' is not supported with codec id 27 and format 'mp4 / MP4 (MPEG-4 Part 14)'
OpenCV: FFMPEG: fallback to use tag 0x31637661/'avc1'
[ERROR:0@146.256] global cap_ffmpeg_impl.hpp:3207 open Could not find encoder for codec_id=27, error: Encoder not found
[ERROR:0@146.256] global cap_ffmpeg_impl.hpp:3285 open VIDEOIO/FFMPEG: Failed to initialize VideoWriter


Video saved with H.264 codec to: orig.mp4
Using codec: 1983148141


[ERROR:0@146.794] global cap_ffmpeg_impl.hpp:3207 open Could not find encoder for codec_id=27, error: Encoder not found
[ERROR:0@146.794] global cap_ffmpeg_impl.hpp:3285 open VIDEOIO/FFMPEG: Failed to initialize VideoWriter
OpenCV: FFMPEG: tag 0x34363248/'H264' is not supported with codec id 27 and format 'mp4 / MP4 (MPEG-4 Part 14)'
OpenCV: FFMPEG: fallback to use tag 0x31637661/'avc1'
[ERROR:0@146.794] global cap_ffmpeg_impl.hpp:3207 open Could not find encoder for codec_id=27, error: Encoder not found
[ERROR:0@146.794] global cap_ffmpeg_impl.hpp:3285 open VIDEOIO/FFMPEG: Failed to initialize VideoWriter


Video saved with H.264 codec to: comfy.mp4
Using codec: 1983148141


[ERROR:0@147.379] global cap_ffmpeg_impl.hpp:3207 open Could not find encoder for codec_id=27, error: Encoder not found
[ERROR:0@147.379] global cap_ffmpeg_impl.hpp:3285 open VIDEOIO/FFMPEG: Failed to initialize VideoWriter
OpenCV: FFMPEG: tag 0x34363248/'H264' is not supported with codec id 27 and format 'mp4 / MP4 (MPEG-4 Part 14)'
OpenCV: FFMPEG: fallback to use tag 0x31637661/'avc1'
[ERROR:0@147.379] global cap_ffmpeg_impl.hpp:3207 open Could not find encoder for codec_id=27, error: Encoder not found
[ERROR:0@147.379] global cap_ffmpeg_impl.hpp:3285 open VIDEOIO/FFMPEG: Failed to initialize VideoWriter


Video saved with H.264 codec to: diff.mp4


<video controls src="/ATI/wan_ati_standalone/diff.mp4">animation</video>


In [18]:
from ipywidgets import Output, GridspecLayout
from IPython import display

filepaths = ['orig.mp4', 'comfy.mp4', 'diff.mp4']
grid = GridspecLayout(1, len(filepaths))

for i, filepath in enumerate(filepaths):
    out = Output()
    with out:
        display.display(display.Video(filepath, width=832//2, height=480//2, embed=True))
    grid[0, i] = out

grid

GridspecLayout(children=(Output(layout=Layout(grid_area='widget001')), Output(layout=Layout(grid_area='widget0…

In [110]:
import torch
tracks_orig = torch.load('tracks_pre_patch.pt')
tracks_orig2 = torch.load('tracks_pre_process_out.pt')
tracks_comfy = torch.load('/ComfyUI_fork/tracks.pt')

In [114]:
tracks_diff = tracks_orig - tracks_comfy.to(tracks_orig.device)

In [115]:
tracks_diff

tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         ...,

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]], device='cuda:0')

In [96]:
tracks_orig2.shape

torch.Size([81, 4, 4])

In [68]:
tracks_p_orig = torch.load('tracks_pre_process.pt')
tracks_p_comfy = torch.load('/ComfyUI_fork/tracks_pre_process.pt')

In [69]:
diff = tracks_p_orig-tracks_p_comfy
diff.max()

tensor(0.)

In [70]:
(tracks_p_orig.floor()-tracks_p_comfy).max()

tensor(0.)

In [60]:
%ls

43_cond.pt                   [0m[01;34mexamples[0m/
43_cond_p5.pt                fix_cuda.py
46_cond.pt                   orig.mp4
README.md                    requirements.txt
Untitled.ipynb               [01;32mrun_wan_ati.py[0m*
check_arch_support.py        run_wan_ati_refactored.py
check_checkpoint.py          test_chunked_load.py
check_compute_capability.py  test_cuda_kernel.py
check_gpu.sh                 test_fp8_minimal.py
check_pytorch.py             test_fp8_ops.py
comfy.mp4                    test_memory.py
d1.mp4                       test_model_load.py
d1_fixed.mp4                 test_motion_preprocessing.py
d2.mp4                       tracks_pre_patch.pt
debug_clip_checkpoint.py     tracks_pre_process.pt
debug_cuda.py                vae_latent.pt
diff.mp4                     validate_inputs.py
diff_fixed.mp4               [01;34mwan[0m/
