# Wan-ATI: High-Quality Video Generation from a Single Image

This notebook is a converted version of the original inference script for the Wan-ATI model. It allows you to generate a video from a text prompt and a single starting image.

### References:
- **Original Script:** Provided by the user.
- **Project Page / GitHub (Hypothetical):** [Link to Project]

### How to Use:
1.  **Setup Environment:** Make sure you have all the required libraries installed (`wan`, `torch`, `Pillow`, etc.) and have downloaded the model checkpoints.
2.  **Configure Parameters:** In the "1. Configuration" section below, you **must** set the `ckpt_dir` to the path where you saved the checkpoints. You also need to provide your own `prompt`, `image`, and optionally `track` file.
3.  **Run All Cells:** Click `Kernel -> Restart & Run All` to execute the notebook and generate your video.

## 0. Imports

In [4]:
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0

from wan.utils.motion import get_tracks_inference
from wan.utils.utils import cache_video, cache_image
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
import wan
from PIL import Image
import torch.distributed as dist
import torch
import random
import logging
import os
import sys
import warnings
import yaml
from datetime import datetime
from types import SimpleNamespace

warnings.filterwarnings('ignore')

## 1. Configuration

All parameters that were previously command-line arguments are defined here as regular variables. **Modify the values in the next cell** to control the generation process.

In [5]:
args = SimpleNamespace()

# ======================================================================================
# !! REQUIRED: USER INPUTS !!
# ======================================================================================

# -- The path to the checkpoint directory. --
# !! YOU MUST CHANGE THIS to the location of your downloaded models.
args.ckpt_dir = "./Wan2.1-ATI-14B-480P"

# -- The prompt to generate the video from. --
# Can also be a path to a .yaml file for batch processing.
args.prompt = "A tranquil koi pond edged by mossy stone, with lily pads drifting on the surface and several orange\u2011and\u2011white koi fish gliding beneath."

# -- [image to video] The input image to generate the video from. --
args.image = "./examples/images/fish.jpg" # e.g., 'assets/images/cat.png'

# -- The stored point trajectory to generate the video. (Optional) --
args.track = "./examples/tracks/fish.pth" # e.g., 'assets/tracks/cat.txt'

# ======================================================================================
# Core Generation Parameters
# ======================================================================================

# The task to run.
args.task = "ati-14B" # choices: list(WAN_CONFIGS.keys())

# The area (width*height) of the generated video.
args.size = "832*480" # choices: list(SIZE_CONFIGS.keys())

# How many frames to sample. Should be 4n+1. If None, defaults to 81 for video tasks.
args.frame_num = None

# The file to save the generated video to. If None, a name is generated automatically.
args.save_file = None # e.g., "my_video.mp4"

# The seed to use. -1 for a random seed.
args.base_seed = 42

# ======================================================================================
# Advanced Sampling Parameters
# ======================================================================================

# The solver used to sample.
args.sample_solver = 'unipc' # choices: ['unipc', 'dpm++']

# The sampling steps. If None, defaults to 10.
args.sample_steps = 10

# Sampling shift factor for flow matching schedulers. If None, defaults to 5.0.
args.sample_shift = None 

# Classifier free guidance scale.
args.sample_guide_scale = 5.0

# ======================================================================================
# System & Performance Parameters (Advanced)
# ======================================================================================

# Whether to offload the model to CPU after each forward pass. Reduces GPU memory.
# If None, it will be set to True for single-GPU and False for multi-GPU.
args.offload_model = None

# The size of the ulysses parallelism in DiT (for multi-GPU).
args.ulysses_size = 1

# The size of the ring attention parallelism in DiT (for multi-GPU).
args.ring_size = 1

# Whether to use FSDP for T5 (for multi-GPU).
args.t5_fsdp = False

# Whether to place T5 model on CPU.
args.t5_cpu = False

# Whether to use FSDP for DiT (for multi-GPU).
args.dit_fsdp = False

# ======================================================================================
# Prompt Extension Parameters (Optional)
# ======================================================================================
args.use_prompt_extend = False
args.prompt_extend_method = "local_qwen" # choices: ["dashscope", "local_qwen"]
args.prompt_extend_model = None
args.prompt_extend_target_lang = "zh" # choices: ["zh", "en"]

