# Ditto Training Pipeline (Global Style)

This notebook sets up the environment and runs the Ditto training pipeline on the global style editing subset (videos/global_style1 & videos/global_style2) of Ditto-1M on Google Colab Pro.

## 1. Mount Google Drive
Mount your Google Drive to save checkpoints and access the dataset.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Setup Environment
Clone the repository and install necessary dependencies.

In [None]:
!git clone https://github.com/dhruv0000/ditto-vace-fork.git
%cd ditto-vace-fork
!pip install -r requirements.txt
!pip install -e .
!pip install accelerate

## 3. Data Setup
Download the global style dataset and metadata from Hugging Face.

In [None]:
#@title Data Download/Setup (Global Style)
import os
from huggingface_hub import snapshot_download

# Store dataset and caches on Google Drive so they persist across sessions
os.environ["HF_HOME"] = "/content/drive/MyDrive/hf_home"
os.environ["HF_DATASETS_CACHE"] = "/content/drive/MyDrive/hf_home"

dataset_root = "/content/drive/MyDrive/Ditto-1M"

# If True and metadata already exists in Drive, skip re-downloading
skip_download_if_exists = True  #@param {type:"boolean"}

# Comma-separated list of tar indices to download for each subset, e.g. "01,02"
tar_indices = "01"  #@param {type:"string"}

# Whether to download archives for each subset
download_source_archives = True  #@param {type:"boolean"}
download_global_style1_archives = True  #@param {type:"boolean"}
download_global_style2_archives = True  #@param {type:"boolean"}

metadata_file = os.path.join(dataset_root, "csvs_for_DiffSynth", "global_style.csv")
if skip_download_if_exists and os.path.isdir(dataset_root) and os.path.exists(metadata_file):
    print(f"Found existing Ditto-1M metadata in Drive at {dataset_root}. Skipping download.")
else:
    # Download metadata JSON and CSV for the global style task
    snapshot_download(
        repo_id="QingyanBai/Ditto-1M",
        repo_type="dataset",
        local_dir=dataset_root,
        allow_patterns=[
            "training_metadata/global_style.json",
            "csvs_for_DiffSynth/global_style.csv",
        ],
    )

    # Build list of tar archives to download
    indices = [s.strip() for s in tar_indices.split(",") if s.strip()]
    subsets = []
    if download_source_archives:
        subsets.append("source")
    if download_global_style1_archives:
        subsets.append("global_style1")
    if download_global_style2_archives:
        subsets.append("global_style2")

    tar_allow_patterns = []
    for subset in subsets:
        for idx in indices:
            tar_allow_patterns.append(f"videos/{subset}/{subset}.tar.gz.{idx}")

    if tar_allow_patterns:
        print("Downloading tar archives:")
        for p in tar_allow_patterns:
            print("  ", p)
        snapshot_download(
            repo_id="QingyanBai/Ditto-1M",
            repo_type="dataset",
            local_dir=dataset_root,
            allow_patterns=tar_allow_patterns,
        )
    else:
        print("No tar archives selected for download (tar_indices or subset flags empty).")

    print("Metadata and selected video archives downloaded to Google Drive.")

# Optionally extract the split archives from Drive
extract_archives = True  #@param {type:"boolean"}

if extract_archives:
    import glob

    cwd = os.getcwd()
    try:
        for subset in ["source", "global_style1", "global_style2"]:
            subset_dir = os.path.join(dataset_root, "videos", subset)
            if not os.path.isdir(subset_dir):
                print(f"Directory {subset_dir} does not exist, skipping.")
                continue

            os.chdir(subset_dir)
            part_pattern = f"{subset}.tar.gz.*"
            part_files = sorted(glob.glob(part_pattern))
            if not part_files:
                print(f"No split archives found for {subset}, skipping extraction.")
                continue

            print(f"Extracting {subset} from {len(part_files)} tar archives...")
            os.system(f"cat {part_pattern} | tar -zxv")
    finally:
        os.chdir(cwd)

print("Data setup complete.")

In [None]:
#@title Build mini CSV from extracted videos
import os
import pandas as pd

dataset_root = "/content/drive/MyDrive/Ditto-1M"
csv_path = os.path.join(dataset_root, "csvs_for_DiffSynth", "global_style.csv")

print("Loading full global_style CSV from:", csv_path)
df = pd.read_csv(csv_path)

videos_root = os.path.join(dataset_root, "videos")
existing_files = set()
for root, _, files in os.walk(videos_root):
    for fname in files:
        rel_path = os.path.relpath(os.path.join(root, fname), dataset_root)
        existing_files.add(rel_path)

print("Total existing video files under videos/:", len(existing_files))

required_cols = ["video", "vace_video"]
for col in required_cols:
    if col not in df.columns:
        raise ValueError(f"Required column '{col}' not found in CSV. Available columns: {df.columns.tolist()}")

mask = df["video"].isin(existing_files) & df["vace_video"].isin(existing_files)
df_existing = df[mask].reset_index(drop=True)
print("Rows with both video and vace_video present:", len(df_existing))

n_samples = 20000  #@param {type:"integer"}
if n_samples > len(df_existing):
    n_samples = len(df_existing)
    print(f"Requested n_samples larger than available; using {n_samples} rows.")

if n_samples > 0:
    df_mini = df_existing.sample(n=n_samples, random_state=0)
else:
    df_mini = df_existing

mini_csv_path = os.path.join(dataset_root, "csvs_for_DiffSynth", "global_style_mini.csv")
df_mini.to_csv(mini_csv_path, index=False)
print("Saved mini CSV to:", mini_csv_path)


## 4. Configuration
Configure the training parameters for the global style task. The default model is `Wan-AI/Wan2.1-VACE-1.3B`.

In [None]:
#@title Training Configuration

dataset_base_path = "/content/drive/MyDrive/Ditto-1M/videos" #@param {type:"string"}
dataset_metadata_path = "/content/drive/MyDrive/Ditto-1M/csvs_for_DiffSynth/global_style_mini.csv" #@param {type:"string"}
output_path = "/content/drive/MyDrive/exps/ditto" #@param {type:"string"}
model_id = "Wan-AI/Wan2.1-VACE-1.3B" #@param {type:"string"}
num_epochs = 5 #@param {type:"integer"}
learning_rate = "1e-4" #@param {type:"string"}

print(f"Configuration:")
print(f"  Dataset Base Path: {dataset_base_path}")
print(f"  Metadata Path: {dataset_metadata_path}")
print(f"  Output Path: {output_path}")
print(f"  Model ID: {model_id}")
print(f"  Epochs: {num_epochs}")
print(f"  Learning Rate: {learning_rate}")

## 5. Run Training
Execute the training script with the configured parameters.

In [None]:
!chmod +x train.sh
!./train.sh \
  --dataset_base_path "{dataset_base_path}" \
  --dataset_metadata_path "{dataset_metadata_path}" \
  --output_path "{output_path}" \
  --model_id "{model_id}" \
  --num_epochs "{num_epochs}" \
  --learning_rate "{learning_rate}"