# StableTuner v0.1.1 for Linux

Original version https://github.com/devilismyfriend/StableTuner
<br>
STv1.0 Update (12/16/2022)
<br><br>
Testing on docker image dnwalkup/cuda:116-cudnn8-devel-u2004:
<br>
* Ubuntu 20.04, Nvidia CUDNN 8, CUDA 1.16 <br>
* PIP'd - gDown, wget, ftfy, OmegaConf, tqdm, tensorboard, transfomers, triton, pillow, iPython, pycuda, ipywidgets, jupyterlab <br>
* APT'd - zip, unzip, rename, python3, python3-apt, python3-diskutils, python3-pip <br>

### Set Global Variables and Imports

In [None]:
# ****************************************      DATASET
goog_dataset_url = ''
# Google url for your training dataset (e.g. instance images or user images). Use zip file.

backup_dataset_url = ''
# Hugging face, dropbox, etc. if google is capping your bandwidth for your training dataset (e.g. instance images or user images). Use zip file.

dataset_repeats_var = 1
# How many times would you like to repeat the dataset? The default of 1 here usually works well.


# ****************************************      PRIOR PRESERVATION LOSS MITIGATION
use_regularization = False
# Do you want to use regularization & prior preservation loss (e.g. class images) in your training? If so, this should be True.

goog_regimg_url = ''
# Google url for your regularization images (e.g. class images). Use zip file.

backup_regimg_url = ''
# Hugging face, dropbox, etc. if google is capping your bandwith for your regularization images (e.g. class images). Use zip file.

num_class_images_var = 0
# Enter number of regularization images you would like to generate or already exist in your directory.


# ****************************************      TRAINING SETTINGS
# TIP #1: What is an "epoch"?
# One epoch is one run through your dataset, in other words, one epoch is the number of steps it takes
# to "view" each image once. (-morbuto)
#
# TIP #2: For best results, use a low learning rate, high batch size, and a captioned dataset (-devilismyfriend)
#
# TIP #3: Maths
# Number of dataset images / batch size = epoch size
# For example, a dataset of 100 images and batch size 4 will have 25 steps per epoch. (-morbuto)

batch_size_var = 16
# Play with this number. If you run out of VRAM, lower it.

train_epochs_var = 100
# 50 to 600 epochs will be the right range for most trainings. (-morbuto)

learning_rate_var = 1e-6
# 3e-6 to 5e-7 will be the right range for most trainings.

save_n_epoch_var = 25
# Save every n epochs to help prevent overtraining.


# ****************************************      GENERAL SETTINGS 
sd_base_model = 'SDv15'
# Input the model you'd like to use.
# Options are 'SDv14', 'SDv15', 'SDv20-512', 'SDv20-768', 'SDv21-512', 'SDv21-768'. Default is 'SDv15'.

seed_var = ''
# Leave empty for a random seed.

google_data_cap = False
# Is google capping your bandwidth and not letting you download your files? If so, this needs to be true.

aspect_ratio_bucketing = False
# If you don't want to limit yourself to square 1:1 dataset images, change to True.

auto_balance = False
# Will balance the number of images in each concept dataset to match the minimum number of images in any concept dataset

train_txt_encoder = True
# Whether or not to train the text encoder.

use_vae = True
# Use the VAE with training. Default is True.

output_dir_var = '/workspace/output_model'
# Directory for output including diffusers, checkpoints, and inference samples.

mixed_precision_var = 'fp16'
# Whether to use mixed precision. Choose between 'no', 'fp16' and 'bf16'.
# Bf16 requires PyTorch >= 1.10 and an Nvidia Ampere GPU.


# ****************************************      TOKEN SETTINGS
txt_file_caption = False
# If you have text files you'd like to use for captions, change to True.

use_img_captions = False
# If you want to use the names of each image as captions, change to True.

instance_prompt = ''
# Add a unique identifier here if you want a specific token trained to call your person or style.

class_prompt = ''
# Add a unique identifier here if you want a specific token trained for the class.


# ****************************************      SAVE SAMPLE SETTINGS
sample_progress = False
# Do you want to generate inference samples as you train?

sample_height_var, sample_width_var = 768, 768
# Width & height for samples

sample_at_start = False
# Samples (and saves?) an image on the start of training to help track progress. Default is False.

