Skip to content
95 changes: 95 additions & 0 deletions examples/research_projects/sana/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Training SANA Sprint Diffuser

This README explains how to use the provided bash script commands to download a pre-trained teacher diffuser model and train it on a specific dataset, following the [SANA Sprint methodology](https://arxiv.org/abs/2503.09641).


## Setup

### 1. Define the local paths

Set a variable for your desired output directory. This directory will store the downloaded model and the training checkpoints/results.

```bash
your_local_path='output' # Or any other path you prefer
mkdir -p $your_local_path # Create the directory if it doesn't exist
```

### 2. Download the pre-trained model

Download the SANA Sprint teacher model from Hugging Face Hub. The script uses the 1.6B parameter model.

```bash
huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
```

*(Optional: You can also download the 0.6B model by replacing the model name: `Efficient-Large-Model/Sana_Sprint_0.6B_1024px_teacher_diffusers`)*

### 3. Acquire the dataset shards

The training script in this example uses specific `.parquet` shards from a randomly selected `brivangl/midjourney-v6-llava` dataset instead of downloading the entire dataset automatically via `dataset_name`.

The script specifically uses these three files:
* `data/train_000.parquet`
* `data/train_001.parquet`
* `data/train_002.parquet`



You can either:

Let the script download the dataset automatically during first run

Or download it manually

**Note:** The full `brivangl/midjourney-v6-llava` dataset is much larger and contains many more shards. This script example explicitly trains *only* on the three specified shards.

## Usage

Once the model is downloaded, you can run the training script.

```bash

your_local_path='output' # Ensure this variable is set

python train_sana_sprint_diffusers.py \
--pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \
--output_dir=$your_local_path \
--mixed_precision=bf16 \
--resolution=1024 \
--learning_rate=1e-6 \
--max_train_steps=30000 \
--dataloader_num_workers=8 \
--dataset_name='brivangl/midjourney-v6-llava' \
--file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \
--checkpointing_steps=500 --checkpoints_total_limit=10 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--seed=453645634 \
--train_largest_timestep \
--misaligned_pairs_D \
--gradient_checkpointing \
--resume_from_checkpoint="latest" \
```

### Explanation of parameters

* `--pretrained_model_name_or_path`: Path to the downloaded pre-trained model directory.
* `--output_dir`: Directory where training logs, checkpoints, and the final model will be saved.
* `--mixed_precision`: Use BF16 mixed precision for training, which can save memory and speed up training on compatible hardware.
* `--resolution`: The image resolution used for training (1024x1024).
* `--learning_rate`: The learning rate for the optimizer.
* `--max_train_steps`: The total number of training steps to perform.
* `--dataloader_num_workers`: Number of worker processes for loading data. Increase for faster data loading if your CPU and disk can handle it.
* `--dataset_name`: The name of the dataset on Hugging Face Hub (`brivangl/midjourney-v6-llava`).
* `--file_path`: **Specifies the local paths to the dataset shards to be used for training.** In this case, `data/train_000.parquet`, `data/train_001.parquet`, and `data/train_002.parquet`.
* `--checkpointing_steps`: Save a training checkpoint every X steps.
* `--checkpoints_total_limit`: Maximum number of checkpoints to keep. Older checkpoints will be deleted.
* `--train_batch_size`: The batch size per GPU.
* `--gradient_accumulation_steps`: Number of steps to accumulate gradients before performing an optimizer step.
* `--seed`: Random seed for reproducibility.
* `--train_largest_timestep`: A specific training strategy focusing on larger timesteps.
* `--misaligned_pairs_D`: Another specific training strategy to add misaligned image-text pairs as fake data for GAN.
* `--gradient_checkpointing`: Enable gradient checkpointing to save GPU memory.
* `--resume_from_checkpoint`: Allows resuming training from the latest saved checkpoint in the `--output_dir`.


Loading
Loading