# ======================================================================================
# Other Task-Specific Parameters (leave as None if not used)
# ======================================================================================
args.src_video = None
args.src_mask = None
args.src_ref_images = None
args.first_frame = None
args.last_frame = None

## 2. Helper Functions

In [6]:
def _validate_args(args):
    """Validates arguments and sets defaults if they are None."""
    # Basic check
    assert args.ckpt_dir is not None, "Please specify the checkpoint directory in the cell above."
    assert os.path.exists(args.ckpt_dir), f"Checkpoint directory not found at: {args.ckpt_dir}"
    assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"

    # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
    if args.sample_steps is None:
        args.sample_steps = 10

    if args.sample_shift is None:
        args.sample_shift = 5.0

    # The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
    if args.frame_num is None:
        args.frame_num = 1 if "t2i" in args.task else 81

    # T2I frame_num check
    if "t2i" in args.task:
        assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"

    if args.base_seed is None or args.base_seed < 0:
        args.base_seed = random.randint(0, sys.maxsize)
        
    # Size check
    assert args.size in SUPPORTED_SIZES[
        args.task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
    
    # Input file checks
    if args.prompt and not args.prompt.endswith('.yaml'):
        assert args.image is not None, "Please provide an input image via 'args.image'"
        assert os.path.exists(args.image), f"Input image not found at: {args.image}"
        if args.track:
             assert os.path.exists(args.track), f"Track file not found at: {args.track}"
    elif args.prompt and args.prompt.endswith('.yaml'):
        assert os.path.exists(args.prompt), f"YAML prompt file not found at: {args.prompt}"
    else:
        raise ValueError("Please provide a text prompt or a YAML file for batch processing.")

def _init_logging(rank):
    """Initializes logging for the notebook."""
    # Remove all handlers associated with the root logger
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
        
    if rank == 0:
        # set format
        logging.basicConfig(
            level=logging.INFO,
            format="[%(asctime)s] %(levelname)s: %(message)s",
            handlers=[logging.StreamHandler(stream=sys.stdout)])
    else:
        logging.basicConfig(level=logging.ERROR)

## 3. Execution

The following cell contains the main logic to set up the environment, load the model, and run the generation.

In [None]:
# --- 1. Validate Arguments & Environment Setup ---
_validate_args(args)

# For a standard notebook environment, we simulate a single-process run.
rank = 0
world_size = 1
local_rank = 0
device = 0
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(local_rank)

_init_logging(rank)

if args.offload_model is None:
    args.offload_model = False if world_size > 1 else True
    logging.info(f"offload_model is not specified, set to {args.offload_model}.")

# This notebook assumes a single-GPU/CPU setup. Distributed features are disabled.
assert not (args.t5_fsdp or args.dit_fsdp), "FSDP is not supported in this notebook setup."
assert not (args.ulysses_size > 1 or args.ring_size > 1), "Context parallelism is not supported in this notebook setup."

# --- 2. Setup Prompt Expander (if used) ---
if args.use_prompt_extend:
    if args.prompt_extend_method == "dashscope":
        prompt_expander = DashScopePromptExpander(model_name=args.prompt_extend_model, is_vl=True)
    elif args.prompt_extend_method == "local_qwen":
        prompt_expander = QwenPromptExpander(model_name=args.prompt_extend_model, is_vl=True, device=device)
    else:
        raise NotImplementedError(f"Unsupport prompt_extend_method: {args.prompt_extend_method}")

# --- 3. Load Model and Prepare Inputs ---
cfg = WAN_CONFIGS[args.task]
logging.info(f"Generation job args: {vars(args)}")
logging.info(f"Generation model config: {cfg}")

torch.manual_seed(args.base_seed)
random.seed(args.base_seed)

if args.prompt.endswith('.yaml'):
    inputs_ = []
    fl_list = yaml.safe_load(open(args.prompt))
    for line in fl_list:
        inputs_.append((line['image'], line['text'].strip(), line['track']))
else:
    inputs_ = [(args.image, args.prompt, args.track)]

logging.info("Creating WanATI pipeline...")
wan_ati = wan.WanATI(
    config=cfg,
    checkpoint_dir=args.ckpt_dir,
    device_id=device,
    rank=rank,
    t5_fsdp=args.t5_fsdp,
    dit_fsdp=args.dit_fsdp,
    use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
    t5_cpu=args.t5_cpu,
)

# --- 4. Run Generation Loop ---
for ii, input_ in enumerate(inputs_):
    # Determine save file name
    current_save_file = args.save_file
    if current_save_file is None:
        formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        if args.prompt.endswith(".yaml"):
            formatted_prompt = f"{ii:02d}"
        else:
            # Sanitize prompt for use in filename
            sanitized_prompt = "".join([c for c in args.prompt if c.isalpha() or c.isdigit() or c==' ']).rstrip()
            formatted_prompt = sanitized_prompt.replace(" ", "_").replace("/", "_")[:50]
        suffix = '.mp4'
        current_save_file = f"output/{args.task}_{args.size.replace('*','x')}_{formatted_prompt}_{formatted_time}{suffix}"
    elif '%' in current_save_file:
        current_save_file = current_save_file % ii

    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(current_save_file), exist_ok=True)
        
    if os.path.exists(current_save_file):
        logging.info(f"File {current_save_file} already exists, skipping.")
        continue

    image, prompt, tracks = input_
    logging.info(f"Input prompt: {prompt}")
    logging.info(f"Input image: {image}")

    img = Image.open(image).convert("RGB")

    width, height = img.size
    tracks = get_tracks_inference(tracks, height, width)

    if args.use_prompt_extend:
        logging.info("Extending prompt ...")
        prompt_output = prompt_expander(
            prompt,
            tar_lang=args.prompt_extend_target_lang,
            image=img,
            seed=args.base_seed)
        if prompt_output.status == False:
            logging.info(f"Extending prompt failed: {prompt_output.message}")
            logging.info("Falling back to original prompt.")
        else:
            prompt = prompt_output.prompt
        logging.info(f"Extended prompt: {prompt}")

    logging.info("Generating video ...")
    video = wan_ati.generate(
        prompt,
        img,
        tracks,
        max_area=MAX_AREA_CONFIGS[args.size],
        frame_num=args.frame_num,
        shift=args.sample_shift,
        sample_solver=args.sample_solver,
        sampling_steps=args.sample_steps,
        guide_scale=args.sample_guide_scale,
        seed=args.base_seed,
        offload_model=args.offload_model)

    logging.info(f"Saving generated video to {current_save_file}")
    cache_video(
        tensor=video[None],
        save_file=current_save_file,
        fps=cfg.sample_fps,
        nrow=1,
        normalize=True,
        value_range=(-1, 1))
    
    # Display video in notebook (optional)
    try:
        from IPython.display import Video, display
        display(Video(current_save_file, embed=True, width=400))
    except ImportError:
        print("Install IPython to display the video directly in the notebook.")
    