sample_prompt = ''
# Insert the prompt would you like to use for your samples

num_samples = 0
# How many inference samples to generate?


# ****************************************     ADVANCED SETTINGS
prior_preservation_weight = 1.0
# Variable only used if 'use_class_images' is True.

cudnn_benchmark = False
# This enables or disables CUDNN benchmarking. Default is disabled (False).

txt_encoder_training_epoch_var = 999999999999999
# The epoch at which the text encoder is no longer trained

new_latent_cache = True
# Regenerate latent cache. Necessary if you've made changes to batch size.

save_latents_cache = False
# Needs description

add_reg_img_dataset = False
# Will generate and/or add existing regularization images to the dataset (does not enable prior preservation loss).

train_1024 = False
# If you want to train at 1024 resolution, change to True.

use_8bit_adam = True
# Whether or not to use 8-bit Adam from bitsandbytes. Default is True.

use_gradient_checkpointing = True
# Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.

gradient_accumulation_steps_var = 1
# Number of updates steps to accumulate before performing a backward/update pass.

scheduler_var = 'constant'
# Scheduler type to use.Choose between 'linear', 'cosine', 'cosine_with_restarts', 'polynomial',
# 'constant', 'constant_with_warmup'

scheduler_warmup_var = 0
# Number of steps for the warmup in the scheduler.

use_concepts = False
# Use multiple concepts via concepts file? Will overwrite parameters like instance_prompt, class_prompt, etc.

concepts_file = '/workspace/stabletune_concept_list.json'
# json file containing concepts


import os, gc, random
from subprocess import getoutput
from IPython.display import clear_output

clear_output()
print('Imports complete')

### Install Pip Dependencies

In [None]:
os.system('python3 -m pip install scikit-image')
os.system('python3 -m pip install scipy')
os.system('python3 -m pip install numpy')
os.system('python3 -m pip install huggingface')
os.system('python3 -m pip install albumentations')
os.system('python3 -m pip install opencv-python')
os.system('python3 -m pip install einops')
os.system('python3 -m pip install pytorch_lightning')
os.system('python3 -m pip install safetensors')
os.system('python3 -m pip install diffusers')

clear_output()
print('Pip installs complete')

### Install xFormers and Final Dependencies

In [None]:
GPU_CardName = getoutput('nvidia-smi --query-gpu=name --format=csv,noheader')

if '3090' in GPU_CardName:
    os.system('python3 -m pip install https://huggingface.co/dnwalkup/xformers-precompiles/resolve/main/RTX3090-xf14-cu116-py38/xformers-0.0.14.dev0-cp38-cp38-linux_x86_64.whl')
    os.system('python3 -m pip uninstall -y torch')
    os.system('python3 -m pip install https://download.pytorch.org/whl/cu116/torch-1.12.1%2Bcu116-cp38-cp38-linux_x86_64.whl')
    os.system('python3 -m pip install https://download.pytorch.org/whl/cu116/torchvision-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl')
    os.system('python3 -m pip install timm')
    del GPU_CardName
    gc.collect()
    clear_output()
    print('xFormers and pyTorch installs complete')

elif 'A5000' in GPU_CardName:
    os.system('python3 -m pip install https://huggingface.co/dnwalkup/xformers-precompiles/resolve/main/A5000-xf14-cu116-py38/xformers-0.0.14.dev0-cp38-cp38-linux_x86_64.whl')
    os.system('python3 -m pip uninstall -y torch')
    os.system('python3 -m pip install https://download.pytorch.org/whl/cu116/torch-1.12.1%2Bcu116-cp38-cp38-linux_x86_64.whl')
    os.system('python3 -m pip install https://download.pytorch.org/whl/cu116/torchvision-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl')
    os.system('python3 -m pip install timm')
    del GPU_CardName
    gc.collect()
    clear_output()
    print('xFormers and pyTorch installs complete')

else:
    print('No prebuilt xformers available for your card. You may struggle with speed and vram.')

### Download Necessary Files

In [None]:
if not os.path.exists('/workspace/scripts'):
    os.mkdir('/workspace/scripts')
os.chdir('/workspace/scripts')

