# Quantize Wan2.2-Animate-14B DiT to NF4 (4-bit)

**Run this once on a Kaggle GPU (T4) session.**

This notebook:
1. Downloads the Wan2.2 code repo and the full-precision DiT model from HuggingFace
2. Quantizes all Linear layers to NF4 using bitsandbytes
3. Saves the quantized model (~9 GB)

After running, create a Kaggle dataset from the output for use in inference notebooks.

## 1. Install dependencies & clone repo

In [None]:
# Kaggle already has: torch, torchvision, numpy, opencv, tqdm, PIL
# Install/upgrade what's needed with pinned versions from Wan2.2:
!pip install -q \
    bitsandbytes>=0.43.0 \
    huggingface_hub \
    "transformers>=4.49.0,<=4.51.3" \
    "diffusers>=0.31.0" \
    "accelerate>=1.1.1" \
    easydict ftfy regex decord peft einops safetensors sentencepiece

# Clone the Wan2.2 repo (needed for WanAnimateModel class definition)
!git clone https://github.com/Wan-Video/Wan2.2.git /kaggle/working/Wan2.2

## 2. Download DiT model from HuggingFace

We only download the DiT files (config.json + safetensors shards + index).
Skip T5, CLIP, VAE — they're not needed for quantization.

In [None]:
import os
import shutil
from huggingface_hub import hf_hub_download

REPO_ID = "Wan-AI/Wan2.2-Animate-14B"
TMP_MODEL_DIR = "/tmp/Wan2.2-Animate-14B"       # full-precision shards go here (tmpfs, not persistent disk)
OUTPUT_DIR = "/kaggle/working/Wan2.2-Animate-14B-NF4"  # quantized model goes here (persistent)

os.makedirs(TMP_MODEL_DIR, exist_ok=True)

# Only the DiT files — skip T5 (~11.4 GB), CLIP (~4.8 GB), VAE (~0.5 GB)
dit_files = [
    "config.json",
    "diffusion_pytorch_model.safetensors.index.json",
    "diffusion_pytorch_model-00001-of-00004.safetensors",
    "diffusion_pytorch_model-00002-of-00004.safetensors",
    "diffusion_pytorch_model-00003-of-00004.safetensors",
    "diffusion_pytorch_model-00004-of-00004.safetensors",
]

for filename in dit_files:
    dest = os.path.join(TMP_MODEL_DIR, filename)
    if os.path.exists(dest):
        print(f"Already exists: {filename}")
        continue
    print(f"Downloading: {filename}...")
    hf_hub_download(
        repo_id=REPO_ID,
        filename=filename,
        local_dir=TMP_MODEL_DIR,
    )
    print(f"  Done: {os.path.getsize(dest) / 1e9:.2f} GB")

print("\nModel directory contents:")
for f in sorted(os.listdir(TMP_MODEL_DIR)):
    size = os.path.getsize(os.path.join(TMP_MODEL_DIR, f))
    print(f"  {f:60s} {size / 1e9:.2f} GB")

## 3. Setup imports & check GPUs

In [None]:
import torch
import sys
from diffusers import BitsAndBytesConfig

# Add the cloned Wan2.2 repo to path
WAN_REPO_PATH = "/kaggle/working/Wan2.2"
if WAN_REPO_PATH not in sys.path:
    sys.path.insert(0, WAN_REPO_PATH)

from wan.modules.animate.model_animate import WanAnimateModel

print(f"torch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)} ({torch.cuda.get_device_properties(i).total_mem / 1e9:.1f} GB)")

## 4. Load and quantize the DiT

`from_pretrained` with `device_map="auto"` loads safetensors shards one at a time,
quantizes each parameter on the fly, and places it on GPU(s).

- **Peak RAM:** ~12 GB (one shard at a time)
- **Peak VRAM:** ~9 GB (accumulated quantized parameters)
- **Time:** ~5-10 minutes on T4

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,  # quantize the scale factors too, saves ~0.4 GB
)