logging.info("Finished.")

[2025-07-02 20:00:16,725] INFO: Generation job args: {'ckpt_dir': './Wan2.1-ATI-14B-480P', 'prompt': 'A tranquil koi pond edged by mossy stone, with lily pads drifting on the surface and several orange‑and‑white koi fish gliding beneath.', 'image': './examples/images/fish.jpg', 'track': './examples/tracks/fish.pth', 'task': 'ati-14B', 'size': '832*480', 'frame_num': 81, 'save_file': None, 'base_seed': 42, 'sample_solver': 'unipc', 'sample_steps': 10, 'sample_shift': 5.0, 'sample_guide_scale': 5.0, 'offload_model': True, 'ulysses_size': 1, 'ring_size': 1, 't5_fsdp': False, 't5_cpu': False, 'dit_fsdp': False, 'use_prompt_extend': False, 'prompt_extend_method': 'local_qwen', 'prompt_extend_model': None, 'prompt_extend_target_lang': 'zh', 'src_video': None, 'src_mask': None, 'src_ref_images': None, 'first_frame': None, 'last_frame': None}
[2025-07-02 20:00:16,726] INFO: Generation model config: {'__name__': 'Config: Wan I2V 14B', 't5_model': 'umt5_xxl', 't5_dtype': torch.bfloat16, 'text_le