os.system('wget -O trainer.py https://github.com/dnwalkup/StableTuner/raw/main/linux/scripts/trainer.py')
os.system('wget -O converters.py https://github.com/dnwalkup/StableTuner/raw/main/linux/scripts/converters.py')

#clear_output()
print('Training files downloaded')

### Download Training Dataset

In [None]:
if not os.path.exists('/workspace/dataset'):
    os.mkdir('/workspace/dataset')

os.chdir('/workspace/dataset')

if not google_data_cap:
    os.system(f'gdown -O dataset.zip --fuzzy {goog_dataset_url}')
else:
    os.system(f'wget -O dataset.zip {backup_dataset_url}')

if os.path.exists('/workspace/dataset/dataset.zip'):
    os.system('unzip dataset.zip')
    os.remove('/workspace/dataset/dataset.zip')
else:
    print('Check your settings, it looks like the dataset is not there')

for instance_items in os.scandir(os.getcwd()):
    if instance_items.is_dir():
        INSTANCE_DIR_TMP = instance_items.path

os.rename(INSTANCE_DIR_TMP,"user_images")

DATASET_DIR = '/workspace/dataset/user_images'

os.chdir(DATASET_DIR)
os.system('find . -name "* *" -type f | rename "s/ /_/g"')

del instance_items
del INSTANCE_DIR_TMP
gc.collect()

clear_output()
print('Instance images downloaded and unzipped')

### Download Regularization Images
Only used if you have "use_regularization" set to True

In [None]:
if use_regularization:
    
    if not os.path.exists('/workspace/regularization'):
        os.mkdir('/workspace/regularization')
    
    os.chdir('/workspace/regularization')

    if not google_data_cap:
        os.system(f'gdown -O reg_images.zip --fuzzy {goog_regimg_url}')
    else:
        os.system(f'wget -O reg_images.zip {backup_regimg_url}')

    if os.path.exists('/workspace/regularization/reg_images.zip'):
        os.system('unzip reg_images.zip')
        os.remove('/workspace/regularization/reg_images.zip')
    else:
        print('Check your settings, it looks like the regularization images are not there')

    for reg_items in os.scandir(os.getcwd()):
        if reg_items.is_dir():
            REG_DIR_TMP = reg_items.path

    os.rename(REG_DIR_TMP,'reg_images')

    CLASS_DIR = '/workspace/regularization/reg_images'

    os.chdir(CLASS_DIR)
    os.system('find . -name "* *" -type f | rename "s/ /_/g"')

    del reg_items
    del REG_DIR_TMP
    gc.collect()

    clear_output()
    print('Regularization images downloaded and unzipped')
else:
    print('Please move on to the next cell')

### Set Final Training Variables

In [None]:
if cudnn_benchmark:
    cudnn_benchmark_args = ''
else:
    cudnn_benchmark_args = '--disable_cudnn_benchmark'

if txt_file_caption:
    txt_file_caption_args = '--use_text_files_as_captions'
else:
    txt_file_caption_args = ''

if aspect_ratio_bucketing:
    aspect_ratio_bucketing_args = '--use_bucketing'
    sample_aspect_ratio_args = '--sample_aspect_ratios'
else:
    aspect_ratio_bucketing_args = ''
    sample_aspect_ratio_args = ''

if new_latent_cache:
    new_latent_cache_args = '--regenerate_latent_cache'
else:
    new_latent_cache_args = ''

if save_latents_cache:
    save_latents_cache_args = '--save_latents_cache'
else:
    save_latents_cache_args = ''

if add_reg_img_dataset:
    add_reg_img_dataset_args = '--add_class_images_to_dataset'
else:
    add_reg_img_dataset_args = ''

if auto_balance:
    auto_balance_args = '--auto_balance_concept_datasets'
else:
    auto_balance_args = ''

if train_txt_encoder:
    train_txt_encoder_args = '--train_text_encoder'
else:
    train_txt_encoder_args = ''

if use_vae:
    use_vae_args = '--pretrained_vae_name_or_path=stabilityai/sd-vae-ft-mse'
else:
    use_vae_args = '--pretrained_vae_name_or_path='

if use_8bit_adam:
    use_8bit_adam_args = '--use_8bit_adam'
else:
    use_8bit_adam_args = ''

if use_gradient_checkpointing:
    gradient_checkpointing_args = '--gradient_checkpointing'
