diff --git a/visual-ai/Wan2.2/README.md b/visual-ai/Wan2.2/README.md new file mode 100644 index 0000000..f78bf51 --- /dev/null +++ b/visual-ai/Wan2.2/README.md @@ -0,0 +1,37 @@ +# Docker setup + +Build docker image: + +```bash +bash build.sh +``` + +Run docker image: + +```bash +export DOCKER_IMAGE=llm-scaler-visualai:latest-wan2.2 +export CONTAINER_NAME=wan-2.2 +export MODEL_DIR= +sudo docker run -itd \ + --privileged \ + --net=host \ + --device=/dev/dri \ + -e no_proxy=localhost,127.0.0.1 \ + --name=$CONTAINER_NAME \ + -v $MODEL_DIR:/llm/models/ \ + --shm-size="16g" \ + --entrypoint=/bin/bash \ + $DOCKER_IMAGE + +docker exec -it wan-2.2 bash +``` + +Run Wan 2.2 demo on Single B60 GPU: +```bash +python3 generate.py --task ti2v-5B --size 1280*704 --ckpt_dir /llm/models/Wan2.2-TI2V-5B/ --offload_model True --t5_cpu --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --convert_model_dtype --frame_num 101 --sample_steps 50 +``` + +Run Wan 2.2 demo on 2 * B60 GPUs: +```bash +torchrun --nproc_per_node=2 generate.py --task ti2v-5B --size 1280*704 --ckpt_dir /llm/models/Wan2.2-TI2V-5B/ --ulysses_size 2 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --offload_model True --t5_cpu --convert_model_dtype --frame_num 101 --sample_steps 50 +``` diff --git a/visual-ai/Wan2.2/build.sh b/visual-ai/Wan2.2/build.sh new file mode 100644 index 0000000..f24fcba --- /dev/null +++ b/visual-ai/Wan2.2/build.sh @@ -0,0 +1,6 @@ +set -x + +export HTTP_PROXY= +export HTTPS_PROXY= + +docker build -f ./docker/Dockerfile . -t llm-scaler-visualai:latest-wan2.1 --build-arg https_proxy=$HTTPS_PROXY --build-arg http_proxy=$HTTP_PROXY \ No newline at end of file diff --git a/visual-ai/Wan2.2/docker/Dockerfile b/visual-ai/Wan2.2/docker/Dockerfile new file mode 100644 index 0000000..2cd5a39 --- /dev/null +++ b/visual-ai/Wan2.2/docker/Dockerfile @@ -0,0 +1,109 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# ======== Base Stage ======== +FROM intel/deep-learning-essentials:2025.0.2-0-devel-ubuntu24.04 AS vllm-base + +ARG https_proxy +ARG http_proxy + +# Add Intel oneAPI repo and PPA for GPU support +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ + echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \ + add-apt-repository -y ppa:kobuk-team/intel-graphics-testing + +# Install dependencies and Python 3.10 +RUN apt-get update -y && \ + apt-get install -y software-properties-common && \ + add-apt-repository ppa:deadsnakes/ppa && \ + apt-get update -y && \ + apt-get install -y python3.10 python3.10-distutils python3.10-dev && \ + curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && \ + apt-get install -y --no-install-recommends --fix-missing \ + curl \ + ffmpeg \ + git \ + libsndfile1 \ + libsm6 \ + libxext6 \ + libgl1 \ + lsb-release \ + numactl \ + wget \ + vim \ + linux-libc-dev && \ + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 && \ + # Install Intel GPU runtime packages + apt-get update -y && \ + apt-get install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing && \ + apt-get install -y intel-oneapi-dpcpp-ct=2025.0.1-17 && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +# pin compute runtime version +RUN mkdir /tmp/neo && \ + cd /tmp/neo && \ + wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.12.5/intel-igc-core-2_2.12.5+19302_amd64.deb && \ + wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.12.5/intel-igc-opencl-2_2.12.5+19302_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.22.33944.8/intel-ocloc-dbgsym_25.22.33944.8-0_amd64.ddeb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.22.33944.8/intel-ocloc_25.22.33944.8-0_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.22.33944.8/intel-opencl-icd-dbgsym_25.22.33944.8-0_amd64.ddeb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.22.33944.8/intel-opencl-icd_25.22.33944.8-0_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.22.33944.8/libigdgmm12_22.7.0_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.22.33944.8/libze-intel-gpu1-dbgsym_25.22.33944.8-0_amd64.ddeb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.22.33944.8/libze-intel-gpu1_25.22.33944.8-0_amd64.deb && \ + dpkg -i *.deb + +WORKDIR /llm +COPY ./patches/wan22_for_multi_arc.patch /tmp/ +COPY ./patches/0001-oneccl-align-global-V0.1.1.patch /tmp/ + +# Set environment variables early +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/" + +# ======= Add oneCCL build ======= +RUN apt-get update && apt-get install -y \ + cmake \ + g++ \ + && rm -rf /var/lib/apt/lists/* + +# Build 1ccl +RUN git clone https://github.com/oneapi-src/oneCCL.git && \ + cd oneCCL && \ + git checkout def870543749186b6f38cdc865b44d52174c7492 && \ + git apply /tmp/0001-oneccl-align-global-V0.1.1.patch && \ + mkdir build && cd build && \ + export IGC_VISAOptions=-activeThreadsOnlyBarrier && \ + /usr/bin/cmake .. \ + -DCMAKE_INSTALL_PREFIX=_install \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx \ + -DCOMPUTE_BACKEND=dpcpp \ + -DCCL_ENABLE_ARCB=1 && \ + make -j && make install && \ + mv _install /opt/intel/oneapi/ccl/2021.15.3 && \ + cd /opt/intel/oneapi/ccl/ && \ + ln -snf 2021.15.3 latest + +# Configure environment to source oneAPI +RUN echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc + +SHELL ["bash", "-c"] +CMD ["bash", "-c", "source /root/.bashrc && exec bash"] + +ENV LD_LIBRARY_PATH="/usr/local/lib:/usr/local/lib/python3.10/dist-packages/torch/lib:$LD_LIBRARY_PATH" + +RUN pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/xpu && \ + pip install intel-extension-for-pytorch==2.7.10+xpu oneccl_bind_pt==2.7.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \ + pip install bigdl-core-xe-all==2.6.0 --extra-index-url https://download.pytorch.org/whl/xpu && \ + apt remove python3-blinker -y + +RUN cd /llm && \ + git clone https://github.com/Wan-Video/Wan2.2.git && \ + cd ./Wan2.2 && \ + git checkout 031a9be56cec91e86d140d3d3a74280fb05a9b1c && \ + git apply /tmp/wan22_for_multi_arc.patch && \ + pip install -r requirements.txt && \ + pip install einops && \ + pip install cffi + +WORKDIR /llm/Wan2.2 diff --git a/visual-ai/Wan2.2/patches/0001-oneccl-align-global-V0.1.1.patch b/visual-ai/Wan2.2/patches/0001-oneccl-align-global-V0.1.1.patch new file mode 100644 index 0000000..8f8a987 --- /dev/null +++ b/visual-ai/Wan2.2/patches/0001-oneccl-align-global-V0.1.1.patch @@ -0,0 +1,125 @@ +From 7f7a3d65541828d9889bfdec799bc23339e8e520 Mon Sep 17 00:00:00 2001 +From: YongZhuIntel +Date: Wed, 21 May 2025 09:37:06 +0800 +Subject: [PATCH] oneccl align global V0.1.1 + +base on public branch release/ccl_2021.15.3-arc(def870543749186b6f38cdc865b44d52174c7492) + +Build: + 1. mkdir build; cd build + 2. source /opt/intel/oneapi/setvars.sh + 3. export IGC_VISAOptions=-activeThreadsOnlyBarrier + 4. cmake .. -DCMAKE_INSTALL_PREFIX=_install -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DCOMPUTE_BACKEND=dpcpp -DCCL_ENABLE_ARCB=1 && make -j && make install + +print bandwidth in benchmark +--- + examples/benchmark/include/benchmark.hpp | 40 +++++++++++++++++++++--- + examples/benchmark/src/benchmark.cpp | 7 +++-- + 2 files changed, 41 insertions(+), 6 deletions(-) + +diff --git a/examples/benchmark/include/benchmark.hpp b/examples/benchmark/include/benchmark.hpp +index 08a3625..bff6275 100644 +--- a/examples/benchmark/include/benchmark.hpp ++++ b/examples/benchmark/include/benchmark.hpp +@@ -377,7 +377,9 @@ void store_to_csv(const user_options_t& options, + double max_time, + double avg_time, + double stddev, +- double wait_avg_time) { ++ double wait_avg_time, ++ double algbw, ++ double busbw) { + std::ofstream csvf; + csvf.open(options.csv_filepath, std::ofstream::out | std::ofstream::app); + +@@ -396,7 +398,7 @@ void store_to_csv(const user_options_t& options, + << "," << ccl::get_datatype_size(dtype) << "," << elem_count << "," + << ccl::get_datatype_size(dtype) * elem_count << "," << buf_count << "," + << iter_count << "," << min_time << "," << max_time << "," << avg_time << "," +- << stddev << "," << wait_avg_time << std::endl; ++ << stddev << "," << wait_avg_time << "," << algbw << "," << busbw << std::endl; + } + csvf.close(); + } +@@ -472,13 +474,41 @@ void print_timings(const ccl::communicator& comm, + max_time /= iter_count; + + size_t bytes = elem_count * ccl::get_datatype_size(dtype) * buf_count; ++ ++ double algbw = bytes*1000/total_avg_time/1024/1024; ++ ++ if (ncolls == 1) { ++ if (options.coll_names.front() == "allgather" || ++ options.coll_names.front() == "allgatherv" || ++ options.coll_names.front() == "reducescatter" || ++ options.coll_names.front() == "alltoall" || ++ options.coll_names.front() == "alltoallv") { ++ algbw = algbw * nranks; ++ } ++ } ++ ++ double busbw = algbw; ++ if (ncolls == 1) { ++ if (options.coll_names.front() == "allreduce") { ++ busbw = algbw * 2 * (nranks -1) / nranks; ++ } else if (options.coll_names.front() == "allgather" || ++ options.coll_names.front() == "allgatherv" || ++ options.coll_names.front() == "reducescatter" || ++ options.coll_names.front() == "alltoall" || ++ options.coll_names.front() == "alltoallv") { ++ busbw = algbw * (nranks -1) / nranks; ++ } ++ } ++ + std::stringstream ss; + ss << std::right << std::fixed << std::setw(COL_WIDTH) << bytes << std::setw(COL_WIDTH) + << elem_count * buf_count << std::setw(COL_WIDTH) << iter_count << std::setw(COL_WIDTH) + << std::setprecision(COL_PRECISION) << min_time << std::setw(COL_WIDTH) + << std::setprecision(COL_PRECISION) << max_time << std::setw(COL_WIDTH) + << std::setprecision(COL_PRECISION) << total_avg_time << std::setw(COL_WIDTH - 3) +- << std::setprecision(COL_PRECISION) << stddev << std::setw(COL_WIDTH + 3); ++ << std::setprecision(COL_PRECISION) << stddev << std::setw(COL_WIDTH) ++ << std::setprecision(COL_PRECISION) << algbw << std::setw(COL_WIDTH) ++ << std::setprecision(COL_PRECISION) << busbw << std::setw(COL_WIDTH + 3); + + if (show_extened_info(options.show_additional_info)) { + ss << std::right << std::fixed << std::setprecision(COL_PRECISION) << wait_avg_time; +@@ -497,7 +527,9 @@ void print_timings(const ccl::communicator& comm, + max_time, + total_avg_time, + stddev, +- wait_avg_time); ++ wait_avg_time, ++ algbw, ++ busbw); + } + } + +diff --git a/examples/benchmark/src/benchmark.cpp b/examples/benchmark/src/benchmark.cpp +index d90fb9b..78957f2 100644 +--- a/examples/benchmark/src/benchmark.cpp ++++ b/examples/benchmark/src/benchmark.cpp +@@ -105,7 +105,8 @@ void run(ccl::communicator& service_comm, + << "#elem_count" << std::setw(COL_WIDTH) << "#repetitions" + << std::setw(COL_WIDTH) << "t_min[usec]" << std::setw(COL_WIDTH) << "t_max[usec]" + << std::setw(COL_WIDTH) << "t_avg[usec]" << std::setw(COL_WIDTH - 3) +- << "stddev[%]"; ++ << "stddev[%]" << std::setw(COL_WIDTH) << "algbw[GB/s]" << std::setw(COL_WIDTH) ++ << "busbw[GB/s]"; + + if (show_extened_info(options.show_additional_info)) { + ss << std::right << std::setw(COL_WIDTH + 3) << "wait_t_avg[usec]"; +@@ -435,7 +436,9 @@ int main(int argc, char* argv[]) { + << "t_max[usec]," + << "t_avg[usec]," + << "stddev[%]," +- << "wait_t_avg[usec]" << std::endl; ++ << "wait_t_avg[usec]," ++ << "algbw[GB/s]," ++ << "busbw[GB/s]" << std::endl; + csvf.close(); + } + +-- +2.25.1 + diff --git a/visual-ai/Wan2.2/patches/wan22_for_multi_arc.patch b/visual-ai/Wan2.2/patches/wan22_for_multi_arc.patch new file mode 100644 index 0000000..0be176c --- /dev/null +++ b/visual-ai/Wan2.2/patches/wan22_for_multi_arc.patch @@ -0,0 +1,717 @@ +diff --git a/generate.py b/generate.py +index c3e5816..efdbaea 100644 +--- a/generate.py ++++ b/generate.py +@@ -7,7 +7,8 @@ import warnings + from datetime import datetime + + warnings.filterwarnings('ignore') +- ++import intel_extension_for_pytorch ++import oneccl_bindings_for_pytorch + import random + + import torch +@@ -225,9 +226,9 @@ def generate(args): + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: +- torch.cuda.set_device(local_rank) ++ torch.xpu.set_device(local_rank) + dist.init_process_group( +- backend="nccl", ++ backend="ccl", + init_method="env://", + rank=rank, + world_size=world_size) +@@ -398,7 +399,7 @@ def generate(args): + value_range=(-1, 1)) + del video + +- torch.cuda.synchronize() ++ torch.xpu.synchronize() + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() +diff --git a/pyproject.toml b/pyproject.toml +index 337240a..0773d24 100644 +--- a/pyproject.toml ++++ b/pyproject.toml +@@ -26,7 +26,6 @@ dependencies = [ + "ftfy", + "dashscope", + "imageio-ffmpeg", +- "flash_attn", + "numpy>=1.23.5,<2" + ] + +diff --git a/requirements.txt b/requirements.txt +index 77c1e6d..cb26887 100644 +--- a/requirements.txt ++++ b/requirements.txt +@@ -11,5 +11,4 @@ easydict + ftfy + dashscope + imageio-ffmpeg +-flash_attn + numpy>=1.23.5,<2 +diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py +index 6bb496d..9dc907a 100644 +--- a/wan/distributed/fsdp.py ++++ b/wan/distributed/fsdp.py +@@ -40,4 +40,4 @@ def free_model(model): + _free_storage(m._handle.flat_param.data) + del model + gc.collect() +- torch.cuda.empty_cache() ++ torch.xpu.empty_cache() +diff --git a/wan/distributed/sequence_parallel.py b/wan/distributed/sequence_parallel.py +index 9c1ad78..9ea6ee6 100644 +--- a/wan/distributed/sequence_parallel.py ++++ b/wan/distributed/sequence_parallel.py +@@ -1,6 +1,6 @@ + # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + import torch +-import torch.cuda.amp as amp ++import torch.xpu.amp as amp + + from ..modules.model import sinusoidal_embedding_1d + from .ulysses import distributed_attention +@@ -20,7 +20,7 @@ def pad_freqs(original_tensor, target_len): + return padded_tensor + + +-@torch.amp.autocast('cuda', enabled=False) ++@torch.amp.autocast('xpu', enabled=False) + def rope_apply(x, grid_sizes, freqs): + """ + x: [B, L, N, C]. +@@ -99,7 +99,7 @@ def sp_dit_forward( + # time embeddings + if t.dim() == 1: + t = t.expand(t.size(0), seq_len) +- with torch.amp.autocast('cuda', dtype=torch.float32): ++ with torch.amp.autocast('xpu', dtype=torch.float32): + bt = t.size(0) + t = t.flatten() + e = self.time_embedding( +diff --git a/wan/distributed/ulysses.py b/wan/distributed/ulysses.py +index 12d7d30..1b8a9f0 100644 +--- a/wan/distributed/ulysses.py ++++ b/wan/distributed/ulysses.py +@@ -2,7 +2,7 @@ + import torch + import torch.distributed as dist + +-from ..modules.attention import flash_attention ++from ..modules.attention import flash_attention, attention + from .util import all_to_all + + +@@ -34,7 +34,7 @@ def distributed_attention( + v = all_to_all(v, scatter_dim=2, gather_dim=1) + + # apply attention +- x = flash_attention( ++ x = attention( + q, + k, + v, +diff --git a/wan/distributed/util.py b/wan/distributed/util.py +index 241efa1..9efcf88 100644 +--- a/wan/distributed/util.py ++++ b/wan/distributed/util.py +@@ -7,7 +7,7 @@ def init_distributed_group(): + """r initialize sequence parallel group. + """ + if not dist.is_initialized(): +- dist.init_process_group(backend='nccl') ++ dist.init_process_group(backend='ccl') + + + def get_rank(): +diff --git a/wan/image2video.py b/wan/image2video.py +index 659564c..efc1913 100644 +--- a/wan/image2video.py ++++ b/wan/image2video.py +@@ -11,7 +11,7 @@ from functools import partial + + import numpy as np + import torch +-import torch.cuda.amp as amp ++import torch.xpu.amp as amp + import torch.distributed as dist + import torchvision.transforms.functional as TF + from tqdm import tqdm +@@ -71,7 +71,7 @@ class WanI2V: + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + """ +- self.device = torch.device(f"cuda:{device_id}") ++ self.device = torch.device(f"xpu:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu +@@ -195,7 +195,7 @@ class WanI2V: + if offload_model or self.init_on_cpu: + if next(getattr( + self, +- offload_model_name).parameters()).device.type == 'cuda': ++ offload_model_name).parameters()).device.type == 'xpu': + getattr(self, offload_model_name).to('cpu') + if next(getattr( + self, +@@ -333,7 +333,7 @@ class WanI2V: + + # evaluation mode + with ( +- torch.amp.autocast('cuda', dtype=self.param_dtype), ++ torch.amp.autocast('xpu', dtype=self.param_dtype), + torch.no_grad(), + no_sync_low_noise(), + no_sync_high_noise(), +@@ -377,7 +377,7 @@ class WanI2V: + } + + if offload_model: +- torch.cuda.empty_cache() ++ torch.xpu.empty_cache() + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = [latent.to(self.device)] +@@ -393,11 +393,11 @@ class WanI2V: + noise_pred_cond = model( + latent_model_input, t=timestep, **arg_c)[0] + if offload_model: +- torch.cuda.empty_cache() ++ torch.xpu.empty_cache() + noise_pred_uncond = model( + latent_model_input, t=timestep, **arg_null)[0] + if offload_model: +- torch.cuda.empty_cache() ++ torch.xpu.empty_cache() + noise_pred = noise_pred_uncond + sample_guide_scale * ( + noise_pred_cond - noise_pred_uncond) + +@@ -415,7 +415,7 @@ class WanI2V: + if offload_model: + self.low_noise_model.cpu() + self.high_noise_model.cpu() +- torch.cuda.empty_cache() ++ torch.xpu.empty_cache() + + if self.rank == 0: + videos = self.vae.decode(x0) +@@ -424,7 +424,7 @@ class WanI2V: + del sample_scheduler + if offload_model: + gc.collect() +- torch.cuda.synchronize() ++ torch.xpu.synchronize() + if dist.is_initialized(): + dist.barrier() + +diff --git a/wan/modules/attention.py b/wan/modules/attention.py +index 4dbbe03..379bcdc 100644 +--- a/wan/modules/attention.py ++++ b/wan/modules/attention.py +@@ -13,6 +13,7 @@ try: + except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + ++FLASH_ATTN_2_AVAILABLE = False + import warnings + + __all__ = [ +@@ -51,7 +52,7 @@ def flash_attention( + """ + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes +- assert q.device.type == 'cuda' and q.size(-1) <= 256 ++ assert q.device.type == 'xpu' and q.size(-1) <= 256 + + # params + b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype +@@ -168,12 +169,27 @@ def attention( + ) + attn_mask = None + ++ # q = q.transpose(1, 2).to(dtype) ++ # k = k.transpose(1, 2).to(dtype) ++ # v = v.transpose(1, 2).to(dtype) ++ ++ # out = torch.nn.functional.scaled_dot_product_attention( ++ # q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) ++ ++ # out = out.transpose(1, 2).contiguous() ++ # return out ++ ++ dtype1 = dtype ++ dtype = torch.float16 + q = q.transpose(1, 2).to(dtype) + k = k.transpose(1, 2).to(dtype) + v = v.transpose(1, 2).to(dtype) + +- out = torch.nn.functional.scaled_dot_product_attention( +- q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) ++ import xe_addons ++ head_size = q.shape[-1] ++ import math ++ scale = 1 / math.sqrt(head_size) ++ out = xe_addons.sdp_non_causal(q.contiguous(), k.contiguous(), v.contiguous(), attn_mask, scale) + + out = out.transpose(1, 2).contiguous() +- return out ++ return out.to(dtype1) +diff --git a/wan/modules/model.py b/wan/modules/model.py +index 96e5cd4..63e6c86 100644 +--- a/wan/modules/model.py ++++ b/wan/modules/model.py +@@ -6,7 +6,7 @@ import torch.nn as nn + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin + +-from .attention import flash_attention ++from .attention import flash_attention, attention + + __all__ = ['WanModel'] + +@@ -24,7 +24,7 @@ def sinusoidal_embedding_1d(dim, position): + return x + + +-@torch.amp.autocast('cuda', enabled=False) ++@torch.amp.autocast('xpu', enabled=False) + def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( +@@ -35,7 +35,7 @@ def rope_params(max_seq_len, dim, theta=10000): + return freqs + + +-@torch.amp.autocast('cuda', enabled=False) ++@torch.amp.autocast('xpu', enabled=False) + def rope_apply(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + +@@ -142,7 +142,7 @@ class WanSelfAttention(nn.Module): + + q, k, v = qkv_fn(x) + +- x = flash_attention( ++ x = attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, +@@ -172,7 +172,7 @@ class WanCrossAttention(WanSelfAttention): + v = self.v(context).view(b, -1, n, d) + + # compute attention +- x = flash_attention(q, k, v, k_lens=context_lens) ++ x = attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) +@@ -235,7 +235,7 @@ class WanAttentionBlock(nn.Module): + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 +- with torch.amp.autocast('cuda', dtype=torch.float32): ++ with torch.amp.autocast('xpu', dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2) + assert e[0].dtype == torch.float32 + +@@ -243,7 +243,7 @@ class WanAttentionBlock(nn.Module): + y = self.self_attn( + self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), + seq_lens, grid_sizes, freqs) +- with torch.amp.autocast('cuda', dtype=torch.float32): ++ with torch.amp.autocast('xpu', dtype=torch.float32): + x = x + y * e[2].squeeze(2) + + # cross-attention & ffn function +@@ -251,7 +251,7 @@ class WanAttentionBlock(nn.Module): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn( + self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)) +- with torch.amp.autocast('cuda', dtype=torch.float32): ++ with torch.amp.autocast('xpu', dtype=torch.float32): + x = x + y * e[5].squeeze(2) + return x + +@@ -283,7 +283,7 @@ class Head(nn.Module): + e(Tensor): Shape [B, L1, C] + """ + assert e.dtype == torch.float32 +- with torch.amp.autocast('cuda', dtype=torch.float32): ++ with torch.amp.autocast('xpu', dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2) + x = ( + self.head( +@@ -459,7 +459,8 @@ class WanModel(ModelMixin, ConfigMixin): + # time embeddings + if t.dim() == 1: + t = t.expand(t.size(0), seq_len) +- with torch.amp.autocast('cuda', dtype=torch.float32): ++ ++ with torch.xpu.amp.autocast(dtype=torch.float32): + bt = t.size(0) + t = t.flatten() + e = self.time_embedding( +diff --git a/wan/modules/t5.py b/wan/modules/t5.py +index c841b04..bdba262 100644 +--- a/wan/modules/t5.py ++++ b/wan/modules/t5.py +@@ -475,7 +475,7 @@ class T5EncoderModel: + self, + text_len, + dtype=torch.bfloat16, +- device=torch.cuda.current_device(), ++ device=torch.xpu.current_device(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, +diff --git a/wan/modules/vae2_1.py b/wan/modules/vae2_1.py +index 98c2590..8bdf51d 100644 +--- a/wan/modules/vae2_1.py ++++ b/wan/modules/vae2_1.py +@@ -2,7 +2,7 @@ + import logging + + import torch +-import torch.cuda.amp as amp ++import torch.xpu.amp as amp + import torch.nn as nn + import torch.nn.functional as F + from einops import rearrange +@@ -622,7 +622,7 @@ class Wan2_1_VAE: + z_dim=16, + vae_pth='cache/vae_step_411000.pth', + dtype=torch.float, +- device="cuda"): ++ device="xpu"): + self.dtype = dtype + self.device = device + +diff --git a/wan/modules/vae2_2.py b/wan/modules/vae2_2.py +index c0b3f29..015ac3e 100644 +--- a/wan/modules/vae2_2.py ++++ b/wan/modules/vae2_2.py +@@ -2,7 +2,7 @@ + import logging + + import torch +-import torch.cuda.amp as amp ++import torch.xpu.amp as amp + import torch.nn as nn + import torch.nn.functional as F + from einops import rearrange +@@ -31,12 +31,20 @@ class CausalConv3d(nn.Conv3d): + ) + self.padding = (0, 0, 0) + +- def forward(self, x, cache_x=None): ++ def forward(self, x, cache_x=None, cache_list=None, cache_idx=None): ++ if cache_list is not None: ++ cache_x = cache_list[cache_idx] ++ cache_list[cache_idx] = None ++ + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: ++ # torch.xpu.empty_cache() + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] ++ del cache_x ++ torch.xpu.synchronize() ++ torch.xpu.empty_cache() + x = F.pad(x, padding) + + return super().forward(x) +@@ -212,7 +220,8 @@ class ResidualBlock(nn.Module): + if in_dim != out_dim else nn.Identity()) + + def forward(self, x, feat_cache=None, feat_idx=[0]): +- h = self.shortcut(x) ++ # h = self.shortcut(x) ++ old_x = x + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] +@@ -227,12 +236,12 @@ class ResidualBlock(nn.Module): + ], + dim=2, + ) +- x = layer(x, feat_cache[idx]) ++ x = layer(x, cache_list=feat_cache, cache_idx=idx) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) +- return x + h ++ return x + self.shortcut(old_x) + + + class AttentionBlock(nn.Module): +@@ -487,14 +496,24 @@ class Up_ResidualBlock(nn.Module): + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): +- x_main = x.clone() +- for module in self.upsamples: +- x_main = module(x_main, feat_cache, feat_idx) ++ # x_main = x.clone() ++ # for module in self.upsamples: ++ # x_main = module(x_main, feat_cache, feat_idx) ++ # if self.avg_shortcut is not None: ++ # x_shortcut = self.avg_shortcut(x, first_chunk) ++ # return x_main + x_shortcut ++ # else: ++ # return x_main + if self.avg_shortcut is not None: ++ x_main = x.clone() ++ for module in self.upsamples: ++ x_main = module(x_main, feat_cache, feat_idx) + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: +- return x_main ++ for module in self.upsamples: ++ x = module(x, feat_cache, feat_idx) ++ return x + + + class Encoder3d(nn.Module): +@@ -604,7 +623,8 @@ class Encoder3d(nn.Module): + ], + dim=2, + ) +- x = layer(x, feat_cache[idx]) ++ # x = layer(x, feat_cache[idx]) ++ x = layer(x, cache_list=feat_cache, cache_idx=idx) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: +@@ -715,7 +735,8 @@ class Decoder3d(nn.Module): + ], + dim=2, + ) +- x = layer(x, feat_cache[idx]) ++ # x = layer(x, feat_cache[idx]) ++ x = layer(x, cache_list=feat_cache, cache_idx=idx) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: +@@ -894,8 +915,8 @@ class Wan2_2_VAE: + vae_pth=None, + dim_mult=[1, 2, 4, 4], + temperal_downsample=[False, True, True], +- dtype=torch.float, +- device="cuda", ++ dtype=torch.float16, ++ device="xpu", + ): + + self.dtype = dtype +diff --git a/wan/text2video.py b/wan/text2video.py +index 7c79c66..cb8bf56 100644 +--- a/wan/text2video.py ++++ b/wan/text2video.py +@@ -10,7 +10,7 @@ from contextlib import contextmanager + from functools import partial + + import torch +-import torch.cuda.amp as amp ++import torch.xpu.amp as amp + import torch.distributed as dist + from tqdm import tqdm + +@@ -69,7 +69,7 @@ class WanT2V: + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + """ +- self.device = torch.device(f"cuda:{device_id}") ++ self.device = torch.device(f"xpu:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu +@@ -192,7 +192,7 @@ class WanT2V: + if offload_model or self.init_on_cpu: + if next(getattr( + self, +- offload_model_name).parameters()).device.type == 'cuda': ++ offload_model_name).parameters()).device.type == 'xpu': + getattr(self, offload_model_name).to('cpu') + if next(getattr( + self, +@@ -298,7 +298,7 @@ class WanT2V: + + # evaluation mode + with ( +- torch.amp.autocast('cuda', dtype=self.param_dtype), ++ torch.amp.autocast('xpu', dtype=self.param_dtype), + torch.no_grad(), + no_sync_low_noise(), + no_sync_high_noise(), +@@ -363,7 +363,7 @@ class WanT2V: + if offload_model: + self.low_noise_model.cpu() + self.high_noise_model.cpu() +- torch.cuda.empty_cache() ++ torch.xpu.empty_cache() + if self.rank == 0: + videos = self.vae.decode(x0) + +@@ -371,7 +371,7 @@ class WanT2V: + del sample_scheduler + if offload_model: + gc.collect() +- torch.cuda.synchronize() ++ torch.xpu.synchronize() + if dist.is_initialized(): + dist.barrier() + +diff --git a/wan/textimage2video.py b/wan/textimage2video.py +index 67e9fd2..e786bc5 100644 +--- a/wan/textimage2video.py ++++ b/wan/textimage2video.py +@@ -10,7 +10,7 @@ from contextlib import contextmanager + from functools import partial + + import torch +-import torch.cuda.amp as amp ++import torch.xpu.amp as amp + import torch.distributed as dist + import torchvision.transforms.functional as TF + from PIL import Image +@@ -72,7 +72,7 @@ class WanTI2V: + Convert DiT model parameters dtype to 'config.param_dtype'. + Only works without FSDP. + """ +- self.device = torch.device(f"cuda:{device_id}") ++ self.device = torch.device(f"xpu:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu +@@ -154,6 +154,10 @@ class WanTI2V: + else: + if convert_model_dtype: + model.to(self.param_dtype) ++ # TODO: check autocast on XPU ++ model.time_embedding.to(torch.float) ++ model.time_projection.to(torch.float) ++ model.head.to(torch.float) + if not self.init_on_cpu: + model.to(self.device) + +@@ -327,7 +331,7 @@ class WanTI2V: + + # evaluation mode + with ( +- torch.amp.autocast('cuda', dtype=self.param_dtype), ++ torch.xpu.amp.autocast(dtype=self.param_dtype), + torch.no_grad(), + no_sync(), + ): +@@ -362,7 +366,7 @@ class WanTI2V: + + if offload_model or self.init_on_cpu: + self.model.to(self.device) +- torch.cuda.empty_cache() ++ torch.xpu.empty_cache() + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents +@@ -395,16 +399,22 @@ class WanTI2V: + x0 = latents + if offload_model: + self.model.cpu() +- torch.cuda.synchronize() +- torch.cuda.empty_cache() ++ torch.xpu.synchronize() ++ torch.xpu.empty_cache() ++ del noise, latents, noise_pred ++ del sample_scheduler ++ del self.model ++ torch.xpu.empty_cache() ++ torch.xpu.synchronize() ++ + if self.rank == 0: + videos = self.vae.decode(x0) + +- del noise, latents +- del sample_scheduler ++ # del noise, latents ++ # del sample_scheduler + if offload_model: + gc.collect() +- torch.cuda.synchronize() ++ torch.xpu.synchronize() + if dist.is_initialized(): + dist.barrier() + +@@ -519,7 +529,7 @@ class WanTI2V: + + # evaluation mode + with ( +- torch.amp.autocast('cuda', dtype=self.param_dtype), ++ torch.amp.autocast('xpu', dtype=self.param_dtype), + torch.no_grad(), + no_sync(), + ): +@@ -562,7 +572,7 @@ class WanTI2V: + + if offload_model or self.init_on_cpu: + self.model.to(self.device) +- torch.cuda.empty_cache() ++ torch.xpu.empty_cache() + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = [latent.to(self.device)] +@@ -580,11 +590,11 @@ class WanTI2V: + noise_pred_cond = self.model( + latent_model_input, t=timestep, **arg_c)[0] + if offload_model: +- torch.cuda.empty_cache() ++ torch.xpu.empty_cache() + noise_pred_uncond = self.model( + latent_model_input, t=timestep, **arg_null)[0] + if offload_model: +- torch.cuda.empty_cache() ++ torch.xpu.empty_cache() + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + +@@ -601,18 +611,25 @@ class WanTI2V: + del latent_model_input, timestep + + if offload_model: +- self.model.cpu() +- torch.cuda.synchronize() +- torch.cuda.empty_cache() ++ # self.model.cpu() ++ del self.model ++ torch.xpu.synchronize() ++ torch.xpu.empty_cache() + +- if self.rank == 0: +- videos = self.vae.decode(x0) ++ del context, context_null ++ del sample_scheduler ++ del self.vae.encoder ++ torch.xpu.synchronize() ++ torch.xpu.empty_cache() ++ ++ if self.rank == 0: ++ videos = self.vae.decode(x0) + + del noise, latent, x0 +- del sample_scheduler ++ + if offload_model: + gc.collect() +- torch.cuda.synchronize() ++ torch.xpu.synchronize() + if dist.is_initialized(): + dist.barrier() +