# Change directory into the `ai-toolkit` repo clone

In [None]:
%cd /root/ai-toolkit

# Load and set configuration

In [None]:
!ls config/

In [None]:
import yaml

yaml_path = "config/train_cfg.yaml"

with open(yaml_path) as f:
    wan_config = yaml.safe_load(f)

### Most things you'd want to change are in `.config.process.0`

In [None]:
wan_config_dict = wan_config["config"]["process"][0]

In [None]:
wan_config_dict

In [None]:
# wan_config_dict["train"]["steps"] = 4000 # example of over-riding training steps

### Select the dataset

In [None]:
!ls /root/ai-toolkit/data

In [None]:
default_dataset_path = "/root/ai-toolkit/data/sample"
my_dataset_path = None # set this to train on your data!

dataset_path = my_dataset_path or default_dataset_path

In [None]:
wan_config_dict["datasets"][0]["folder_path"] = dataset_path

### Select the model

In [None]:
model = "1.3B"  # start with the smaller one, it trains and runs much faster
# model = "14B"  # run the larger model for better results

wan_config_dict["model"]["name_or_path"] = f"Wan-AI/Wan2.1-T2V-{model}-Diffusers"

### Set the training parameters

In [None]:
import math

batch_size = 4 if model == "1.3B" else 2
wan_config_dict["train"]["batch_size"] = batch_size
wan_config_dict["train"]["gradient_checkpointing"] = True
wan_config_dict["train"]["lr"] = 1e-4
wan_config_dict["train"]["optimizer_params"]["weight_decay"] = 1e-4 # * math.sqrt(batch_size)

wan_config_dict["datasets"][0]["shuffle_tokens"] = True

steps = 1000 // int(math.sqrt(batch_size))
check_every = steps // 4

wan_config_dict["train"]["steps"] = steps
wan_config_dict["save"]["save_every"] = check_every
wan_config_dict["sample"]["sample_every"] = check_every

### Set the model save directory if not provided

In [None]:
from hashlib import md5 as hasher

name = None  # override manually if you want

name = hasher(str(wan_config_dict).encode("utf-8")).hexdigest()

wan_config["config"]["name"] = name
wan_config["config"]["name"]

### Persist the config to disk

In [None]:
with open('config/final_train_cfg.yaml', 'w') as file:
    yaml.dump(wan_config, file)

# Run fine-tuning

In [None]:
!python run.py config/final_train_cfg.yaml