else:
    gradient_checkpointing_args = ''

if use_img_captions:
    use_img_captions_args = '--use_image_names_as_captions'
else:
    use_img_captions_args = ''

if use_concepts:
    concepts_args = f'--concepts_list={concepts_file}'
    dataset_dir_args = ''
else:
    concepts_args = ''
    dataset_dir_args = f'--instance_data_dir={DATASET_DIR}'

if instance_prompt != '':
    instance_prompt_args = f'--instance_prompt={instance_prompt}'
else:
    instance_prompt_args = ''

if class_prompt != '':
    class_prompt_args = f'--class_prompt={class_prompt}'
else:
    class_prompt_args = ''

if sd_base_model == 'SDv14':
    sd_base_model_args = '--pretrained_model_name_or_path=CompVis/stable-diffusion-v1-4'
    resolution_args = '--resolution=512'
elif sd_base_model == 'SDv20-512':
    sd_base_model_args = '--pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base'
    resolution_args = '--resolution=512'
elif sd_base_model == 'SDv20-768':
    sd_base_model_args = '--pretrained_model_name_or_path=stabilityai/stable-diffusion-2'
    resolution_args = '--resolution=768'
elif sd_base_model == 'SDv21-512':
    sd_base_model_args = '--pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1-base'
    resolution_args = '--resolution=512'
elif sd_base_model == 'SDv21-768':
    sd_base_model_args = '--pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1'
    resolution_args = '--resolution=768'
else:
    sd_base_model_args = '--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5'
    resolution_args = '--resolution=512'

if train_1024:
    resolution_args = '--resolution=1024'

if use_regularization:
    prior_preservation_args = '--with_prior_preservation'
    prior_preservation_weight_args = f'--prior_loss_weight={prior_preservation_weight}'
    class_dir_args = f'--class_data_dir={CLASS_DIR}'
else:
    prior_preservation_args = ''
    prior_preservation_weight_args = ''
    class_dir_args = ''
    num_class_images_var = 0

if sample_progress:
    num_samples_args = f'--n_save_sample={num_samples}'
    sample_prompt_args = f'--save_sample_prompt={sample_prompt}'
else:
    num_samples_args = ''
    sample_prompt_args = ''

if sample_at_start:
    sample_at_start_args = '--sample_on_training_start'
else:
    sample_at_start_args = ''

if seed_var == '' or seed_var == '0':
  seed_var = random.randint(1, 999999)
else:
  seed_var = int(seed_var)

clear_output()
print('Training ready, please proceed')

### Launch Training

In [None]:
os.system(f'accelerate launch --mixed_precision={mixed_precision_var} /workspace/scripts/trainer.py \
    {cudnn_benchmark_args} \
    {txt_file_caption_args} \
    {use_img_captions_args} \
    {instance_prompt_args} \
    {class_prompt_args} \
    {sd_base_model_args} \
    {use_vae_args} \
    {resolution_args} \
    {aspect_ratio_bucketing_args} \
    {sample_aspect_ratio_args} \
    {new_latent_cache_args} \
    {save_latents_cache_args} \
    {auto_balance_args} \
    {add_reg_img_dataset_args} \
    {train_txt_encoder_args} \
    {num_samples_args} \
    {sample_prompt_args} \
    {sample_at_start_args} \
    {use_8bit_adam_args} \
    {gradient_checkpointing_args} \
    {concepts_args} \
    {dataset_dir_args} \
    {prior_preservation_args} \
    {prior_preservation_weight_args} \
    {class_dir_args} \
    --num_class_images={num_class_images_var} \
    --output_dir={output_dir_var} \
    --seed={seed_var} \
    --train_batch_size={batch_size_var} \
    --num_train_epochs={train_epochs_var} \
    --learning_rate={learning_rate_var} \
    --mixed_precision={mixed_precision_var} \
    --stop_text_encoder_training={txt_encoder_training_epoch_var} \
    --gradient_accumulation_steps={gradient_accumulation_steps_var} \
    --lr_scheduler={scheduler_var} \
    --lr_warmup_steps={scheduler_warmup_var} \
    --save_every_n_epoch={save_n_epoch_var} \
    --sample_height={sample_height_var} \
    --sample_width={sample_width_var} \
    --dataset_repeats={dataset_repeats_var} \
')