print("Loading and quantizing DiT...")
model = WanAnimateModel.from_pretrained(
    TMP_MODEL_DIR,
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
print("Done!")

# Clean up full-precision shards from /tmp — no longer needed
print("Cleaning up full-precision shards from /tmp...")
shutil.rmtree(TMP_MODEL_DIR)
print("Freed ~34.5 GB from /tmp")

## 4. Inspect the quantized model

Verify that Linear layers were replaced with Linear4bit.

In [None]:
import bitsandbytes as bnb

total_params = 0
quantized_params = 0
full_precision_params = 0

for name, module in model.named_modules():
    if isinstance(module, bnb.nn.Linear4bit):
        n = module.in_features * module.out_features
        quantized_params += n
        total_params += n
    elif isinstance(module, torch.nn.Linear):
        n = module.in_features * module.out_features
        full_precision_params += n
        total_params += n

print(f"Total Linear params:          {total_params / 1e9:.2f} B")
print(f"Quantized (NF4):              {quantized_params / 1e9:.2f} B ({100 * quantized_params / total_params:.1f}%)")
print(f"Full precision (not touched):  {full_precision_params / 1e9:.2f} B ({100 * full_precision_params / total_params:.1f}%)")
print()

# Show VRAM usage
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated() / 1e9
    reserved = torch.cuda.memory_reserved() / 1e9
    print(f"VRAM allocated: {allocated:.1f} GB")
    print(f"VRAM reserved:  {reserved:.1f} GB")

## 5. Save quantized model

In [None]:
print(f"Saving quantized model to {OUTPUT_DIR}...")
model.save_pretrained(OUTPUT_DIR)
print("Saved!")

# Report output size
total_bytes = sum(
    os.path.getsize(os.path.join(root, f))
    for root, _, files in os.walk(OUTPUT_DIR)
    for f in files
)
print(f"\nQuantized model size: {total_bytes / 1e9:.1f} GB")
print(f"\nOutput files:")
for f in sorted(os.listdir(OUTPUT_DIR)):
    size = os.path.getsize(os.path.join(OUTPUT_DIR, f))
    print(f"  {f:60s} {size / 1e9:.2f} GB")

## 6. Verify: reload the quantized model

Quick sanity check — load from the saved directory to confirm it works.

In [None]:
# Free the current model
del model
torch.cuda.empty_cache()

# Reload from saved quantized checkpoint
print("Reloading quantized model from disk...")
model_reloaded = WanAnimateModel.from_pretrained(
    OUTPUT_DIR,
    device_map="auto",
    torch_dtype=torch.float16,
)
print("Reload successful!")

if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated() / 1e9
    print(f"VRAM after reload: {allocated:.1f} GB")

del model_reloaded
torch.cuda.empty_cache()
print("\nDone. Create a Kaggle dataset from the output directory for inference.")

## 7. Push to HuggingFace private repo

Upload the quantized model so you can pull it in any future session.

In [None]:
from huggingface_hub import HfApi, login

# Login — on Kaggle, add your HF token as a Kaggle secret named "HF_TOKEN"
# Or paste it directly: login(token="hf_...")
from kaggle_secrets import UserSecretsClient
secrets = UserSecretsClient()
login(token=secrets.get_secret("HF_TOKEN"))

# Change this to your HF username
HF_USERNAME = "YOUR_USERNAME"
HF_REPO = f"{HF_USERNAME}/Wan2.2-Animate-14B-NF4"

api = HfApi()
api.create_repo(HF_REPO, private=True, exist_ok=True)

print(f"Uploading quantized model to {HF_REPO} (private)...")
api.upload_folder(
    folder_path=OUTPUT_DIR,
    repo_id=HF_REPO,
    repo_type="model",
)
print(f"Done! Model available at: https://huggingface.co/{HF_REPO}")

## Next steps

In your inference notebook, pull the quantized model:
```python
from huggingface_hub import snapshot_download
snapshot_download("YOUR_USERNAME/Wan2.2-Animate-14B-NF4", local_dir="/tmp/dit-nf4")

model = WanAnimateModel.from_pretrained("/tmp/dit-nf4", device_map="auto", torch_dtype=torch.float16)
```
The saved `config.json` contains the quantization config — no need to re-specify `BitsAndBytesConfig`.

LoRA works on quantized layers as usual (QLoRA pattern).